Skip to content

Commit eefea5e

Browse files
feat: medusa v2 (#1734)
1 parent 1b2670c commit eefea5e

File tree

3 files changed

+147
-46
lines changed

3 files changed

+147
-46
lines changed

server/text_generation_server/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def get_model(
145145
if speculate is not None:
146146
if speculate > speculate_medusa:
147147
raise RuntimeError(
148-
"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
148+
f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
149149
)
150150
else:
151151
set_speculate(speculate)

server/text_generation_server/models/flash_causal_lm.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,7 @@ def warmup(self, batch: FlashCausalLMBatch):
814814
for bs in CUDA_GRAPHS:
815815
if self.speculate is None or self.speculate + 1 <= bs:
816816
self.cuda_graph_warmup(bs, max_s, max_bt)
817-
except Exception:
817+
except torch.cuda.OutOfMemoryError:
818818
logger.exception(f"Decode cuda graph warmup failed")
819819

820820
return int(num_blocks * BLOCK_SIZE)
@@ -874,22 +874,14 @@ def forward(
874874
lm_head_indices = batch.prefill_head_indices
875875

876876
bs = input_ids.shape[0]
877-
padded_bs = bs
878-
if bs == 3:
879-
padded_bs = 4
880-
elif 3 < bs <= 8:
881-
padded_bs = 8
882-
elif bs > 8:
883-
padded_bs = (bs + 7) // 8 * 8
884-
885-
# Try to find an associated cuda graph
886-
cuda_graph = self.cuda_graphs.get(padded_bs, None)
887-
888-
if (
889-
cu_seqlen_prefill is not None
890-
or cuda_graph is None
891-
or batch.speculative_ids is not None
892-
):
877+
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
878+
if sorted_padded_bs:
879+
# Get associated cuda graph
880+
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
881+
else:
882+
cuda_graph = None
883+
884+
if cu_seqlen_prefill is not None or cuda_graph is None:
893885
return self.model.forward(
894886
input_ids=input_ids,
895887
position_ids=position_ids,

server/text_generation_server/utils/layers.py

Lines changed: 137 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -432,12 +432,12 @@ def forward(self, x):
432432

433433

434434
class MedusaModel(torch.nn.Module):
435-
def __init__(self, config, weights):
435+
def __init__(self, config, medusa_config, weights):
436436
super().__init__()
437437
self.heads = torch.nn.ModuleList(
438438
[
439-
MedusaHead(config, prefix=f"{i}", weights=weights)
440-
for i in range(config["medusa_num_heads"])
439+
MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
440+
for i in range(medusa_config["medusa_num_heads"])
441441
]
442442
)
443443

@@ -447,12 +447,12 @@ def forward(self, x):
447447

448448

449449
class MedusaHead(torch.nn.Module):
450-
def __init__(self, config, prefix, weights):
450+
def __init__(self, config, medusa_config, prefix, weights):
451451
super().__init__()
452452
self.blocks = torch.nn.ModuleList(
453453
[
454454
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
455-
for i in range(config["medusa_num_layers"])
455+
for i in range(medusa_config["medusa_num_layers"])
456456
]
457457
)
458458
n = len(self.blocks)
@@ -467,46 +467,155 @@ def forward(self, x):
467467
return x
468468

469469

470-
class SpeculativeHead(nn.Module):
470+
class MedusaHeadV1(nn.Module):
471471
def __init__(self, lm_head, medusa):
472472
super().__init__()
473473
self.lm_head = lm_head
474474
self.medusa = medusa
475475

476476
@staticmethod
477477
def load(config, prefix: str, weights):
478+
from pathlib import Path
479+
from safetensors import safe_open
480+
import json
481+
482+
use_medusa = config.use_medusa
483+
484+
medusa_config = str(Path(use_medusa) / "config.json")
485+
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
486+
487+
with open(medusa_config, "r") as f:
488+
medusa_config = json.load(f)
489+
routing = weights.routing
490+
with safe_open(filename, framework="pytorch") as f:
491+
for k in f.keys():
492+
if k in routing and routing[k] != filename:
493+
raise RuntimeError(
494+
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
495+
)
496+
routing[k] = filename
497+
498+
medusa = MedusaModel(config, medusa_config, weights)
478499
lm_head = TensorParallelHead.load(config, prefix, weights)
500+
return MedusaHeadV1(lm_head, medusa)
501+
502+
def forward(
503+
self, input: torch.Tensor
504+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
505+
logits = self.lm_head(input)
506+
speculative_logits = self.medusa(input)
507+
return logits, speculative_logits
508+
509+
510+
class MedusaHeadV2(nn.Module):
511+
def __init__(self, config, prefix, weights):
512+
super().__init__()
513+
from pathlib import Path
514+
from safetensors import safe_open
515+
import json
516+
517+
use_medusa = config.use_medusa
518+
519+
medusa_config = str(Path(use_medusa) / "config.json")
520+
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
521+
522+
with open(medusa_config, "r") as f:
523+
medusa_config = json.load(f)
524+
routing = weights.routing
525+
with safe_open(filename, framework="pytorch") as f:
526+
for k in f.keys():
527+
if k in routing and routing[k] != filename:
528+
raise RuntimeError(
529+
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
530+
)
531+
routing[k] = filename
532+
533+
self.n_medusa_heads = medusa_config["medusa_num_heads"]
534+
535+
assert medusa_config["medusa_num_layers"] == 1
536+
self.linear = TensorParallelColumnLinear.load_multi(
537+
config,
538+
prefixes=[f"{i}.0.linear" for i in range(self.n_medusa_heads)],
539+
dim=0,
540+
weights=weights,
541+
bias=True,
542+
)
543+
self.process_group = weights.process_group
544+
self.world_size = self.process_group.size()
545+
self.rank = self.process_group.rank()
546+
547+
self.act = torch.nn.SiLU()
548+
549+
self.lm_head = TensorParallelHead.load(config, prefix, weights)
550+
551+
def forward(self, x):
552+
size = x.shape[-1]
553+
block_size = (size + self.world_size - 1) // self.world_size
554+
start = self.rank * block_size
555+
stop = (self.rank + 1) * block_size
556+
557+
x_block = x[:, start:stop]
558+
559+
# Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1
560+
medusa_res = self.act(self.linear(x)).reshape(
561+
*x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1]
562+
)
563+
564+
# Apply all residual medusa heads
565+
output = x[:, start:stop].unsqueeze(-2) + medusa_res
566+
567+
# Gather medusa heads
568+
world_output = [
569+
torch.empty_like(output) for _ in range(self.process_group.size())
570+
]
571+
torch.distributed.all_gather(world_output, output, group=self.process_group)
572+
world_output = torch.cat(world_output, dim=-1)
573+
574+
# Stack x and medusa residual x
575+
stacked_x = torch.cat([x.unsqueeze(-2), world_output], dim=-2)
576+
577+
# Compute lm head on x + medusa residual x
578+
logits = self.lm_head(stacked_x)
579+
580+
# Finally, split logits from speculative logits
581+
logits, speculative_logits = torch.split(
582+
logits, [1, self.n_medusa_heads], dim=-2
583+
)
584+
# Squeeze added dimension
585+
logits = logits.squeeze(-2)
586+
587+
return logits, speculative_logits
588+
589+
590+
class SpeculativeHead(nn.Module):
591+
def __init__(self, lm_head, medusa):
592+
super().__init__()
593+
self.head = lm_head
594+
self.medusa = medusa
595+
596+
@staticmethod
597+
def load(config, prefix: str, weights):
479598
use_medusa = config.use_medusa
480599
if use_medusa:
481-
from pathlib import Path
482-
from safetensors import safe_open
483-
import json
484-
485-
medusa_config = str(Path(use_medusa) / "config.json")
486-
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")
487-
488-
with open(medusa_config, "r") as f:
489-
config = json.load(f)
490-
routing = weights.routing
491-
with safe_open(filename, framework="pytorch") as f:
492-
for k in f.keys():
493-
if k in routing:
494-
raise RuntimeError(
495-
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
496-
)
497-
weights.routing[k] = filename
498-
499-
medusa = MedusaModel(config, weights)
600+
lm_head = None
601+
try:
602+
medusa = MedusaHeadV1.load(config, prefix, weights)
603+
except:
604+
medusa = MedusaHeadV2(config, prefix, weights)
500605
else:
606+
lm_head = TensorParallelHead.load(config, prefix, weights)
501607
medusa = None
502608
return SpeculativeHead(lm_head, medusa)
503609

504610
def forward(
505611
self, input: torch.Tensor
506612
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
507-
logits = self.lm_head(input)
508-
speculative_logits = self.medusa(input) if self.medusa is not None else None
509-
return logits, speculative_logits
613+
if self.medusa is not None:
614+
return self.medusa(input)
615+
616+
assert self.head is not None
617+
logits = self.head(input)
618+
return logits, None
510619

511620

512621
class TensorParallelHead(SuperLayer):

0 commit comments

Comments
 (0)