11import torch
2+ import math
23from torch import nn
34from x_transformers import ContinuousTransformerWrapper , Decoder
5+ from functools import partial
46
57from mamba_ssm .utils .generation import InferenceParams
68from .transformer import ContinuousTransformer
9+ from .mambaplus .mamba import MambaPlus , MambaPlusConfig
710
811# Interface for backbone of a language model
912# Handles conditioning and cross-attention
@@ -253,4 +256,90 @@ def forward(self, x, mask=None, prepend_cond=None, prepend_cond_mask=None, cross
253256 self .cuda_graph .replay ()
254257 return self .captured_logits .clone ()
255258
256- return self .model (x , inference_params = self .inference_params if use_cache else None )[:, prepend_length :, :]
259+ return self .model (x , inference_params = self .inference_params if use_cache else None )[:, prepend_length :, :]
260+
261+
262+ def _init_weights (
263+ module ,
264+ n_layer ,
265+ initializer_range = 0.02 , # Now only used for embedding layer.
266+ rescale_prenorm_residual = True ,
267+ n_residuals_per_layer = 1 , # Change to 2 if we have MLP
268+ ):
269+ if isinstance (module , nn .Linear ):
270+ if module .bias is not None :
271+ if not getattr (module .bias , "_no_reinit" , False ):
272+ nn .init .zeros_ (module .bias )
273+ elif isinstance (module , nn .Embedding ):
274+ nn .init .normal_ (module .weight , std = initializer_range )
275+
276+ if rescale_prenorm_residual :
277+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
278+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
279+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
280+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
281+ #
282+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
283+ for name , p in module .named_parameters ():
284+ if name in ["out_proj.weight" , "fc2.weight" ]:
285+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
286+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
287+ # We need to reinit p since this code could be called multiple times
288+ # Having just p *= scale would repeatedly scale it down
289+ nn .init .kaiming_uniform_ (p , a = math .sqrt (5 ))
290+ with torch .no_grad ():
291+ p /= math .sqrt (n_residuals_per_layer * n_layer )
292+
293+ class MambaPlusAudioLMBackbone (AudioLMBackbone ):
294+ def __init__ (self ,
295+ embed_dim : int = 512 ,
296+ n_layers : int = 32 ,
297+ d_state : int = 1 ,
298+ bidirectional : bool = False ,
299+ num_mod_groups : int = 128 ,
300+ cross_attn_cond_dim : int = 0 ,
301+ prepend_cond_dim : int = 0 ,
302+ ** kwargs ):
303+ super ().__init__ (embed_dim = embed_dim )
304+
305+ self .config = MambaPlusConfig (d_model = embed_dim ,
306+ n_layers = n_layers ,
307+ d_state = d_state ,
308+ expand_factor = 2 ,
309+ num_mod_groups = num_mod_groups ,
310+ complex = True ,
311+ mamba_plus_enabled = True ,
312+ bidirectional = bidirectional ,
313+ ** kwargs )
314+
315+ # Embeddings are done in the AudioLanguageModel, so we use the continuous-input transformer
316+ self .model = MambaPlus (
317+ config = self .config ,
318+ ** kwargs
319+ )
320+ self .apply (
321+ partial (
322+ _init_weights ,
323+ n_layer = self .config .n_layers
324+ )
325+ )
326+
327+ if prepend_cond_dim > 0 :
328+ # Prepend conditioning
329+ self .to_prepend_embed = nn .Sequential (
330+ nn .Linear (prepend_cond_dim , embed_dim , bias = False )
331+ )
332+
333+ assert (cross_attn_cond_dim == 0 , "Cross-attention conditioning not supported for MambaPlus" )
334+
335+ def forward (self , x , mask = None , prepend_cond = None , prepend_cond_mask = None , cross_attn_cond = None , use_cache = False ):
336+
337+ prepend_length = 0
338+ if prepend_cond is not None :
339+ # Project the prepend conditioning to the embedding dimension
340+ prepend_cond = self .to_prepend_embed (prepend_cond )
341+ prepend_length = prepend_cond .shape [1 ]
342+
343+ x = torch .cat ([prepend_cond , x ], dim = 1 )
344+ return self .model (x )[:, prepend_length :, :]
345+
0 commit comments