Skip to content

Commit

Permalink
feat: medusa v2 (#1734)
Browse files Browse the repository at this point in the history
  • Loading branch information
OlivierDehaene authored Apr 12, 2024
1 parent 1b2670c commit eefea5e
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 46 deletions.
2 changes: 1 addition & 1 deletion server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def get_model(
if speculate is not None:
if speculate > speculate_medusa:
raise RuntimeError(
"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
)
else:
set_speculate(speculate)
Expand Down
26 changes: 9 additions & 17 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ def warmup(self, batch: FlashCausalLMBatch):
for bs in CUDA_GRAPHS:
if self.speculate is None or self.speculate + 1 <= bs:
self.cuda_graph_warmup(bs, max_s, max_bt)
except Exception:
except torch.cuda.OutOfMemoryError:
logger.exception(f"Decode cuda graph warmup failed")

return int(num_blocks * BLOCK_SIZE)
Expand Down Expand Up @@ -874,22 +874,14 @@ def forward(
lm_head_indices = batch.prefill_head_indices

bs = input_ids.shape[0]
padded_bs = bs
if bs == 3:
padded_bs = 4
elif 3 < bs <= 8:
padded_bs = 8
elif bs > 8:
padded_bs = (bs + 7) // 8 * 8

# Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(padded_bs, None)

if (
cu_seqlen_prefill is not None
or cuda_graph is None
or batch.speculative_ids is not None
):
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
if sorted_padded_bs:
# Get associated cuda graph
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
else:
cuda_graph = None

if cu_seqlen_prefill is not None or cuda_graph is None:
return self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
Expand Down
165 changes: 137 additions & 28 deletions server/text_generation_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,12 +432,12 @@ def forward(self, x):


class MedusaModel(torch.nn.Module):
def __init__(self, config, weights):
def __init__(self, config, medusa_config, weights):
super().__init__()
self.heads = torch.nn.ModuleList(
[
MedusaHead(config, prefix=f"{i}", weights=weights)
for i in range(config["medusa_num_heads"])
MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
for i in range(medusa_config["medusa_num_heads"])
]
)

Expand All @@ -447,12 +447,12 @@ def forward(self, x):


class MedusaHead(torch.nn.Module):
def __init__(self, config, prefix, weights):
def __init__(self, config, medusa_config, prefix, weights):
super().__init__()
self.blocks = torch.nn.ModuleList(
[
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
for i in range(config["medusa_num_layers"])
for i in range(medusa_config["medusa_num_layers"])
]
)
n = len(self.blocks)
Expand All @@ -467,46 +467,155 @@ def forward(self, x):
return x


class SpeculativeHead(nn.Module):
class MedusaHeadV1(nn.Module):
def __init__(self, lm_head, medusa):
super().__init__()
self.lm_head = lm_head
self.medusa = medusa

@staticmethod
def load(config, prefix: str, weights):
from pathlib import Path
from safetensors import safe_open
import json

use_medusa = config.use_medusa

medusa_config = str(Path(use_medusa) / "config.json")
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")

with open(medusa_config, "r") as f:
medusa_config = json.load(f)
routing = weights.routing
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing and routing[k] != filename:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename

medusa = MedusaModel(config, medusa_config, weights)
lm_head = TensorParallelHead.load(config, prefix, weights)
return MedusaHeadV1(lm_head, medusa)

def forward(
self, input: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
logits = self.lm_head(input)
speculative_logits = self.medusa(input)
return logits, speculative_logits


class MedusaHeadV2(nn.Module):
def __init__(self, config, prefix, weights):
super().__init__()
from pathlib import Path
from safetensors import safe_open
import json

use_medusa = config.use_medusa

medusa_config = str(Path(use_medusa) / "config.json")
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")

with open(medusa_config, "r") as f:
medusa_config = json.load(f)
routing = weights.routing
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing and routing[k] != filename:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
routing[k] = filename

self.n_medusa_heads = medusa_config["medusa_num_heads"]

assert medusa_config["medusa_num_layers"] == 1
self.linear = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{i}.0.linear" for i in range(self.n_medusa_heads)],
dim=0,
weights=weights,
bias=True,
)
self.process_group = weights.process_group
self.world_size = self.process_group.size()
self.rank = self.process_group.rank()

self.act = torch.nn.SiLU()

self.lm_head = TensorParallelHead.load(config, prefix, weights)

def forward(self, x):
size = x.shape[-1]
block_size = (size + self.world_size - 1) // self.world_size
start = self.rank * block_size
stop = (self.rank + 1) * block_size

x_block = x[:, start:stop]

# Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1
medusa_res = self.act(self.linear(x)).reshape(
*x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1]
)

# Apply all residual medusa heads
output = x[:, start:stop].unsqueeze(-2) + medusa_res

# Gather medusa heads
world_output = [
torch.empty_like(output) for _ in range(self.process_group.size())
]
torch.distributed.all_gather(world_output, output, group=self.process_group)
world_output = torch.cat(world_output, dim=-1)

# Stack x and medusa residual x
stacked_x = torch.cat([x.unsqueeze(-2), world_output], dim=-2)

# Compute lm head on x + medusa residual x
logits = self.lm_head(stacked_x)

# Finally, split logits from speculative logits
logits, speculative_logits = torch.split(
logits, [1, self.n_medusa_heads], dim=-2
)
# Squeeze added dimension
logits = logits.squeeze(-2)

return logits, speculative_logits


class SpeculativeHead(nn.Module):
def __init__(self, lm_head, medusa):
super().__init__()
self.head = lm_head
self.medusa = medusa

@staticmethod
def load(config, prefix: str, weights):
use_medusa = config.use_medusa
if use_medusa:
from pathlib import Path
from safetensors import safe_open
import json

medusa_config = str(Path(use_medusa) / "config.json")
filename = str(Path(use_medusa) / "medusa_lm_head.safetensors")

with open(medusa_config, "r") as f:
config = json.load(f)
routing = weights.routing
with safe_open(filename, framework="pytorch") as f:
for k in f.keys():
if k in routing:
raise RuntimeError(
f"Key {k} was found in multiple files: {filename} and {routing[k]}"
)
weights.routing[k] = filename

medusa = MedusaModel(config, weights)
lm_head = None
try:
medusa = MedusaHeadV1.load(config, prefix, weights)
except:
medusa = MedusaHeadV2(config, prefix, weights)
else:
lm_head = TensorParallelHead.load(config, prefix, weights)
medusa = None
return SpeculativeHead(lm_head, medusa)

def forward(
self, input: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
logits = self.lm_head(input)
speculative_logits = self.medusa(input) if self.medusa is not None else None
return logits, speculative_logits
if self.medusa is not None:
return self.medusa(input)

assert self.head is not None
logits = self.head(input)
return logits, None


class TensorParallelHead(SuperLayer):
Expand Down

0 comments on commit eefea5e

Please sign in to comment.