diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index b398e052da7..f4cbb2de2d4 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -96,6 +96,10 @@ class LogitsMetadata: temperature: torch.Tensor = None top_p_normalized_logprobs: bool = False top_p: torch.Tensor = None + top_k_normalized_logprobs: bool = False + top_k: torch.Tensor = None + min_p_normalized_logprobs: bool = False + min_p: torch.Tensor = None # DP attention metadata. Not needed when DP attention is not used. # Number of tokens in the request. @@ -161,6 +165,14 @@ def from_forward_batch(cls, forward_batch: ForwardBatch): forward_batch_gathered_buffer=forward_batch.gathered_buffer, global_num_tokens_for_logprob_cpu=forward_batch.global_num_tokens_for_logprob_cpu, global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu, + temp_scaled_logprobs=forward_batch.temp_scaled_logprobs, + temperature=forward_batch.temperature, + top_p_normalized_logprobs=forward_batch.top_p_normalized_logprobs, + top_p=forward_batch.top_p, + top_k_normalized_logprobs=forward_batch.top_k_normalized_logprobs, + top_k=forward_batch.top_k, + min_p_normalized_logprobs=forward_batch.min_p_normalized_logprobs, + min_p=forward_batch.min_p, ) def compute_dp_attention_metadata(self, hidden_states: torch.Tensor): @@ -224,6 +236,7 @@ def forward( lm_head: VocabParallelEmbedding, logits_metadata: Union[LogitsMetadata, ForwardBatch], ) -> LogitsProcessorOutput: + print("in .forwrd of LogitsProcessor") if isinstance(logits_metadata, ForwardBatch): logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata) @@ -336,6 +349,7 @@ def forward( hidden_states=hidden_states_to_store, ) else: + assert logits_metadata.forward_mode.is_extend(), "Extend mode is required for return_logprob" input_logprobs = logits[input_logprob_indices] del hidden_states, logits @@ -354,12 +368,24 @@ def forward( logits_metadata.top_p, pruned_lens, ) + input_logprobs = self.compute_temp_top_p_normalized_logprobs( input_logprobs, logits_metadata ) + # IDEA: do top_k using top_p of np.max(input_logprobs, axis = 1) + + # Going to have to leave the top-k normalization out for now (CUDA OOM) + #input_logprobs = self.compute_temp_top_p_top_k_normalized_logprobs( + # input_logprobs, logits_metadata + #) + # Get the logprob of top-k tokens + # Note: this is how many "k" we want to return, not the top_k for sampling purposes if logits_metadata.extend_return_top_logprob: + # Clamp to avoid -inf, which certainly happens if we use top-p or top-k + input_logprobs = input_logprobs.clamp(min=torch.finfo(input_logprobs.dtype).min) + ( input_top_logprobs_val, input_top_logprobs_idx, @@ -496,6 +522,53 @@ def get_token_ids_logprobs( return input_token_ids_logprobs_val, input_token_ids_logprobs_idx + + @staticmethod + def compute_temp_top_p_top_k_normalized_logprobs( + last_logits: torch.Tensor, logits_metadata: LogitsMetadata + ) -> torch.Tensor: + + # Note: We should also incorporate custom logit processors and/or grammar backend masks as well + + if logits_metadata.temp_scaled_logprobs: + last_logits.div_(logits_metadata.temperature) + + needs_top_p = logits_metadata.top_p_normalized_logprobs \ + and (logits_metadata.top_p != 1.0).any() + + needs_top_k = logits_metadata.top_k_normalized_logprobs \ + and ((logits_metadata.top_k != -1) & (logits_metadata.top_k < last_logits.shape[-1])).any() + + if not needs_top_p and not needs_top_k: + return torch.nn.functional.log_softmax(last_logits, dim=-1) + + probs = torch.softmax(last_logits, dim=-1) + del last_logits + + """ + if needs_top_p or needs_top_k: + from sglang.srt.layers.sampler import top_p_top_k_normalize_probs_torch + probs = top_p_top_k_normalize_probs_torch( + probs, + logits_metadata.top_p, + logits_metadata.top_k + ) + """ + + + if needs_top_p: + from sglang.srt.layers.sampler import top_p_normalize_probs_torch + probs = top_p_normalize_probs_torch(probs, logits_metadata.top_p) + """ + if needs_top_k: + print(" Applying top k") + from sglang.srt.layers.sampler import top_k_normalize_probs_torch + probs = top_k_normalize_probs_torch(probs, logits_metadata.top_k) + print(f"After top k: {probs[0,:20]}") + """ + + return torch.log(probs) + @staticmethod def compute_temp_top_p_normalized_logprobs( last_logits: torch.Tensor, logits_metadata: LogitsMetadata @@ -506,9 +579,10 @@ def compute_temp_top_p_normalized_logprobs( Returns: torch.Tensor: logprobs from logits """ + # Scale logits if temperature scaling is enabled if logits_metadata.temp_scaled_logprobs: - last_logits = last_logits / logits_metadata.temperature + last_logits = last_logits.div_(logits_metadata.temperature) # Normalize logprobs if top_p normalization is enabled # NOTE: only normalize logprobs when top_p is set and not equal to 1.0 @@ -517,7 +591,6 @@ def compute_temp_top_p_normalized_logprobs( and (logits_metadata.top_p != 1.0).any() ): from sglang.srt.layers.sampler import top_p_normalize_probs_torch - probs = torch.softmax(last_logits, dim=-1) del last_logits probs = top_p_normalize_probs_torch(probs, logits_metadata.top_p) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 37f22ec21ca..e27b576971e 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -243,6 +243,36 @@ def top_p_normalize_probs_torch( return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort) +def top_k_normalize_probs_torch( + probs: torch.Tensor, # [T,V] + top_ks: torch.Tensor # [T] +): + probs_sort, probs_idx = probs.sort(dim=-1, descending=True) + mask = torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks.view(-1, 1) + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort) + + +def top_p_top_k_normalize_probs_torch( + probs: torch.Tensor, # [T,V] + top_ps: torch.Tensor, # [T] + top_ks: torch.Tensor, # [T] + ): + probs_sort, probs_idx = probs.sort(dim=-1, descending=True) + + # top_ps + probs_sum = torch.cumsum(probs_sort, dim=-1) + probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0 + + # top_ks + mask = torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks.view(-1, 1) + probs_sort[mask] = 0.0 + + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort) + + def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]): assert len(top_logprobs_nums) == logprobs.shape[0], ( len(top_logprobs_nums), diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 27f87d8a284..635c715b490 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -267,6 +267,8 @@ def __init__( self.session_id = session_id self.input_embeds = input_embeds + print("In Req constructor - temperature: ", sampling_params.temperature) + # Sampling info if isinstance(sampling_params.custom_params, dict): sampling_params = copy.copy(sampling_params) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index d5ce3bc71ae..5b91e850004 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -631,6 +631,8 @@ def handle_generate_request( ) custom_logit_processor = None + print("In handle_generate_request - temperature: ", recv_req.sampling_params.temperature) + req = Req( recv_req.rid, recv_req.input_text, diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 13158d93726..115cd34e41c 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -291,6 +291,7 @@ def add_input_logprob_return_values( Some of input logprob operation should only happen at the last prefill (e.g., computing input token logprobs). """ + assert output.input_token_logprobs is not None if req.input_token_logprobs is None: req.input_token_logprobs = [] diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 3132060ed29..95ce59d977f 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -408,11 +408,15 @@ async def _tokenize_one_request( f"messages or the completion to fit within the limit." ) + print("In tokenizer_manager _tokenize_one_request (before normalize) - temperature: ", obj.sampling_params) + # Parse sampling parameters sampling_params = SamplingParams(**obj.sampling_params) sampling_params.normalize(self.tokenizer) sampling_params.verify() + print("In tokenizer_manager _tokenize_one_request (after normalize) - temperature: ", sampling_params.temperature) + # Build return object if isinstance(obj, GenerateReqInput): tokenized_obj = TokenizedGenerateReqInput( diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index b732b033e39..8f46fb52c85 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -157,6 +157,10 @@ class ForwardBatch: temperature: torch.Tensor = None top_p_normalized_logprobs: bool = False top_p: torch.Tensor = None + top_k_normalized_logprobs: bool = False + top_k: torch.Tensor = None + min_p_normalized_logprobs: bool = False + min_p: torch.Tensor = None # Position information positions: torch.Tensor = None @@ -261,6 +265,14 @@ def init_new( capture_hidden_mode=batch.capture_hidden_mode, input_embeds=batch.input_embeds, extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu, + temp_scaled_logprobs=batch.sampling_info.temperatures is not None, + temperature=batch.sampling_info.temperatures, + top_p_normalized_logprobs=batch.sampling_info.top_ps is not None, + top_p=batch.sampling_info.top_ps, + top_k_normalized_logprobs=batch.sampling_info.top_ks is not None, + top_k=batch.sampling_info.top_ks, + min_p_normalized_logprobs=batch.sampling_info.min_ps is not None, + min_p=batch.sampling_info.min_ps, ) # For DP attention diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 2ac4e3ed85a..fbcb2d8139c 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -868,6 +868,7 @@ def v1_chat_generate_request( top_logprobs_nums = [] modalities_list = [] lora_paths = [] + return_hidden_states = [] # NOTE: with openai API, the prompt's logprobs are always not computed @@ -961,10 +962,11 @@ def v1_chat_generate_request( modalities = [] input_ids.append(prompt_ids) return_logprobs.append(request.logprobs) - logprob_start_lens.append(-1) top_logprobs_nums.append(request.top_logprobs or 0) lora_paths.append(request.lora_path) - + logprob_start_lens.append(request.__pydantic_extra__.get("logprob_start_len", -1)) + return_hidden_states.append(request.__pydantic_extra__.get("return_hidden_states", False)) + sampling_params = { "temperature": request.temperature, "max_new_tokens": request.max_tokens, @@ -1029,8 +1031,10 @@ def v1_chat_generate_request( rid=request_ids, modalities=modalities_list, lora_path=lora_paths, + return_hidden_states=return_hidden_states ) + return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0] @@ -1046,6 +1050,7 @@ def v1_chat_generate_response( for idx, ret_item in enumerate(ret): logprobs = False + if isinstance(request, list) and request[idx].logprobs: logprobs = True elif (not isinstance(request, list)) and request.logprobs: @@ -1226,6 +1231,7 @@ def v1_chat_generate_response( async def v1_chat_completions(tokenizer_manager, raw_request: Request): request_json = await raw_request.json() + all_requests = [ChatCompletionRequest(**request_json)] adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager) diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 0b614832173..b761db1ebfd 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -16,7 +16,7 @@ import time from typing import Dict, List, Optional, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict from typing_extensions import Literal @@ -301,6 +301,8 @@ class ToolChoice(BaseModel): class ChatCompletionRequest(BaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/chat/create + model_config = ConfigDict(extra='allow') + messages: List[ChatCompletionMessageParam] model: str frequency_penalty: float = 0.0 diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 5942b827087..329b2c6d89e 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -60,6 +60,10 @@ class SamplingBatchInfo: def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): reqs = batch.reqs device = batch.device + print("IN from_schedule_batch temperature: ", [r.sampling_params.temperature for r in reqs]) + print("IN from_schedule_batch top_p: ", [r.sampling_params.top_p for r in reqs]) + print("IN from_schedule_batch top_k: ", [r.sampling_params.top_k for r in reqs]) + print("IN from_schedule_batch min_p: ", [r.sampling_params.min_p for r in reqs]) temperatures = ( torch.tensor( [r.sampling_params.temperature for r in reqs],