88from torch import nn , einsum
99from torch .cuda .amp import autocast
1010from typing import Callable , Literal
11- from soft_moe_pytorch import SoftMoE
1211
1312try :
1413 from flash_attn import flash_attn_func , flash_attn_kvpacked_func
@@ -578,7 +577,6 @@ def __init__(
578577 conformer = False ,
579578 layer_ix = - 1 ,
580579 remove_norms = False ,
581- number_of_experts = 1 ,
582580 attn_kwargs = {},
583581 ff_kwargs = {},
584582 norm_kwargs = {}
@@ -613,10 +611,7 @@ def __init__(
613611 )
614612
615613 self .ff_norm = LayerNorm (dim , ** norm_kwargs ) if not remove_norms else nn .Identity ()
616- if number_of_experts > 1 :
617- self .ff = SoftMoE (dim = dim , seq_len = 1500 , num_experts = number_of_experts , geglu = True )
618- else :
619- self .ff = FeedForward (dim , zero_init_output = zero_init_branch_outputs , ** ff_kwargs )
614+ self .ff = FeedForward (dim , zero_init_output = zero_init_branch_outputs , ** ff_kwargs )
620615
621616 self .layer_ix = layer_ix
622617
@@ -700,7 +695,6 @@ def __init__(
700695 use_sinusoidal_emb = False ,
701696 use_abs_pos_emb = False ,
702697 abs_pos_emb_max_length = 10000 ,
703- number_of_experts = 1 ,
704698 ** kwargs
705699 ):
706700
@@ -739,7 +733,6 @@ def __init__(
739733 zero_init_branch_outputs = zero_init_branch_outputs ,
740734 conformer = conformer ,
741735 layer_ix = i ,
742- number_of_experts = number_of_experts ,
743736 ** kwargs
744737 )
745738 )
@@ -787,4 +780,4 @@ def forward(
787780
788781 x = self .project_out (x )
789782
790- return x
783+ return x
0 commit comments