From 79b873ece3dd47f11ed998c08348065a93588ea1 Mon Sep 17 00:00:00 2001 From: Kyle Pena Date: Sat, 19 Apr 2025 23:50:05 +0000 Subject: [PATCH 1/5] personal fork for sglang - verification stuff --- .../srt/managers/scheduler_output_processor_mixin.py | 1 + python/sglang/srt/openai_api/adapter.py | 10 ++++++++-- python/sglang/srt/openai_api/protocol.py | 4 +++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 13158d93726..115cd34e41c 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -291,6 +291,7 @@ def add_input_logprob_return_values( Some of input logprob operation should only happen at the last prefill (e.g., computing input token logprobs). """ + assert output.input_token_logprobs is not None if req.input_token_logprobs is None: req.input_token_logprobs = [] diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 2ac4e3ed85a..fbcb2d8139c 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -868,6 +868,7 @@ def v1_chat_generate_request( top_logprobs_nums = [] modalities_list = [] lora_paths = [] + return_hidden_states = [] # NOTE: with openai API, the prompt's logprobs are always not computed @@ -961,10 +962,11 @@ def v1_chat_generate_request( modalities = [] input_ids.append(prompt_ids) return_logprobs.append(request.logprobs) - logprob_start_lens.append(-1) top_logprobs_nums.append(request.top_logprobs or 0) lora_paths.append(request.lora_path) - + logprob_start_lens.append(request.__pydantic_extra__.get("logprob_start_len", -1)) + return_hidden_states.append(request.__pydantic_extra__.get("return_hidden_states", False)) + sampling_params = { "temperature": request.temperature, "max_new_tokens": request.max_tokens, @@ -1029,8 +1031,10 @@ def v1_chat_generate_request( rid=request_ids, modalities=modalities_list, lora_path=lora_paths, + return_hidden_states=return_hidden_states ) + return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0] @@ -1046,6 +1050,7 @@ def v1_chat_generate_response( for idx, ret_item in enumerate(ret): logprobs = False + if isinstance(request, list) and request[idx].logprobs: logprobs = True elif (not isinstance(request, list)) and request.logprobs: @@ -1226,6 +1231,7 @@ def v1_chat_generate_response( async def v1_chat_completions(tokenizer_manager, raw_request: Request): request_json = await raw_request.json() + all_requests = [ChatCompletionRequest(**request_json)] adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager) diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 0b614832173..b761db1ebfd 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -16,7 +16,7 @@ import time from typing import Dict, List, Optional, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict from typing_extensions import Literal @@ -301,6 +301,8 @@ class ToolChoice(BaseModel): class ChatCompletionRequest(BaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/chat/create + model_config = ConfigDict(extra='allow') + messages: List[ChatCompletionMessageParam] model: str frequency_penalty: float = 0.0 From ca2c925a2a6b020fbb6911484f1cf91476da95b4 Mon Sep 17 00:00:00 2001 From: Kyle Pena Date: Mon, 21 Apr 2025 17:31:22 +0000 Subject: [PATCH 2/5] checkpoint for implementing changes needed to correctly compute temp / top-k (and also top-p) for input logprobs --- python/sglang/srt/layers/logits_processor.py | 93 ++++++++++++++++++- python/sglang/srt/layers/sampler.py | 11 +++ python/sglang/srt/managers/schedule_batch.py | 2 + python/sglang/srt/managers/scheduler.py | 2 + .../sglang/srt/managers/tokenizer_manager.py | 4 + .../srt/model_executor/forward_batch_info.py | 12 +++ .../srt/sampling/sampling_batch_info.py | 4 + 7 files changed, 127 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index b398e052da7..8fc0837a356 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -96,6 +96,10 @@ class LogitsMetadata: temperature: torch.Tensor = None top_p_normalized_logprobs: bool = False top_p: torch.Tensor = None + top_k_normalized_logprobs: bool = False + top_k: torch.Tensor = None + min_p_normalized_logprobs: bool = False + min_p: torch.Tensor = None # DP attention metadata. Not needed when DP attention is not used. # Number of tokens in the request. @@ -140,6 +144,16 @@ def from_forward_batch(cls, forward_batch: ForwardBatch): extend_token_ids_logprob ) = extend_logprob_pruned_lens_cpu = False + print(f"In from_forward_batch: {forward_batch}") + print(f"temp_scaled_logprobs: {forward_batch.temp_scaled_logprobs}") + print(f"temperature: {forward_batch.temperature}") + print(f"top_p: {forward_batch.top_p}") + print(f"top_p_normalized_logprobs: {forward_batch.top_p_normalized_logprobs}") + print(f"top_k: {forward_batch.top_k}") + print(f"top_k_normalized_logprobs: {forward_batch.top_k_normalized_logprobs}") + print(f"min_p: {forward_batch.min_p}") + print(f"min_p_normalized_logprobs: {forward_batch.min_p_normalized_logprobs}") + return cls( forward_mode=forward_batch.forward_mode, capture_hidden_mode=forward_batch.capture_hidden_mode, @@ -161,6 +175,14 @@ def from_forward_batch(cls, forward_batch: ForwardBatch): forward_batch_gathered_buffer=forward_batch.gathered_buffer, global_num_tokens_for_logprob_cpu=forward_batch.global_num_tokens_for_logprob_cpu, global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu, + temp_scaled_logprobs=forward_batch.temp_scaled_logprobs, + temperature=forward_batch.temperature, + top_p_normalized_logprobs=forward_batch.top_p_normalized_logprobs, + top_p=forward_batch.top_p, + top_k_normalized_logprobs=forward_batch.top_k_normalized_logprobs, + top_k=forward_batch.top_k, + min_p_normalized_logprobs=forward_batch.min_p_normalized_logprobs, + min_p=forward_batch.min_p, ) def compute_dp_attention_metadata(self, hidden_states: torch.Tensor): @@ -224,6 +246,7 @@ def forward( lm_head: VocabParallelEmbedding, logits_metadata: Union[LogitsMetadata, ForwardBatch], ) -> LogitsProcessorOutput: + print("in .forwrd of LogitsProcessor") if isinstance(logits_metadata, ForwardBatch): logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata) @@ -336,6 +359,7 @@ def forward( hidden_states=hidden_states_to_store, ) else: + assert logits_metadata.forward_mode.is_extend(), "Extend mode is required for return_logprob" input_logprobs = logits[input_logprob_indices] del hidden_states, logits @@ -354,12 +378,23 @@ def forward( logits_metadata.top_p, pruned_lens, ) - input_logprobs = self.compute_temp_top_p_normalized_logprobs( + + #input_logprobs = self.compute_temp_top_p_normalized_logprobs( + # input_logprobs, logits_metadata + #) + + input_logprobs = self.compute_temp_top_p_top_k_normalized_logprobs( input_logprobs, logits_metadata ) # Get the logprob of top-k tokens + # Note: this is how many "k" we want to return, not the top_k for sampling purposes if logits_metadata.extend_return_top_logprob: + # Clamp to avoid -inf, which certainly happens if we use top-p or top-k + print("Before clamp", input_logprobs[0,:20]) + input_logprobs = input_logprobs.clamp(min=torch.finfo(probs.dtype).min) + print("After clamp", input_logprobs[0,:20]) + ( input_top_logprobs_val, input_top_logprobs_idx, @@ -496,6 +531,53 @@ def get_token_ids_logprobs( return input_token_ids_logprobs_val, input_token_ids_logprobs_idx + + @staticmethod + def compute_temp_top_p_top_k_normalized_logprobs( + last_logits: torch.Tensor, logits_metadata: LogitsMetadata + ) -> torch.Tensor: + + # Note: We should also incorporate custom logit processors and/or grammar backend masks as well + + if logits_metadata.temp_scaled_logprobs: + print("Before temp scaling", last_logits[0,:20]) + print(f"Scaling logits by temperature: {logits_metadata.temperature}") + last_logits = last_logits / logits_metadata.temperature + print("After temp scaling", last_logits[0,:20]) + + needs_top_p = logits_metadata.top_p_normalized_logprobs \ + and (logits_metadata.top_p != 1.0).any() + + needs_top_k = logits_metadata.top_k_normalized_logprobs \ + and ((logits_metadata.top_k != -1) & (logits_metadata.top_k < last_logits.shape[-1])).any() + + + print("in compute_temp_top_p_top_k_normalized_logprobs") + print(f"needs_top_p: {needs_top_p}") + print(f"needs_top_k: {needs_top_k}") + print(f"last_logits.shape[-1]: {last_logits.shape[-1]}") + + if not needs_top_p and not needs_top_k: + return torch.nn.functional.log_softmax(last_logits, dim=-1) + + probs = torch.softmax(last_logits, dim=-1) + del last_logits + + if needs_top_p: + print(" Applying top p") + from sglang.srt.layers.sampler import top_p_normalize_probs_torch + probs = top_p_normalize_probs_torch(probs, logits_metadata.top_p) + print(f"After top p: {probs[0,:20]}") + + if needs_top_k: + print(" Applying top k") + from sglang.srt.layers.sampler import top_k_normalize_probs_torch + probs = top_k_normalize_probs_torch(probs, logits_metadata.top_k) + print(f"After top k: {probs[0,:20]}") + + + return torch.log(probs) + @staticmethod def compute_temp_top_p_normalized_logprobs( last_logits: torch.Tensor, logits_metadata: LogitsMetadata @@ -506,8 +588,16 @@ def compute_temp_top_p_normalized_logprobs( Returns: torch.Tensor: logprobs from logits """ + + print(f"In compute_temp_top_p_normalized_logprobs: {logits_metadata.temp_scaled_logprobs}") + print(f"temp_scaled_logprobs: {logits_metadata.temp_scaled_logprobs}") + print(f"temperature: {logits_metadata.temperature}") + print(f"top_p_normalized_logprobs: {logits_metadata.top_p_normalized_logprobs}") + print(f"top_p: {logits_metadata.top_p}") + # Scale logits if temperature scaling is enabled if logits_metadata.temp_scaled_logprobs: + print(f"Scaling logits by temperature: {logits_metadata.temperature}") last_logits = last_logits / logits_metadata.temperature # Normalize logprobs if top_p normalization is enabled @@ -518,6 +608,7 @@ def compute_temp_top_p_normalized_logprobs( ): from sglang.srt.layers.sampler import top_p_normalize_probs_torch + print(f"Normalizing logprobs by top_p: {logits_metadata.top_p}") probs = torch.softmax(last_logits, dim=-1) del last_logits probs = top_p_normalize_probs_torch(probs, logits_metadata.top_p) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 37f22ec21ca..93d7f803c07 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -243,6 +243,17 @@ def top_p_normalize_probs_torch( return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort) +def top_k_normalize_probs_torch( + probs: torch.Tensor, # [T,V] + top_ks: torch.Tensor # [T] +): + probs_sort, probs_idx = probs.sort(dim=-1, descending=True) + mask = torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks.view(-1, 1) + probs_sort[mask] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort) + + def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]): assert len(top_logprobs_nums) == logprobs.shape[0], ( len(top_logprobs_nums), diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 27f87d8a284..635c715b490 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -267,6 +267,8 @@ def __init__( self.session_id = session_id self.input_embeds = input_embeds + print("In Req constructor - temperature: ", sampling_params.temperature) + # Sampling info if isinstance(sampling_params.custom_params, dict): sampling_params = copy.copy(sampling_params) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index d5ce3bc71ae..5b91e850004 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -631,6 +631,8 @@ def handle_generate_request( ) custom_logit_processor = None + print("In handle_generate_request - temperature: ", recv_req.sampling_params.temperature) + req = Req( recv_req.rid, recv_req.input_text, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 3132060ed29..95ce59d977f 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -408,11 +408,15 @@ async def _tokenize_one_request( f"messages or the completion to fit within the limit." ) + print("In tokenizer_manager _tokenize_one_request (before normalize) - temperature: ", obj.sampling_params) + # Parse sampling parameters sampling_params = SamplingParams(**obj.sampling_params) sampling_params.normalize(self.tokenizer) sampling_params.verify() + print("In tokenizer_manager _tokenize_one_request (after normalize) - temperature: ", sampling_params.temperature) + # Build return object if isinstance(obj, GenerateReqInput): tokenized_obj = TokenizedGenerateReqInput( diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index b732b033e39..8f46fb52c85 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -157,6 +157,10 @@ class ForwardBatch: temperature: torch.Tensor = None top_p_normalized_logprobs: bool = False top_p: torch.Tensor = None + top_k_normalized_logprobs: bool = False + top_k: torch.Tensor = None + min_p_normalized_logprobs: bool = False + min_p: torch.Tensor = None # Position information positions: torch.Tensor = None @@ -261,6 +265,14 @@ def init_new( capture_hidden_mode=batch.capture_hidden_mode, input_embeds=batch.input_embeds, extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu, + temp_scaled_logprobs=batch.sampling_info.temperatures is not None, + temperature=batch.sampling_info.temperatures, + top_p_normalized_logprobs=batch.sampling_info.top_ps is not None, + top_p=batch.sampling_info.top_ps, + top_k_normalized_logprobs=batch.sampling_info.top_ks is not None, + top_k=batch.sampling_info.top_ks, + min_p_normalized_logprobs=batch.sampling_info.min_ps is not None, + min_p=batch.sampling_info.min_ps, ) # For DP attention diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 5942b827087..329b2c6d89e 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -60,6 +60,10 @@ class SamplingBatchInfo: def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): reqs = batch.reqs device = batch.device + print("IN from_schedule_batch temperature: ", [r.sampling_params.temperature for r in reqs]) + print("IN from_schedule_batch top_p: ", [r.sampling_params.top_p for r in reqs]) + print("IN from_schedule_batch top_k: ", [r.sampling_params.top_k for r in reqs]) + print("IN from_schedule_batch min_p: ", [r.sampling_params.min_p for r in reqs]) temperatures = ( torch.tensor( [r.sampling_params.temperature for r in reqs], From 71067c584b5d9d29f760f5ce1b38e2ba5f7b188e Mon Sep 17 00:00:00 2001 From: Kyle Pena Date: Tue, 22 Apr 2025 15:55:24 +0000 Subject: [PATCH 3/5] now only doing top_p normalization for input logprobs --- python/sglang/srt/layers/logits_processor.py | 35 ++++++++++++++------ python/sglang/srt/layers/sampler.py | 19 +++++++++++ 2 files changed, 44 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 8fc0837a356..43f6158b9cb 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -379,20 +379,23 @@ def forward( pruned_lens, ) - #input_logprobs = self.compute_temp_top_p_normalized_logprobs( - # input_logprobs, logits_metadata - #) - - input_logprobs = self.compute_temp_top_p_top_k_normalized_logprobs( + input_logprobs = self.compute_temp_top_p_normalized_logprobs( input_logprobs, logits_metadata ) + # IDEA: do top_k using top_p of np.max(input_logprobs, axis = 1) + + # Going to have to leave the top-k normalization out for now (CUDA OOM) + #input_logprobs = self.compute_temp_top_p_top_k_normalized_logprobs( + # input_logprobs, logits_metadata + #) + # Get the logprob of top-k tokens # Note: this is how many "k" we want to return, not the top_k for sampling purposes if logits_metadata.extend_return_top_logprob: # Clamp to avoid -inf, which certainly happens if we use top-p or top-k print("Before clamp", input_logprobs[0,:20]) - input_logprobs = input_logprobs.clamp(min=torch.finfo(probs.dtype).min) + input_logprobs = input_logprobs.clamp(min=torch.finfo(input_logprobs.dtype).min) print("After clamp", input_logprobs[0,:20]) ( @@ -542,7 +545,8 @@ def compute_temp_top_p_top_k_normalized_logprobs( if logits_metadata.temp_scaled_logprobs: print("Before temp scaling", last_logits[0,:20]) print(f"Scaling logits by temperature: {logits_metadata.temperature}") - last_logits = last_logits / logits_metadata.temperature + #last_logits = last_logits / logits_metadata.temperature + last_logits.div_(logits_metadata.temperature) print("After temp scaling", last_logits[0,:20]) needs_top_p = logits_metadata.top_p_normalized_logprobs \ @@ -563,18 +567,29 @@ def compute_temp_top_p_top_k_normalized_logprobs( probs = torch.softmax(last_logits, dim=-1) del last_logits + """ + if needs_top_p or needs_top_k: + from sglang.srt.layers.sampler import top_p_top_k_normalize_probs_torch + probs = top_p_top_k_normalize_probs_torch( + probs, + logits_metadata.top_p, + logits_metadata.top_k + ) + """ + + if needs_top_p: print(" Applying top p") from sglang.srt.layers.sampler import top_p_normalize_probs_torch probs = top_p_normalize_probs_torch(probs, logits_metadata.top_p) print(f"After top p: {probs[0,:20]}") - + """ if needs_top_k: print(" Applying top k") from sglang.srt.layers.sampler import top_k_normalize_probs_torch probs = top_k_normalize_probs_torch(probs, logits_metadata.top_k) print(f"After top k: {probs[0,:20]}") - + """ return torch.log(probs) @@ -608,7 +623,7 @@ def compute_temp_top_p_normalized_logprobs( ): from sglang.srt.layers.sampler import top_p_normalize_probs_torch - print(f"Normalizing logprobs by top_p: {logits_metadata.top_p}") + print(f"!!!! Normalizing logprobs by top_p: {logits_metadata.top_p}") probs = torch.softmax(last_logits, dim=-1) del last_logits probs = top_p_normalize_probs_torch(probs, logits_metadata.top_p) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 93d7f803c07..e27b576971e 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -254,6 +254,25 @@ def top_k_normalize_probs_torch( return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort) +def top_p_top_k_normalize_probs_torch( + probs: torch.Tensor, # [T,V] + top_ps: torch.Tensor, # [T] + top_ks: torch.Tensor, # [T] + ): + probs_sort, probs_idx = probs.sort(dim=-1, descending=True) + + # top_ps + probs_sum = torch.cumsum(probs_sort, dim=-1) + probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0 + + # top_ks + mask = torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks.view(-1, 1) + probs_sort[mask] = 0.0 + + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort) + + def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]): assert len(top_logprobs_nums) == logprobs.shape[0], ( len(top_logprobs_nums), From 2e69ffdd95752f8a3f0b00699966bd55cd333536 Mon Sep 17 00:00:00 2001 From: Kyle Pena Date: Tue, 22 Apr 2025 18:30:10 +0000 Subject: [PATCH 4/5] replaced allocation with in-place operation --- python/sglang/srt/layers/logits_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 43f6158b9cb..c3ca4620f50 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -613,7 +613,7 @@ def compute_temp_top_p_normalized_logprobs( # Scale logits if temperature scaling is enabled if logits_metadata.temp_scaled_logprobs: print(f"Scaling logits by temperature: {logits_metadata.temperature}") - last_logits = last_logits / logits_metadata.temperature + last_logits = last_logits.div_(logits_metadata.temperature) # Normalize logprobs if top_p normalization is enabled # NOTE: only normalize logprobs when top_p is set and not equal to 1.0 From 023e7151b24fc76aa022a9de34b44a7178342dc5 Mon Sep 17 00:00:00 2001 From: Kyle Pena Date: Wed, 23 Apr 2025 01:08:14 +0000 Subject: [PATCH 5/5] removed a bunch of print statements. checkpoint before implementing full logits dump --- python/sglang/srt/layers/logits_processor.py | 33 -------------------- 1 file changed, 33 deletions(-) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index c3ca4620f50..f4cbb2de2d4 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -144,16 +144,6 @@ def from_forward_batch(cls, forward_batch: ForwardBatch): extend_token_ids_logprob ) = extend_logprob_pruned_lens_cpu = False - print(f"In from_forward_batch: {forward_batch}") - print(f"temp_scaled_logprobs: {forward_batch.temp_scaled_logprobs}") - print(f"temperature: {forward_batch.temperature}") - print(f"top_p: {forward_batch.top_p}") - print(f"top_p_normalized_logprobs: {forward_batch.top_p_normalized_logprobs}") - print(f"top_k: {forward_batch.top_k}") - print(f"top_k_normalized_logprobs: {forward_batch.top_k_normalized_logprobs}") - print(f"min_p: {forward_batch.min_p}") - print(f"min_p_normalized_logprobs: {forward_batch.min_p_normalized_logprobs}") - return cls( forward_mode=forward_batch.forward_mode, capture_hidden_mode=forward_batch.capture_hidden_mode, @@ -394,9 +384,7 @@ def forward( # Note: this is how many "k" we want to return, not the top_k for sampling purposes if logits_metadata.extend_return_top_logprob: # Clamp to avoid -inf, which certainly happens if we use top-p or top-k - print("Before clamp", input_logprobs[0,:20]) input_logprobs = input_logprobs.clamp(min=torch.finfo(input_logprobs.dtype).min) - print("After clamp", input_logprobs[0,:20]) ( input_top_logprobs_val, @@ -543,11 +531,7 @@ def compute_temp_top_p_top_k_normalized_logprobs( # Note: We should also incorporate custom logit processors and/or grammar backend masks as well if logits_metadata.temp_scaled_logprobs: - print("Before temp scaling", last_logits[0,:20]) - print(f"Scaling logits by temperature: {logits_metadata.temperature}") - #last_logits = last_logits / logits_metadata.temperature last_logits.div_(logits_metadata.temperature) - print("After temp scaling", last_logits[0,:20]) needs_top_p = logits_metadata.top_p_normalized_logprobs \ and (logits_metadata.top_p != 1.0).any() @@ -555,12 +539,6 @@ def compute_temp_top_p_top_k_normalized_logprobs( needs_top_k = logits_metadata.top_k_normalized_logprobs \ and ((logits_metadata.top_k != -1) & (logits_metadata.top_k < last_logits.shape[-1])).any() - - print("in compute_temp_top_p_top_k_normalized_logprobs") - print(f"needs_top_p: {needs_top_p}") - print(f"needs_top_k: {needs_top_k}") - print(f"last_logits.shape[-1]: {last_logits.shape[-1]}") - if not needs_top_p and not needs_top_k: return torch.nn.functional.log_softmax(last_logits, dim=-1) @@ -579,10 +557,8 @@ def compute_temp_top_p_top_k_normalized_logprobs( if needs_top_p: - print(" Applying top p") from sglang.srt.layers.sampler import top_p_normalize_probs_torch probs = top_p_normalize_probs_torch(probs, logits_metadata.top_p) - print(f"After top p: {probs[0,:20]}") """ if needs_top_k: print(" Applying top k") @@ -604,15 +580,8 @@ def compute_temp_top_p_normalized_logprobs( torch.Tensor: logprobs from logits """ - print(f"In compute_temp_top_p_normalized_logprobs: {logits_metadata.temp_scaled_logprobs}") - print(f"temp_scaled_logprobs: {logits_metadata.temp_scaled_logprobs}") - print(f"temperature: {logits_metadata.temperature}") - print(f"top_p_normalized_logprobs: {logits_metadata.top_p_normalized_logprobs}") - print(f"top_p: {logits_metadata.top_p}") - # Scale logits if temperature scaling is enabled if logits_metadata.temp_scaled_logprobs: - print(f"Scaling logits by temperature: {logits_metadata.temperature}") last_logits = last_logits.div_(logits_metadata.temperature) # Normalize logprobs if top_p normalization is enabled @@ -622,8 +591,6 @@ def compute_temp_top_p_normalized_logprobs( and (logits_metadata.top_p != 1.0).any() ): from sglang.srt.layers.sampler import top_p_normalize_probs_torch - - print(f"!!!! Normalizing logprobs by top_p: {logits_metadata.top_p}") probs = torch.softmax(last_logits, dim=-1) del last_logits probs = top_p_normalize_probs_torch(probs, logits_metadata.top_p)