@@ -432,12 +432,12 @@ def forward(self, x):
432
432
433
433
434
434
class MedusaModel (torch .nn .Module ):
435
- def __init__ (self , config , weights ):
435
+ def __init__ (self , config , medusa_config , weights ):
436
436
super ().__init__ ()
437
437
self .heads = torch .nn .ModuleList (
438
438
[
439
- MedusaHead (config , prefix = f"{ i } " , weights = weights )
440
- for i in range (config ["medusa_num_heads" ])
439
+ MedusaHead (config , medusa_config , prefix = f"{ i } " , weights = weights )
440
+ for i in range (medusa_config ["medusa_num_heads" ])
441
441
]
442
442
)
443
443
@@ -447,12 +447,12 @@ def forward(self, x):
447
447
448
448
449
449
class MedusaHead (torch .nn .Module ):
450
- def __init__ (self , config , prefix , weights ):
450
+ def __init__ (self , config , medusa_config , prefix , weights ):
451
451
super ().__init__ ()
452
452
self .blocks = torch .nn .ModuleList (
453
453
[
454
454
ResBlock (config , prefix = f"{ prefix } .{ i } " , weights = weights )
455
- for i in range (config ["medusa_num_layers" ])
455
+ for i in range (medusa_config ["medusa_num_layers" ])
456
456
]
457
457
)
458
458
n = len (self .blocks )
@@ -467,46 +467,155 @@ def forward(self, x):
467
467
return x
468
468
469
469
470
- class SpeculativeHead (nn .Module ):
470
+ class MedusaHeadV1 (nn .Module ):
471
471
def __init__ (self , lm_head , medusa ):
472
472
super ().__init__ ()
473
473
self .lm_head = lm_head
474
474
self .medusa = medusa
475
475
476
476
@staticmethod
477
477
def load (config , prefix : str , weights ):
478
+ from pathlib import Path
479
+ from safetensors import safe_open
480
+ import json
481
+
482
+ use_medusa = config .use_medusa
483
+
484
+ medusa_config = str (Path (use_medusa ) / "config.json" )
485
+ filename = str (Path (use_medusa ) / "medusa_lm_head.safetensors" )
486
+
487
+ with open (medusa_config , "r" ) as f :
488
+ medusa_config = json .load (f )
489
+ routing = weights .routing
490
+ with safe_open (filename , framework = "pytorch" ) as f :
491
+ for k in f .keys ():
492
+ if k in routing and routing [k ] != filename :
493
+ raise RuntimeError (
494
+ f"Key { k } was found in multiple files: { filename } and { routing [k ]} "
495
+ )
496
+ routing [k ] = filename
497
+
498
+ medusa = MedusaModel (config , medusa_config , weights )
478
499
lm_head = TensorParallelHead .load (config , prefix , weights )
500
+ return MedusaHeadV1 (lm_head , medusa )
501
+
502
+ def forward (
503
+ self , input : torch .Tensor
504
+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ]]:
505
+ logits = self .lm_head (input )
506
+ speculative_logits = self .medusa (input )
507
+ return logits , speculative_logits
508
+
509
+
510
+ class MedusaHeadV2 (nn .Module ):
511
+ def __init__ (self , config , prefix , weights ):
512
+ super ().__init__ ()
513
+ from pathlib import Path
514
+ from safetensors import safe_open
515
+ import json
516
+
517
+ use_medusa = config .use_medusa
518
+
519
+ medusa_config = str (Path (use_medusa ) / "config.json" )
520
+ filename = str (Path (use_medusa ) / "medusa_lm_head.safetensors" )
521
+
522
+ with open (medusa_config , "r" ) as f :
523
+ medusa_config = json .load (f )
524
+ routing = weights .routing
525
+ with safe_open (filename , framework = "pytorch" ) as f :
526
+ for k in f .keys ():
527
+ if k in routing and routing [k ] != filename :
528
+ raise RuntimeError (
529
+ f"Key { k } was found in multiple files: { filename } and { routing [k ]} "
530
+ )
531
+ routing [k ] = filename
532
+
533
+ self .n_medusa_heads = medusa_config ["medusa_num_heads" ]
534
+
535
+ assert medusa_config ["medusa_num_layers" ] == 1
536
+ self .linear = TensorParallelColumnLinear .load_multi (
537
+ config ,
538
+ prefixes = [f"{ i } .0.linear" for i in range (self .n_medusa_heads )],
539
+ dim = 0 ,
540
+ weights = weights ,
541
+ bias = True ,
542
+ )
543
+ self .process_group = weights .process_group
544
+ self .world_size = self .process_group .size ()
545
+ self .rank = self .process_group .rank ()
546
+
547
+ self .act = torch .nn .SiLU ()
548
+
549
+ self .lm_head = TensorParallelHead .load (config , prefix , weights )
550
+
551
+ def forward (self , x ):
552
+ size = x .shape [- 1 ]
553
+ block_size = (size + self .world_size - 1 ) // self .world_size
554
+ start = self .rank * block_size
555
+ stop = (self .rank + 1 ) * block_size
556
+
557
+ x_block = x [:, start :stop ]
558
+
559
+ # Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1
560
+ medusa_res = self .act (self .linear (x )).reshape (
561
+ * x_block .shape [:- 1 ], self .n_medusa_heads , x_block .shape [- 1 ]
562
+ )
563
+
564
+ # Apply all residual medusa heads
565
+ output = x [:, start :stop ].unsqueeze (- 2 ) + medusa_res
566
+
567
+ # Gather medusa heads
568
+ world_output = [
569
+ torch .empty_like (output ) for _ in range (self .process_group .size ())
570
+ ]
571
+ torch .distributed .all_gather (world_output , output , group = self .process_group )
572
+ world_output = torch .cat (world_output , dim = - 1 )
573
+
574
+ # Stack x and medusa residual x
575
+ stacked_x = torch .cat ([x .unsqueeze (- 2 ), world_output ], dim = - 2 )
576
+
577
+ # Compute lm head on x + medusa residual x
578
+ logits = self .lm_head (stacked_x )
579
+
580
+ # Finally, split logits from speculative logits
581
+ logits , speculative_logits = torch .split (
582
+ logits , [1 , self .n_medusa_heads ], dim = - 2
583
+ )
584
+ # Squeeze added dimension
585
+ logits = logits .squeeze (- 2 )
586
+
587
+ return logits , speculative_logits
588
+
589
+
590
+ class SpeculativeHead (nn .Module ):
591
+ def __init__ (self , lm_head , medusa ):
592
+ super ().__init__ ()
593
+ self .head = lm_head
594
+ self .medusa = medusa
595
+
596
+ @staticmethod
597
+ def load (config , prefix : str , weights ):
479
598
use_medusa = config .use_medusa
480
599
if use_medusa :
481
- from pathlib import Path
482
- from safetensors import safe_open
483
- import json
484
-
485
- medusa_config = str (Path (use_medusa ) / "config.json" )
486
- filename = str (Path (use_medusa ) / "medusa_lm_head.safetensors" )
487
-
488
- with open (medusa_config , "r" ) as f :
489
- config = json .load (f )
490
- routing = weights .routing
491
- with safe_open (filename , framework = "pytorch" ) as f :
492
- for k in f .keys ():
493
- if k in routing :
494
- raise RuntimeError (
495
- f"Key { k } was found in multiple files: { filename } and { routing [k ]} "
496
- )
497
- weights .routing [k ] = filename
498
-
499
- medusa = MedusaModel (config , weights )
600
+ lm_head = None
601
+ try :
602
+ medusa = MedusaHeadV1 .load (config , prefix , weights )
603
+ except :
604
+ medusa = MedusaHeadV2 (config , prefix , weights )
500
605
else :
606
+ lm_head = TensorParallelHead .load (config , prefix , weights )
501
607
medusa = None
502
608
return SpeculativeHead (lm_head , medusa )
503
609
504
610
def forward (
505
611
self , input : torch .Tensor
506
612
) -> Tuple [torch .Tensor , Optional [torch .Tensor ]]:
507
- logits = self .lm_head (input )
508
- speculative_logits = self .medusa (input ) if self .medusa is not None else None
509
- return logits , speculative_logits
613
+ if self .medusa is not None :
614
+ return self .medusa (input )
615
+
616
+ assert self .head is not None
617
+ logits = self .head (input )
618
+ return logits , None
510
619
511
620
512
621
class TensorParallelHead (SuperLayer ):
0 commit comments