From 1ad76cd4e04d2348fe79ed048bbceeec52857bec Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 8 Dec 2024 23:30:09 -0800 Subject: [PATCH] Migrate llama_classification to use the /classify interface --- .../sglang/srt/models/llama_classification.py | 33 +++++++------------ .../deprecated/test_httpserver_classify.py | 22 +++++++++++-- 2 files changed, 30 insertions(+), 25 deletions(-) diff --git a/python/sglang/srt/models/llama_classification.py b/python/sglang/srt/models/llama_classification.py index 038732476ed..c4ee76379b6 100644 --- a/python/sglang/srt/models/llama_classification.py +++ b/python/sglang/srt/models/llama_classification.py @@ -18,7 +18,7 @@ from torch import nn from transformers import LlamaConfig -from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -40,7 +40,7 @@ def __init__( self.classification_head = nn.Linear( config.hidden_size, config.classification_out_size, bias=False ) - self.eos_token_id = config.eos_token_id + self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False) @torch.no_grad() def forward( @@ -49,28 +49,17 @@ def forward( positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, - ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) - is_eos_token = input_ids == self.eos_token_id - hidden_states = hidden_states[is_eos_token] - scores = self.classification_head(hidden_states) - - if scores.shape[0] != forward_batch.batch_size: - print("Warning: the EOS tokens are missing in some sentences.") - scores = torch.ones( - (forward_batch.batch_size, self.config.classification_out_size) - ).to(input_ids.device) + get_embedding: bool = True, + ) -> EmbeddingPoolerOutput: + assert ( + get_embedding + ), "LlamaForClassification is only used for embedding. Please add --is-embedding when you launch the server." - logits_output = LogitsProcessorOutput( - next_token_logits=scores, - next_token_logprobs=scores, - normalized_prompt_logprobs=scores, - input_token_logprobs=torch.ones_like(input_ids), - input_top_logprobs=None, - output_top_logprobs=None, - ) + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) + last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings + scores = self.classification_head(last_token_hidden) - return logits_output + return EmbeddingPoolerOutput(scores) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) diff --git a/scripts/deprecated/test_httpserver_classify.py b/scripts/deprecated/test_httpserver_classify.py index dbcafb88d7d..cb88802999a 100644 --- a/scripts/deprecated/test_httpserver_classify.py +++ b/scripts/deprecated/test_httpserver_classify.py @@ -1,6 +1,6 @@ """ Usage: -python3 -m sglang.launch_server --disable-cuda-graph --model-path /model/llama-classification +python3 -m sglang.launch_server --model-path /model/llama-classification --is-embedding --disable-radix-cache python3 test_httpserver_classify.py """ @@ -11,7 +11,7 @@ import requests -def get_logits(url, prompt): +def get_logits_deprecated(url: str, prompt: str): response = requests.post( url + "/generate", json={ @@ -25,7 +25,7 @@ def get_logits(url, prompt): return response.json()["meta_info"]["normalized_prompt_logprob"] -def get_logits_batch(url, prompts): +def get_logits_batch_deprecated(url: str, prompts: list[str]): response = requests.post( url + "/generate", json={ @@ -46,6 +46,22 @@ def get_logits_batch(url, prompts): return logits +def get_logits(url: str, prompt: str): + response = requests.post( + url + "/classify", + json={"text": prompt}, + ) + return response.json()["embedding"] + + +def get_logits_batch(url: str, prompts: list[str]): + response = requests.post( + url + "/classify", + json={"text": prompts}, + ) + return np.array([x["embedding"] for x in response.json()]) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="http://127.0.0.1")