@@ -507,12 +507,28 @@ class GeneralModelForCausalLMPipe(PipelinePretrainedModel, PipelineLayer):
507
507
transpose_weight_keys = None
508
508
_embed_cls = None
509
509
_rotary_emb_cls = None
510
+ _mtp_layer_pipe_cls = None
511
+ _embedding_pipe_cls = None
512
+ _decoder_layer_pipe_cls = None
513
+ _criterion_pipe_cls = None
514
+ _lmhead_pipe_cls = None
515
+ _rms_norm_pipe_cls = None
510
516
511
517
def __init__ (self , config : PretrainedConfig , ** kwargs ):
512
518
# dynamic inherit DecoderLayer
513
519
if self ._decoder_layer_cls is None :
514
520
raise ValueError ("_decoder_layer_cls must be set before init." )
515
- DecoderLayerPipe = make_decoder_layer_pipe (self ._decoder_layer_cls )
521
+
522
+ EmbeddingPipeCls = self ._embedding_pipe_cls if self ._embedding_pipe_cls is not None else Embedding
523
+
524
+ if self ._decoder_layer_pipe_cls is None :
525
+ DecoderLayerPipe = make_decoder_layer_pipe (self ._decoder_layer_cls )
526
+ else :
527
+ DecoderLayerPipe = self ._decoder_layer_pipe_cls
528
+
529
+ LMHeadPipeCls = self ._lmhead_pipe_cls if self ._lmhead_pipe_cls is not None else LMHeadPipe
530
+ MTPLayerPipeCls = self ._mtp_layer_pipe_cls if self ._mtp_layer_pipe_cls is not None else None
531
+ RMSNormPipeCls = self ._rms_norm_pipe_cls if self ._rms_norm_pipe_cls is not None else RMSNormPipe
516
532
517
533
new_initializer_range = math .sqrt (0.3333 / config .hidden_size )
518
534
logger .info (f"change initializer-range from { config .initializer_range } to { new_initializer_range } " )
@@ -559,7 +575,7 @@ def __init__(self, config: PretrainedConfig, **kwargs):
559
575
else :
560
576
self .add_sequential_layer (
561
577
LayerDesc (
562
- EmbeddingPipe , config = config , embed_cls = self ._embed_cls , rotary_emb_cls = self ._rotary_emb_cls
578
+ EmbeddingPipeCls , config = config , embed_cls = self ._embed_cls , rotary_emb_cls = self ._rotary_emb_cls
563
579
),
564
580
"model" ,
565
581
)
@@ -573,6 +589,12 @@ def __init__(self, config: PretrainedConfig, **kwargs):
573
589
),
574
590
f"model.layers.{ i } " ,
575
591
)
592
+ for i in range (config .num_nextn_predict_layers ):
593
+ if MTPLayerPipeCls is not None :
594
+ self .add_sequential_layer (
595
+ LayerDesc (MTPLayerPipeCls , config = config , layer_idx = config .num_hidden_layers + i ),
596
+ f"model.layers.{ config .num_hidden_layers + i } " ,
597
+ )
576
598
for i in range (config .add_tail_layers ):
577
599
self .add_sequential_layer (
578
600
LayerDesc (
@@ -582,22 +604,22 @@ def __init__(self, config: PretrainedConfig, **kwargs):
582
604
)
583
605
584
606
self .add_sequential_layer (
585
- LayerDesc (RMSNormPipe if config .use_rmsnorm else LayerNormPipe , config = config ),
607
+ LayerDesc (RMSNormPipeCls if config .use_rmsnorm else LayerNormPipe , config = config ),
586
608
"model.norm" ,
587
609
)
588
610
589
611
if config .tie_word_embeddings :
590
612
self .add_sequential_layer (
591
613
SharedLayerDesc (
592
614
"model_shared_weight" ,
593
- LMHeadPipe ,
615
+ LMHeadPipeCls ,
594
616
shared_weight_attr = "embedding_weight" ,
595
617
config = config ,
596
618
),
597
619
"lm_head" ,
598
620
)
599
621
else :
600
- self .add_sequential_layer (LayerDesc (LMHeadPipe , config = config ), "lm_head" )
622
+ self .add_sequential_layer (LayerDesc (LMHeadPipeCls , config = config ), "lm_head" )
601
623
recompute_interval = 0
602
624
603
625
seg_method = config .pp_seg_method if hasattr (config , "pp_seg_method" ) else "layer:DecoderLayer|EmptyLayer"
@@ -630,10 +652,12 @@ def __init__(self, config: PretrainedConfig, **kwargs):
630
652
)
631
653
632
654
def get_loss_fn (self , config ):
655
+ CriterionPipeCls = self ._criterion_pipe_cls if self ._criterion_pipe_cls is not None else CriterionLayerPipe
656
+
633
657
if config .get ("dpo_config" , None ) is not None :
634
- loss_fn = CriterionLayerPipe (config , use_infohub = True )
658
+ loss_fn = CriterionPipeCls (config , use_infohub = True )
635
659
else :
636
- loss_fn = CriterionLayerPipe (config )
660
+ loss_fn = CriterionPipeCls (config )
637
661
638
662
return loss_fn
639
663
0 commit comments