Skip to content

Commit

Permalink
Fix the chat template for llava-v1.6-34b & format code (#177)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Feb 11, 2024
1 parent 50afed4 commit c51020c
Show file tree
Hide file tree
Showing 23 changed files with 101 additions and 44 deletions.
1 change: 1 addition & 0 deletions python/sglang/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Public API"""

import re
from typing import Callable, List, Optional, Union

Expand Down
33 changes: 24 additions & 9 deletions python/sglang/backend/runtime_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def __init__(self, base_url, auth_token=None):
self.base_url = base_url
self.auth_token = auth_token

res = http_request(self.base_url + "/get_model_info", auth_token=self.auth_token)
res = http_request(
self.base_url + "/get_model_info", auth_token=self.auth_token
)
assert res.status_code == 200
self.model_info = res.json()

Expand All @@ -37,22 +39,24 @@ def cache_prefix(self, prefix_str: str):
res = http_request(
self.base_url + "/generate",
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token
auth_token=self.auth_token,
)
assert res.status_code == 200

def commit_lazy_operations(self, s: StreamExecutor):
res = http_request(
self.base_url + "/generate",
json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token
auth_token=self.auth_token,
)
assert res.status_code == 200

def fill_image(self, s: StreamExecutor):
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token)
res = http_request(
self.base_url + "/generate", json=data, auth_token=self.auth_token
)
assert res.status_code == 200

def generate(
Expand Down Expand Up @@ -82,7 +86,9 @@ def generate(

self._add_images(s, data)

res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token)
res = http_request(
self.base_url + "/generate", json=data, auth_token=self.auth_token
)
obj = res.json()
comp = obj["text"]
return comp, obj["meta_info"]
Expand Down Expand Up @@ -115,7 +121,12 @@ def generate_stream(
data["stream"] = True
self._add_images(s, data)

response = http_request(self.base_url + "/generate", json=data, stream=True, auth_token=self.auth_token)
response = http_request(
self.base_url + "/generate",
json=data,
stream=True,
auth_token=self.auth_token,
)
pos = 0

incomplete_text = ""
Expand Down Expand Up @@ -145,7 +156,9 @@ def select(
# Cache common prefix
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token)
res = http_request(
self.base_url + "/generate", json=data, auth_token=self.auth_token
)
assert res.status_code == 200
prompt_len = res.json()["meta_info"]["prompt_tokens"]

Expand All @@ -157,7 +170,9 @@ def select(
"logprob_start_len": max(prompt_len - 2, 0),
}
self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data, auth_token=self.auth_token)
res = http_request(
self.base_url + "/generate", json=data, auth_token=self.auth_token
)
assert res.status_code == 200
obj = res.json()
normalized_prompt_logprob = [
Expand All @@ -172,7 +187,7 @@ def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
res = http_request(
self.base_url + "/concate_and_append_request",
json={"src_rids": src_rids, "dst_rid": dst_rid},
auth_token=self.auth_token
auth_token=self.auth_token,
)
assert res.status_code == 200

Expand Down
19 changes: 18 additions & 1 deletion python/sglang/lang/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,21 @@ def get_chat_template_by_model_path(model_path):
)


register_chat_template(
ChatTemplate(
name="chatml-llava",
default_system_prompt="Answer the questions.",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "\n<|im_end|>\n"),
"user": ("<|im_start|>user\n", "\n<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
image_token=" <image>\n",
)
)

register_chat_template(
ChatTemplate(
name="vicuna_v1.1",
Expand Down Expand Up @@ -168,7 +183,7 @@ def get_chat_template_by_model_path(model_path):
def match_vicuna(model_path: str):
if "vicuna" in model_path.lower():
return get_chat_template("vicuna_v1.1")
if "llava" in model_path.lower():
if "llava-v1.5" in model_path.lower():
return get_chat_template("vicuna_v1.1")


Expand All @@ -192,6 +207,8 @@ def match_chat_ml(model_path: str):
return get_chat_template("chatml")
if "qwen" in model_path and "chat" in model_path:
return get_chat_template("chatml")
if "llava-v1.6-34b" in model_path:
return get_chat_template("chatml-llava")


@register_chat_template_matching_function
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/lang/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ def to_anthropic_kwargs(self):
)
return {
"max_tokens_to_sample": self.max_new_tokens,
"stop_sequences": self.stop
if isinstance(self.stop, (list, tuple))
else [self.stop],
"stop_sequences": (
self.stop if isinstance(self.stop, (list, tuple)) else [self.stop]
),
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
Expand Down
1 change: 1 addition & 0 deletions python/sglang/lang/tracer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tracing a program."""

import uuid
from typing import Any, Callable, Dict, List, Optional, Union

Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/backend_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Backend configurations, may vary with different serving platforms.
"""

from dataclasses import dataclass


Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,8 @@ def generate_chat_conv(
if content.type == "text":
real_content += content.text
elif content.type == "image_url":
real_content += "<image>"
# NOTE: Only works for llava
real_content += "<image>\n"
conv.append_image(content.image_url.url)
conv.append_message(conv.roles[0], real_content)
elif msg_role == "assistant":
Expand Down
14 changes: 10 additions & 4 deletions python/sglang/srt/managers/router/model_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
is_multimodal_model,
set_random_seed,
)
from vllm.logger import _default_handler as vllm_default_handler

logger = logging.getLogger("model_rpc")

Expand All @@ -50,6 +51,9 @@ def exposed_init_model(
self.tp_size = server_args.tp_size
self.schedule_heuristic = server_args.schedule_heuristic
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
vllm_default_handler.setLevel(
level=getattr(logging, server_args.log_level.upper())
)

# Init model and tokenizer
self.model_config = ModelConfig(
Expand Down Expand Up @@ -83,9 +87,11 @@ def exposed_init_model(
self.max_num_running_seq = self.max_total_num_token // 2
self.max_prefill_num_token = max(
self.model_config.context_len,
self.max_total_num_token // 6
if server_args.max_prefill_num_token is None
else server_args.max_prefill_num_token,
(
self.max_total_num_token // 6
if server_args.max_prefill_num_token is None
else server_args.max_prefill_num_token
),
)
self.int_token_logit_bias = torch.tensor(
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
Expand Down Expand Up @@ -534,7 +540,7 @@ def handle_finished_requests(self, batch: Batch):
output_skip_special_tokens.append(
req.sampling_params.skip_special_tokens
)

# For the length of input_ids, which will be accumulated during jump-forward.
# Use the original length of input_ids to calculate the token usage info.
meta_info = {
Expand Down
12 changes: 9 additions & 3 deletions python/sglang/srt/managers/router/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def init_flashinfer_args(self, tp_size):
(self.batch_size,), dtype=torch.int32, device="cuda"
)

workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda")
workspace_buffer = torch.empty(
32 * 1024 * 1024, dtype=torch.int8, device="cuda"
)
if (
self.forward_mode == ForwardMode.PREFILL
or self.forward_mode == ForwardMode.EXTEND
Expand All @@ -121,7 +123,9 @@ def init_flashinfer_args(self, tp_size):
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD")
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD"
)
self.prefill_wrapper.begin_forward(
self.qo_indptr,
self.kv_indptr,
Expand All @@ -131,7 +135,9 @@ def init_flashinfer_args(self, tp_size):
self.model_runner.model_config.num_key_value_heads // tp_size,
)
else:
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD")
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, "NHD"
)
self.decode_wrapper.begin_forward(
self.kv_indptr,
self.kv_indices,
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/memory_pool.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Memory pool."""

import logging

import torch
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/llava.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Inference-only LLaVa model compatible with HuggingFace weights."""

from typing import List, Optional

import numpy as np
Expand Down Expand Up @@ -269,7 +270,6 @@ def load_weights(
raise ValueError(f"Unexpected select feature: {self.select_feature}")

# load mm_projector
# TODO: support TP?
projector_weights = {
"model.mm_projector.0": "multi_modal_projector.linear_1",
"model.mm_projector.2": "multi_modal_projector.linear_2",
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/models/mistral.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Inference-only Mistral model."""

from sglang.srt.models.llama2 import LlamaForCausalLM


Expand Down
16 changes: 9 additions & 7 deletions python/sglang/srt/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,16 @@ def __init__(

self.experts = nn.ModuleList(
[
MixtralMLP(
self.num_total_experts,
config.hidden_size,
config.intermediate_size,
linear_method=linear_method,
(
MixtralMLP(
self.num_total_experts,
config.hidden_size,
config.intermediate_size,
linear_method=linear_method,
)
if idx in self.expert_indicies
else None
)
if idx in self.expert_indicies
else None
for idx in range(self.num_total_experts)
]
)
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/models/yivl.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Inference-only Yi-VL model."""

import os
from typing import List, Optional

Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/sampling_params.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Sampling parameters for text generation."""

from typing import List, Optional, Union

_SAMPLING_EPS = 1e-6
Expand Down
7 changes: 4 additions & 3 deletions python/sglang/srt/server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""SRT: SGLang Runtime"""

import asyncio
import json
import multiprocessing as mp
Expand Down Expand Up @@ -493,7 +494,7 @@ def _launch_server():

# Warmup
try:
print("Warmup...", flush=True)
# print("Warmup...", flush=True)
res = requests.post(
url + "/generate",
json={
Expand All @@ -505,8 +506,8 @@ def _launch_server():
},
timeout=60,
)
print(f"Warmup done. model response: {res.json()['text']}")
print("=" * 20, "Server is ready", "=" * 20, flush=True)
# print(f"Warmup done. model response: {res.json()['text']}")
# print("=" * 20, "Server is ready", "=" * 20, flush=True)
except requests.exceptions.RequestException as e:
if pipe_finish_writer is not None:
pipe_finish_writer.send(str(e))
Expand Down
4 changes: 1 addition & 3 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def handle_port_init(
# first check on server port
if not check_port(port):
new_port = alloc_usable_network_port(1, used_list=[port])[0]
print(f"Port {port} is not available, using {new_port} instead.")
print(f"WARNING: Port {port} is not available. Use {new_port} instead.")
port = new_port

# then we check on additional ports
Expand Down Expand Up @@ -157,8 +157,6 @@ def get_int_token_logit_bias(tokenizer, vocab_size):
ss = tokenizer.decode([t_id]).strip()
if not (ss.isdigit() or len(ss) == 0 or t_id == tokenizer.eos_token_id):
logit_bias[t_id] = -1e5
# else:
# print(ss, t_id)

return logit_bias

Expand Down
1 change: 1 addition & 0 deletions python/sglang/test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Common utilities for testing and benchmarking"""

import numpy as np
import requests
from sglang.backend.openai import OpenAI
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_available_gpu_memory(gpu_id, distributed=True):

if torch.cuda.current_device() != gpu_id:
print(
f"WARN: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
"which may cause useless memory allocation for torch CUDA context.",
)

Expand Down Expand Up @@ -95,7 +95,7 @@ def http_request(url, json=None, stream=False, auth_token=None):
return requests.post(url, json=json, stream=True)
headers = {
"Content-Type": "application/json",
"Authentication": f"Bearer {auth_token}"
"Authentication": f"Bearer {auth_token}",
}
return requests.post(url, json=json, stream=True, headers=headers)
else:
Expand Down
1 change: 1 addition & 0 deletions test/lang/test_srt_backend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
"""

import json
import unittest

Expand Down
Loading

0 comments on commit c51020c

Please sign in to comment.