Skip to content

Commit

Permalink
Migrate llama_classification to use the /classify interface (#2417)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy committed Dec 9, 2024
1 parent 3844feb commit 835f8af
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 25 deletions.
33 changes: 11 additions & 22 deletions python/sglang/srt/models/llama_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch import nn
from transformers import LlamaConfig

from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
Expand All @@ -40,7 +40,7 @@ def __init__(
self.classification_head = nn.Linear(
config.hidden_size, config.classification_out_size, bias=False
)
self.eos_token_id = config.eos_token_id
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)

@torch.no_grad()
def forward(
Expand All @@ -49,28 +49,17 @@ def forward(
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
is_eos_token = input_ids == self.eos_token_id
hidden_states = hidden_states[is_eos_token]
scores = self.classification_head(hidden_states)

if scores.shape[0] != forward_batch.batch_size:
print("Warning: the EOS tokens are missing in some sentences.")
scores = torch.ones(
(forward_batch.batch_size, self.config.classification_out_size)
).to(input_ids.device)
get_embedding: bool = True,
) -> EmbeddingPoolerOutput:
assert (
get_embedding
), "LlamaForClassification is only used for embedding. Please add --is-embedding when you launch the server."

logits_output = LogitsProcessorOutput(
next_token_logits=scores,
next_token_logprobs=scores,
normalized_prompt_logprobs=scores,
input_token_logprobs=torch.ones_like(input_ids),
input_top_logprobs=None,
output_top_logprobs=None,
)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings
scores = self.classification_head(last_token_hidden)

return logits_output
return EmbeddingPoolerOutput(scores)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
Expand Down
22 changes: 19 additions & 3 deletions scripts/deprecated/test_httpserver_classify.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
Usage:
python3 -m sglang.launch_server --disable-cuda-graph --model-path /model/llama-classification
python3 -m sglang.launch_server --model-path /model/llama-classification --is-embedding --disable-radix-cache
python3 test_httpserver_classify.py
"""
Expand All @@ -11,7 +11,7 @@
import requests


def get_logits(url, prompt):
def get_logits_deprecated(url: str, prompt: str):
response = requests.post(
url + "/generate",
json={
Expand All @@ -25,7 +25,7 @@ def get_logits(url, prompt):
return response.json()["meta_info"]["normalized_prompt_logprob"]


def get_logits_batch(url, prompts):
def get_logits_batch_deprecated(url: str, prompts: list[str]):
response = requests.post(
url + "/generate",
json={
Expand All @@ -46,6 +46,22 @@ def get_logits_batch(url, prompts):
return logits


def get_logits(url: str, prompt: str):
response = requests.post(
url + "/classify",
json={"text": prompt},
)
return response.json()["embedding"]


def get_logits_batch(url: str, prompts: list[str]):
response = requests.post(
url + "/classify",
json={"text": prompts},
)
return np.array([x["embedding"] for x in response.json()])


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
Expand Down

0 comments on commit 835f8af

Please sign in to comment.