Skip to content

Commit 36366ef

Browse files
committed
2 parents 82f9af9 + ddf2504 commit 36366ef

File tree

1 file changed

+2
-9
lines changed

1 file changed

+2
-9
lines changed

stable_audio_tools/models/transformer.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from torch import nn, einsum
99
from torch.cuda.amp import autocast
1010
from typing import Callable, Literal
11-
from soft_moe_pytorch import SoftMoE
1211

1312
try:
1413
from flash_attn import flash_attn_func, flash_attn_kvpacked_func
@@ -578,7 +577,6 @@ def __init__(
578577
conformer = False,
579578
layer_ix = -1,
580579
remove_norms = False,
581-
number_of_experts = 1,
582580
attn_kwargs = {},
583581
ff_kwargs = {},
584582
norm_kwargs = {}
@@ -613,10 +611,7 @@ def __init__(
613611
)
614612

615613
self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
616-
if number_of_experts > 1:
617-
self.ff = SoftMoE(dim = dim, seq_len = 1500, num_experts= number_of_experts, geglu = True)
618-
else:
619-
self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs)
614+
self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs)
620615

621616
self.layer_ix = layer_ix
622617

@@ -700,7 +695,6 @@ def __init__(
700695
use_sinusoidal_emb=False,
701696
use_abs_pos_emb=False,
702697
abs_pos_emb_max_length=10000,
703-
number_of_experts = 1,
704698
**kwargs
705699
):
706700

@@ -739,7 +733,6 @@ def __init__(
739733
zero_init_branch_outputs = zero_init_branch_outputs,
740734
conformer=conformer,
741735
layer_ix=i,
742-
number_of_experts=number_of_experts,
743736
**kwargs
744737
)
745738
)
@@ -787,4 +780,4 @@ def forward(
787780

788781
x = self.project_out(x)
789782

790-
return x
783+
return x

0 commit comments

Comments
 (0)