Skip to content

Commit

Permalink
text_modeling_zamba2
Browse files Browse the repository at this point in the history
  • Loading branch information
pglorio committed Nov 9, 2024
1 parent 979b99b commit 9d9b2eb
Show file tree
Hide file tree
Showing 6 changed files with 791 additions and 104 deletions.
112 changes: 56 additions & 56 deletions src/transformers/models/zamba2/configuration_zamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,48 @@ class Zamba2Config(PretrainedConfig):
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the Zamba2 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Zamba2Model`]
max_position_embeddings (`int`, *optional*, defaults to 4096):
The maximum sequence length that this model might ever be used with.
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
model has a output word embedding layer.
hidden_size (`int`, *optional*, defaults to 2560):
Dimension of the hidden representations.
ffn_hidden_size (`int`, *optional*, defaults to 4 * hidden_size):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 54):
Number of hidden layers in the model.
layers_block_type (`list`, *optional*):
List of layer types, which can be either "mamba" or "hybrid".
mamba_d_state (`int`, *optional*, defaults to 64): shape of the state space latents.
mamba_d_conv (`int`, *optional*, defaults to 4): Size of the convolution kernel.
mamba_expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size.
mamba_ngroups (`int`, *optional*, defaults to 8):
Number of groups for the evolution matrices of mamba 2.
time_step_min (`float`, *optional*, defaults to 0.001):
Minimum `time_step` used to bound `dt_proj.bias`.
time_step_max (`float`, *optional*, defaults to 0.1):
Maximum `time_step` used to bound `dt_proj.bias`.
time_step_floor (`float`, *optional*, defaults to 0.0001):
Minimum clamping value of the `dt_proj.bias` layer initialization.
time_step_limit (`tuple`, *optional*):
Accepted range of time step values.
mamba_dt_rank (`Union[int,str]`, *optional*, defaults to `"auto"`):
Rank of the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)`
n_mamba_heads (`int`, *optional*, defaults to 1):
Number of heads for the evolution matrices of mamba 2.
use_conv_bias (`bool`, *optional*, defaults to `True`):
Whether or not to use bias in the convolution layer of the mixer block.
mamba_proj_bias (`bool`, *optional*, defaults to `False`):
Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block
hidden_mamba_act (`str`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) adjacent to the mamba conv.
chunk_size (`int`, *optional*, defaults to 256):
Size of the chunks that will comprise the sequence.
add_bias_linear (`bool`, *optional*, defaults to `False`):
Flag indicating whether or not to use bias in various layers
intermediate_size (`int`, *optional*, defaults to 4 * hidden_size):
Dimension of the MLP representations.
hidden_act (`str`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the MLP.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
Expand All @@ -56,11 +89,23 @@ class Zamba2Config(PretrainedConfig):
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf).
mamba_headdim (`<fill_type>`, *optional*, defaults to 64):
dimension of each Mamba2 heads (number of heads is set to 1 in this implementation).
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
num_mem_blocks (`int`, *optional*, defaults to 1):
Number of unshared transformer blocks.
use_shared_block_lora (`bool`, *optional*, defaults to `False`):
If True, unshared LoRA's will be added to the shared MLP's.
use_shared_attention_lora (`bool`, *optional*, defaults to `False`):
If True, unshared LoRA's will be added to the q, k, v projectors in the shared attention layers.
lora_rank (`int`, *optional*, defaults to 128):
Rank of the LoRA in the shared MLP and shared attention layers.
use_mem_rope (`bool`, *optional*, defaults to `False`):
If True, includes RoPE in the shared attention layers.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
rms_norm_eps (`float`, *optional*, defaults to 1e-5):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
Expand All @@ -77,29 +122,6 @@ class Zamba2Config(PretrainedConfig):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 2):
The id of the "end-of-sequence" token.
sliding_window (`int`, *optional*):
Sliding window attention window size. If not specified, will default to `None`.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
use_mamba_kernels (`bool`, *optional*, defaults to `True`):
Flag indicating whether or not to use the fast mamba kernels. These are available only if `mamba-ssm` and
`causal-conv1d` are installed, and the mamba modules are running on a CUDA device. Raises ValueError if
`True` and kernels are not available
state_size (`int`, *optional*, defaults to 16):
The dimension the mamba state space latents
mamba_d_conv (`int`, *optional*, defaults to 4):
The size of the mamba convolution kernel
mamba_expand (`int`, *optional*, defaults to 2):
Expanding factor (relative to hidden_size) used to determine the mamba intermediate size
add_bias_linear (`bool`, *optional*, defaults to `False`):
Flag indicating whether or not to use bias in various layers
gated_linear_unit (`bool`, *optional*, defaults to `True`):
Flag indicating whether or not to use gated MLP
use_shared_block_lora (`bool`, *optional*, defaults to `False`):
Flag indicating whether or not to add (unshared) LoRA modules to the first layer of the MLP
inside the shared transformer blocks
state_size (`int`, *optional*, defaults to 128):
The rank of the LoRA modules inside the MLP of the shared transformer blocks
"""

model_type = "zamba2"
Expand All @@ -116,37 +138,29 @@ def __init__(
mamba_d_state=64,
mamba_d_conv=4,
mamba_expand=2,
mamba_headdim=64,
mamba_ngroups=1,
time_step_min=0.001,
time_step_max=0.1,
time_step_floor=1e-4,
time_step_limit=(0.0, float("inf")),
time_step_limit=None,
mamba_dt_rank="auto",
n_mamba_heads=1,
mamba_conv_bias=True,
mamba_proj_bias=False,
hidden_mamba_act="silu",
use_mamba_kernels=True,
use_conv_bias=True,
chunk_size=256,
add_bias_linear=False,
intermediate_size=None,
gated_linear_unit=True,
hidden_act="gelu",
num_attention_heads=32,
num_key_value_heads=None,
sliding_window=None,
attention_dropout=0.0,
num_mem_blocks=1,
use_shared_block_lora=True,
use_shared_block_lora=False,
use_shared_attention_lora=False,
lora_rank=128,
use_mem_eff_path=True,
use_mem_rope=False,
rope_theta=10000,
attention_hidden_size=None,
attention_head_dim=None,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
Expand Down Expand Up @@ -176,36 +190,26 @@ def __init__(
self.hidden_act = hidden_act
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.sliding_window = sliding_window
self.num_mem_blocks = num_mem_blocks
self.use_mem_rope = use_mem_rope
self.rope_theta = rope_theta
if attention_hidden_size is None:
self.attention_hidden_size = 2 * hidden_size
else:
self.attention_hidden_size = attention_hidden_size
if attention_head_dim is None:
self.attention_head_dim = 2 * self.hidden_size // self.num_attention_heads
else:
self.attention_head_dim = attention_head_dim
self.attention_hidden_size = 2 * hidden_size
self.attention_head_dim = 2 * self.hidden_size // self.num_attention_heads
self.attention_dropout = attention_dropout
self.mamba_d_state = mamba_d_state
self.mamba_d_conv = mamba_d_conv
self.mamba_expand = mamba_expand
self.mamba_dt_rank = math.ceil(self.hidden_size / 16) if mamba_dt_rank == "auto" else mamba_dt_rank
self.add_bias_linear = add_bias_linear
self.mamba_headdim = mamba_headdim
self.mamba_headdim = int(mamba_expand * hidden_size) // n_mamba_heads
self.mamba_ngroups = mamba_ngroups
self.n_mamba_heads = n_mamba_heads
self.mamba_conv_bias = mamba_conv_bias
self.mamba_proj_bias = mamba_proj_bias
self.hidden_mamba_act = hidden_mamba_act
self.use_mamba_kernels = use_mamba_kernels
self.use_conv_bias = use_conv_bias
self.chunk_size = chunk_size
self.time_step_limit = time_step_limit

self.gated_linear_unit = gated_linear_unit
self.use_shared_block_lora = use_shared_block_lora
self.use_shared_attention_lora = use_shared_attention_lora
self.lora_rank = lora_rank
Expand All @@ -227,12 +231,8 @@ def __init__(
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps

if intermediate_size is None:
self.ffn_hidden_size = 4 * self.hidden_size

self.use_cache = use_cache
self.num_logits_to_keep = num_logits_to_keep
self.use_mem_eff_path = use_mem_eff_path

# Below, "mamba" stands for mamba layer, "hybrid" stands for hybrid layer (composed by a shared transformer followed by mamba layer)
if layers_block_type is None:
Expand All @@ -246,4 +246,4 @@ def __init__(
+ ["mamba"] * 2
)
else:
self.layers_block_type = layers_block_type
self.layers_block_type = layers_block_type
61 changes: 43 additions & 18 deletions src/transformers/models/zamba2/modeling_zamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def __init__(
):
self.layers_block_type = config.layers_block_type
self.transformer_layers = []
self._modules = {}
self._parameters = {}
self._buffers = {}

self.has_previous_state = False
self.dtype = dtype
Expand Down Expand Up @@ -431,7 +434,6 @@ def forward(
bsz, q_len, _ = hidden_states.size()

if self.config.use_shared_attention_lora:
layer_idx = self.layer_dic[layer_idx]
lora_layer_idx = self.layer_dic[layer_idx]
linear_q_lora_A = self.linear_q_lora_A_list[lora_layer_idx]
linear_q_lora_B = self.linear_q_lora_B_list[lora_layer_idx]
Expand Down Expand Up @@ -810,7 +812,7 @@ def __init__(self, config: Zamba2Config, layer_idx: int = None):

self.n_groups = config.mamba_ngroups
self.head_dim = config.mamba_headdim
self.num_heads = self.intermediate_size // self.head_dim
self.num_heads = self.config.n_mamba_heads
self.chunk_size = config.chunk_size

self.time_step_limit = config.time_step_limit # add this with default (0.0, float("inf"))
Expand Down Expand Up @@ -924,7 +926,7 @@ def cuda_kernels_forward(
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)
A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size)
dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit}
dt_limit_kwargs = {} if self.time_step_limit == None else {"dt_limit": self.time_step_limit}
if attention_mask is not None:
input_not_masked = torch.all(attention_mask == 1)
else:
Expand Down Expand Up @@ -1551,9 +1553,9 @@ def _init_weights(self, module):
module.A_log._no_weight_decay = True
module.D._no_weight_decay = True

num_heads = int(self.config.mamba_expand * self.config.hidden_size) // self.config.mamba_headdim
# num_heads = int(self.config.mamba_expand * self.config.hidden_size) // self.config.mamba_headdim
dt = torch.exp(
torch.rand(num_heads) * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
torch.rand(self.config.n_mamba_heads) * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
+ math.log(self.config.time_step_min)
).clamp(min=self.config.time_step_floor)
# # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
Expand Down Expand Up @@ -1671,19 +1673,42 @@ def __init__(self, config: Zamba2Config):
self._tied_weights_keys = []
for layer_id, layer_type in enumerate(self.layers_block_type):
if layer_type == "hybrid":
prefix_name = f"layers.{layer_id}."
tied_keys = [
"shared_transf.self_attn.q_proj.weight",
"shared_transf.self_attn.k_proj.weight",
"shared_transf.self_attn.v_proj.weight",
"shared_transf.self_attn.o_proj.weight",
"shared_transf.feed_forward.gate_proj.weight",
"shared_transf.feed_forward.up_proj.weight",
"shared_transf.feed_forward.down_proj.weight",
"shared_transf.input_layernorm.weight",
"shared_transf.pre_ff_layernorm.weight",
]
self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]]
block = next(blocks)
if config.num_mem_blocks * len(layer_type_list(config)) > 1:
prefix_name = f"layers.{layer_id}."
tied_keys = [
"shared_transf.self_attn.q_proj.weight",
"shared_transf.self_attn.k_proj.weight",
"shared_transf.self_attn.v_proj.weight",
"shared_transf.self_attn.o_proj.weight",
"shared_transf.feed_forward.gate_up_proj.weight",
"shared_transf.feed_forward.down_proj.weight",
"shared_transf.input_layernorm.weight",
"shared_transf.pre_ff_layernorm.weight",
]
self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]]
if config.use_shared_block_lora:
tied_keys_lora = []
lora_id = 0
for _layer_type in self.layers_block_type:
if _layer_type == "hybrid" and lora_id % config.num_mem_blocks == block.block_id:
tied_keys_lora.append('shared_transf.feed_forward.gate_up_proj_lora_A_list.' + str(lora_id) + '.weight')
tied_keys_lora.append('shared_transf.feed_forward.gate_up_proj_lora_B_list.' + str(lora_id) + '.weight')
lora_id += 1
self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_lora]
if config.use_shared_attention_lora:
tied_keys_lora = []
lora_id = 0
for _layer_type in self.layers_block_type:
if _layer_type == "hybrid" and lora_id % config.num_mem_blocks == block.block_id:
tied_keys_lora.append('shared_transf.self_attn.linear_q_lora_A_list.' + str(lora_id) + '.weight')
tied_keys_lora.append('shared_transf.self_attn.linear_k_lora_A_list.' + str(lora_id) + '.weight')
tied_keys_lora.append('shared_transf.self_attn.linear_v_lora_A_list.' + str(lora_id) + '.weight')
tied_keys_lora.append('shared_transf.self_attn.linear_q_lora_B_list.' + str(lora_id) + '.weight')
tied_keys_lora.append('shared_transf.self_attn.linear_k_lora_B_list.' + str(lora_id) + '.weight')
tied_keys_lora.append('shared_transf.self_attn.linear_v_lora_B_list.' + str(lora_id) + '.weight')
lora_id += 1
self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_lora]
layers.append(Zamba2HybridLayer(next(blocks), next(linear_layers), next(mamba_layers)))
else:
layers.append(next(mamba_layers))
Expand Down
Loading

0 comments on commit 9d9b2eb

Please sign in to comment.