Skip to content

Commit

Permalink
add LogitsMetadata (#604)
Browse files Browse the repository at this point in the history
  • Loading branch information
hnyls2002 authored Jul 9, 2024
1 parent f4e885b commit f25b76c
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 40 deletions.
6 changes: 3 additions & 3 deletions benchmark/line_retrieval/gen_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def generate_lines(random_words, num_lines, redirect_ratio):
)
for i in redirect_indices:
target_idx = np.random.choice(min(i * 2 + 100, num_lines))
lines[i] = (
f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}."
)
lines[
i
] = f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}."
redirects[i] = target_idx

# Build links and find sources
Expand Down
69 changes: 50 additions & 19 deletions python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Logits processing."""

import dataclasses
from typing import List
from typing import List, Union

import torch
from torch import nn
Expand Down Expand Up @@ -31,21 +31,42 @@ class LogitProcessorOutput:
decode_top_logprobs: List


@dataclasses.dataclass
class LogitsMetadata:
forward_mode: ForwardMode
extend_seq_lens: torch.Tensor
extend_start_loc: torch.Tensor

# For logprobs
return_logprob: bool
top_logprobs_nums: List[int]

@classmethod
def from_input_metadata(cls, input_metadata: InputMetadata):
return cls(
forward_mode=input_metadata.forward_mode,
extend_seq_lens=input_metadata.extend_seq_lens,
extend_start_loc=input_metadata.extend_start_loc,
return_logprob=input_metadata.return_logprob,
top_logprobs_nums=input_metadata.top_logprobs_nums,
)


class LogitsProcessor(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()

def _get_normalized_prompt_logprobs(
self, prefill_token_logprobs, input_metadata: InputMetadata
self, prefill_token_logprobs, logits_metadata: LogitsMetadata
):
logprobs_cumsum = torch.cumsum(
prefill_token_logprobs, dim=0, dtype=torch.float32
)

start = input_metadata.extend_start_loc.clone()
end = start + input_metadata.extend_seq_lens - 2
start = logits_metadata.extend_start_loc.clone()
end = start + logits_metadata.extend_seq_lens - 2
start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
sum_logp = (
Expand All @@ -54,17 +75,17 @@ def _get_normalized_prompt_logprobs(
+ prefill_token_logprobs[start]
)
normalized_prompt_logprobs = sum_logp / (
(input_metadata.extend_seq_lens - 1).clamp(min=1)
(logits_metadata.extend_seq_lens - 1).clamp(min=1)
)

return normalized_prompt_logprobs

def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata):
def _get_top_logprobs(self, all_logprobs, logits_metadata: LogitsMetadata):
# TODO: vectorize the code below
if input_metadata.forward_mode == ForwardMode.DECODE:
if logits_metadata.forward_mode == ForwardMode.DECODE:
decode_top_logprobs = []
for i in range(all_logprobs.shape[0]):
k = input_metadata.top_logprobs_nums[i]
k = logits_metadata.top_logprobs_nums[i]
t = all_logprobs[i].topk(k)
v_cpu = t.values.tolist()
p_cpu = t.indices.tolist()
Expand All @@ -73,13 +94,13 @@ def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata):
else:
prefill_top_logprobs, decode_top_logprobs = [], []
pt = 0
extend_seq_lens_cpu = input_metadata.extend_seq_lens.tolist()
extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
if extend_seq_len == 0:
prefill_top_logprobs.append([])
decode_top_logprobs.append([])
continue
k = input_metadata.top_logprobs_nums[i]
k = logits_metadata.top_logprobs_nums[i]
t = all_logprobs[pt : pt + extend_seq_len].topk(k)
vs_cpu = t.values.tolist()
ps_cpu = t.indices.tolist()
Expand All @@ -91,14 +112,24 @@ def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata):

return prefill_top_logprobs, decode_top_logprobs

def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
def forward(
self,
input_ids,
hidden_states,
weight,
logits_metadata: Union[LogitsMetadata, InputMetadata],
):
if isinstance(logits_metadata, InputMetadata):
logits_metadata = LogitsMetadata.from_input_metadata(logits_metadata)
assert isinstance(logits_metadata, LogitsMetadata)

# Get the last hidden states and last logits for the next token prediction
if input_metadata.forward_mode == ForwardMode.DECODE:
if logits_metadata.forward_mode == ForwardMode.DECODE:
last_index = None
last_hidden = hidden_states
else:
last_index = (
torch.cumsum(input_metadata.extend_seq_lens, dim=0, dtype=torch.long)
torch.cumsum(logits_metadata.extend_seq_lens, dim=0, dtype=torch.long)
- 1
)
last_hidden = hidden_states[last_index]
Expand All @@ -114,7 +145,7 @@ def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadat
last_logits *= self.config.final_logit_softcapping

# Return only last_logits if logprob is not requested
if not input_metadata.return_logprob:
if not logits_metadata.return_logprob:
return LogitProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=None,
Expand All @@ -125,7 +156,7 @@ def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadat
)
else:
# When logprob is requested, compute the logits for all tokens.
if input_metadata.forward_mode == ForwardMode.DECODE:
if logits_metadata.forward_mode == ForwardMode.DECODE:
all_logits = last_logits
else:
all_logits = torch.matmul(hidden_states, weight.T)
Expand All @@ -138,15 +169,15 @@ def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadat
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)

# Get the logprob of top-k tokens
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
return_top_logprob = any(x > 0 for x in logits_metadata.top_logprobs_nums)
if return_top_logprob:
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
all_logprobs, input_metadata
all_logprobs, logits_metadata
)
else:
prefill_top_logprobs = decode_top_logprobs = None

if input_metadata.forward_mode == ForwardMode.DECODE:
if logits_metadata.forward_mode == ForwardMode.DECODE:
return LogitProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=all_logprobs,
Expand All @@ -166,7 +197,7 @@ def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadat
]

normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
prefill_token_logprobs, input_metadata
prefill_token_logprobs, logits_metadata
)

return LogitProcessorOutput(
Expand Down
3 changes: 1 addition & 2 deletions python/sglang/srt/layers/radix_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

import numpy as np
import torch
from torch import nn

from flashinfer.cascade import merge_state
from torch import nn

from sglang.global_config import global_config
from sglang.srt.layers.extend_attention import extend_attention_fwd
Expand Down
16 changes: 8 additions & 8 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,15 +334,15 @@ def convert_logprob_style(
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
)
if top_logprobs_num > 0:
ret["meta_info"]["prefill_top_logprobs"] = (
self.detokenize_top_logprobs_tokens(
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
)
ret["meta_info"][
"prefill_top_logprobs"
] = self.detokenize_top_logprobs_tokens(
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
)
ret["meta_info"]["decode_top_logprobs"] = (
self.detokenize_top_logprobs_tokens(
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
)
ret["meta_info"][
"decode_top_logprobs"
] = self.detokenize_top_logprobs_tokens(
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
)
return ret

Expand Down
5 changes: 0 additions & 5 deletions python/sglang/srt/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def forward_cuda(


class GemmaRotaryEmbedding(RotaryEmbedding):

def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
inv_freq = 1.0 / (
Expand All @@ -95,7 +94,6 @@ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:


class Gemma2MLP(nn.Module):

def __init__(
self,
hidden_size: int,
Expand Down Expand Up @@ -127,7 +125,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class Gemma2Attention(nn.Module):

def __init__(
self,
layer_idx: int,
Expand Down Expand Up @@ -218,7 +215,6 @@ def forward(


class Gemma2DecoderLayer(nn.Module):

def __init__(
self,
layer_idx: int,
Expand Down Expand Up @@ -287,7 +283,6 @@ def forward(


class Gemma2Model(nn.Module):

def __init__(
self,
config: PretrainedConfig,
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/srt/models/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ def __init__(
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None
):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings
)
rope_scaling[
"original_max_position_embeddings"
] = config.original_max_position_embeddings
rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.self_attn = LlamaAttention(
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ def monkey_patch_vllm_p2p_access_check(gpu_id: int):
"""

import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt

setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)


Expand Down

0 comments on commit f25b76c

Please sign in to comment.