Skip to content

DO NOT MERGE - WIP - Sglang fork verification v2 #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 75 additions & 2 deletions python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ class LogitsMetadata:
temperature: torch.Tensor = None
top_p_normalized_logprobs: bool = False
top_p: torch.Tensor = None
top_k_normalized_logprobs: bool = False
top_k: torch.Tensor = None
min_p_normalized_logprobs: bool = False
min_p: torch.Tensor = None

# DP attention metadata. Not needed when DP attention is not used.
# Number of tokens in the request.
Expand Down Expand Up @@ -161,6 +165,14 @@ def from_forward_batch(cls, forward_batch: ForwardBatch):
forward_batch_gathered_buffer=forward_batch.gathered_buffer,
global_num_tokens_for_logprob_cpu=forward_batch.global_num_tokens_for_logprob_cpu,
global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu,
temp_scaled_logprobs=forward_batch.temp_scaled_logprobs,
temperature=forward_batch.temperature,
top_p_normalized_logprobs=forward_batch.top_p_normalized_logprobs,
top_p=forward_batch.top_p,
top_k_normalized_logprobs=forward_batch.top_k_normalized_logprobs,
top_k=forward_batch.top_k,
min_p_normalized_logprobs=forward_batch.min_p_normalized_logprobs,
min_p=forward_batch.min_p,
)

def compute_dp_attention_metadata(self, hidden_states: torch.Tensor):
Expand Down Expand Up @@ -224,6 +236,7 @@ def forward(
lm_head: VocabParallelEmbedding,
logits_metadata: Union[LogitsMetadata, ForwardBatch],
) -> LogitsProcessorOutput:
print("in .forwrd of LogitsProcessor")
if isinstance(logits_metadata, ForwardBatch):
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)

Expand Down Expand Up @@ -336,6 +349,7 @@ def forward(
hidden_states=hidden_states_to_store,
)
else:
assert logits_metadata.forward_mode.is_extend(), "Extend mode is required for return_logprob"
input_logprobs = logits[input_logprob_indices]
del hidden_states, logits

Expand All @@ -354,12 +368,24 @@ def forward(
logits_metadata.top_p,
pruned_lens,
)

input_logprobs = self.compute_temp_top_p_normalized_logprobs(
input_logprobs, logits_metadata
)

# IDEA: do top_k using top_p of np.max(input_logprobs, axis = 1)

# Going to have to leave the top-k normalization out for now (CUDA OOM)
#input_logprobs = self.compute_temp_top_p_top_k_normalized_logprobs(
# input_logprobs, logits_metadata
#)

# Get the logprob of top-k tokens
# Note: this is how many "k" we want to return, not the top_k for sampling purposes
if logits_metadata.extend_return_top_logprob:
# Clamp to avoid -inf, which certainly happens if we use top-p or top-k
input_logprobs = input_logprobs.clamp(min=torch.finfo(input_logprobs.dtype).min)

(
input_top_logprobs_val,
input_top_logprobs_idx,
Expand Down Expand Up @@ -496,6 +522,53 @@ def get_token_ids_logprobs(

return input_token_ids_logprobs_val, input_token_ids_logprobs_idx


@staticmethod
def compute_temp_top_p_top_k_normalized_logprobs(
last_logits: torch.Tensor, logits_metadata: LogitsMetadata
) -> torch.Tensor:

# Note: We should also incorporate custom logit processors and/or grammar backend masks as well

if logits_metadata.temp_scaled_logprobs:
last_logits.div_(logits_metadata.temperature)

needs_top_p = logits_metadata.top_p_normalized_logprobs \
and (logits_metadata.top_p != 1.0).any()

needs_top_k = logits_metadata.top_k_normalized_logprobs \
and ((logits_metadata.top_k != -1) & (logits_metadata.top_k < last_logits.shape[-1])).any()

if not needs_top_p and not needs_top_k:
return torch.nn.functional.log_softmax(last_logits, dim=-1)

probs = torch.softmax(last_logits, dim=-1)
del last_logits

"""
if needs_top_p or needs_top_k:
from sglang.srt.layers.sampler import top_p_top_k_normalize_probs_torch
probs = top_p_top_k_normalize_probs_torch(
probs,
logits_metadata.top_p,
logits_metadata.top_k
)
"""


if needs_top_p:
from sglang.srt.layers.sampler import top_p_normalize_probs_torch
probs = top_p_normalize_probs_torch(probs, logits_metadata.top_p)
"""
if needs_top_k:
print(" Applying top k")
from sglang.srt.layers.sampler import top_k_normalize_probs_torch
probs = top_k_normalize_probs_torch(probs, logits_metadata.top_k)
print(f"After top k: {probs[0,:20]}")
"""

return torch.log(probs)

@staticmethod
def compute_temp_top_p_normalized_logprobs(
last_logits: torch.Tensor, logits_metadata: LogitsMetadata
Expand All @@ -506,9 +579,10 @@ def compute_temp_top_p_normalized_logprobs(
Returns:
torch.Tensor: logprobs from logits
"""

# Scale logits if temperature scaling is enabled
if logits_metadata.temp_scaled_logprobs:
last_logits = last_logits / logits_metadata.temperature
last_logits = last_logits.div_(logits_metadata.temperature)

# Normalize logprobs if top_p normalization is enabled
# NOTE: only normalize logprobs when top_p is set and not equal to 1.0
Expand All @@ -517,7 +591,6 @@ def compute_temp_top_p_normalized_logprobs(
and (logits_metadata.top_p != 1.0).any()
):
from sglang.srt.layers.sampler import top_p_normalize_probs_torch

probs = torch.softmax(last_logits, dim=-1)
del last_logits
probs = top_p_normalize_probs_torch(probs, logits_metadata.top_p)
Expand Down
30 changes: 30 additions & 0 deletions python/sglang/srt/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,36 @@ def top_p_normalize_probs_torch(
return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)


def top_k_normalize_probs_torch(
probs: torch.Tensor, # [T,V]
top_ks: torch.Tensor # [T]
):
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
mask = torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks.view(-1, 1)
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)


def top_p_top_k_normalize_probs_torch(
probs: torch.Tensor, # [T,V]
top_ps: torch.Tensor, # [T]
top_ks: torch.Tensor, # [T]
):
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)

# top_ps
probs_sum = torch.cumsum(probs_sort, dim=-1)
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0

# top_ks
mask = torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks.view(-1, 1)
probs_sort[mask] = 0.0

probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)


def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
assert len(top_logprobs_nums) == logprobs.shape[0], (
len(top_logprobs_nums),
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ def __init__(
self.session_id = session_id
self.input_embeds = input_embeds

print("In Req constructor - temperature: ", sampling_params.temperature)

# Sampling info
if isinstance(sampling_params.custom_params, dict):
sampling_params = copy.copy(sampling_params)
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,8 @@ def handle_generate_request(
)
custom_logit_processor = None

print("In handle_generate_request - temperature: ", recv_req.sampling_params.temperature)

req = Req(
recv_req.rid,
recv_req.input_text,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def add_input_logprob_return_values(
Some of input logprob operation should only happen at the last
prefill (e.g., computing input token logprobs).
"""

assert output.input_token_logprobs is not None
if req.input_token_logprobs is None:
req.input_token_logprobs = []
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,11 +408,15 @@ async def _tokenize_one_request(
f"messages or the completion to fit within the limit."
)

print("In tokenizer_manager _tokenize_one_request (before normalize) - temperature: ", obj.sampling_params)

# Parse sampling parameters
sampling_params = SamplingParams(**obj.sampling_params)
sampling_params.normalize(self.tokenizer)
sampling_params.verify()

print("In tokenizer_manager _tokenize_one_request (after normalize) - temperature: ", sampling_params.temperature)

# Build return object
if isinstance(obj, GenerateReqInput):
tokenized_obj = TokenizedGenerateReqInput(
Expand Down
12 changes: 12 additions & 0 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ class ForwardBatch:
temperature: torch.Tensor = None
top_p_normalized_logprobs: bool = False
top_p: torch.Tensor = None
top_k_normalized_logprobs: bool = False
top_k: torch.Tensor = None
min_p_normalized_logprobs: bool = False
min_p: torch.Tensor = None

# Position information
positions: torch.Tensor = None
Expand Down Expand Up @@ -261,6 +265,14 @@ def init_new(
capture_hidden_mode=batch.capture_hidden_mode,
input_embeds=batch.input_embeds,
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
temp_scaled_logprobs=batch.sampling_info.temperatures is not None,
temperature=batch.sampling_info.temperatures,
top_p_normalized_logprobs=batch.sampling_info.top_ps is not None,
top_p=batch.sampling_info.top_ps,
top_k_normalized_logprobs=batch.sampling_info.top_ks is not None,
top_k=batch.sampling_info.top_ks,
min_p_normalized_logprobs=batch.sampling_info.min_ps is not None,
min_p=batch.sampling_info.min_ps,
)

# For DP attention
Expand Down
10 changes: 8 additions & 2 deletions python/sglang/srt/openai_api/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,7 @@ def v1_chat_generate_request(
top_logprobs_nums = []
modalities_list = []
lora_paths = []
return_hidden_states = []

# NOTE: with openai API, the prompt's logprobs are always not computed

Expand Down Expand Up @@ -961,10 +962,11 @@ def v1_chat_generate_request(
modalities = []
input_ids.append(prompt_ids)
return_logprobs.append(request.logprobs)
logprob_start_lens.append(-1)
top_logprobs_nums.append(request.top_logprobs or 0)
lora_paths.append(request.lora_path)

logprob_start_lens.append(request.__pydantic_extra__.get("logprob_start_len", -1))
return_hidden_states.append(request.__pydantic_extra__.get("return_hidden_states", False))

sampling_params = {
"temperature": request.temperature,
"max_new_tokens": request.max_tokens,
Expand Down Expand Up @@ -1029,8 +1031,10 @@ def v1_chat_generate_request(
rid=request_ids,
modalities=modalities_list,
lora_path=lora_paths,
return_hidden_states=return_hidden_states
)


return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]


Expand All @@ -1046,6 +1050,7 @@ def v1_chat_generate_response(

for idx, ret_item in enumerate(ret):
logprobs = False

if isinstance(request, list) and request[idx].logprobs:
logprobs = True
elif (not isinstance(request, list)) and request.logprobs:
Expand Down Expand Up @@ -1226,6 +1231,7 @@ def v1_chat_generate_response(

async def v1_chat_completions(tokenizer_manager, raw_request: Request):
request_json = await raw_request.json()

all_requests = [ChatCompletionRequest(**request_json)]
adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)

Expand Down
4 changes: 3 additions & 1 deletion python/sglang/srt/openai_api/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import time
from typing import Dict, List, Optional, Union

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ConfigDict
from typing_extensions import Literal


Expand Down Expand Up @@ -301,6 +301,8 @@ class ToolChoice(BaseModel):
class ChatCompletionRequest(BaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
model_config = ConfigDict(extra='allow')

messages: List[ChatCompletionMessageParam]
model: str
frequency_penalty: float = 0.0
Expand Down
4 changes: 4 additions & 0 deletions python/sglang/srt/sampling/sampling_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ class SamplingBatchInfo:
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
reqs = batch.reqs
device = batch.device
print("IN from_schedule_batch temperature: ", [r.sampling_params.temperature for r in reqs])
print("IN from_schedule_batch top_p: ", [r.sampling_params.top_p for r in reqs])
print("IN from_schedule_batch top_k: ", [r.sampling_params.top_k for r in reqs])
print("IN from_schedule_batch min_p: ", [r.sampling_params.min_p for r in reqs])
temperatures = (
torch.tensor(
[r.sampling_params.temperature for r in reqs],
Expand Down