diff --git a/src/transformers/models/xlstm/modeling_xlstm.py b/src/transformers/models/xlstm/modeling_xlstm.py index 135ea4f97fb0a3..de1fe232ddc7e8 100644 --- a/src/transformers/models/xlstm/modeling_xlstm.py +++ b/src/transformers/models/xlstm/modeling_xlstm.py @@ -26,9 +26,11 @@ mLSTMBlock, mLSTMStateType, soft_cap, + xLSTMLargeConfig, ) else: mLSTMBlock = None + xLSTMLargeConfig = None from .configuration_xlstm import xLSTMConfig @@ -213,7 +215,39 @@ def __init__(self, config): super().__init__(config) self.embeddings = nn.Embedding(config.vocab_size, config.embedding_dim) - self.blocks = nn.ModuleList([mLSTMBlock(config.to_xlstm_block_config()) for _ in range(config.num_blocks)]) + # use config explicitly to mitigate unused variable tests + xlstm_block_config = xLSTMLargeConfig( + vocab_size=config.vocab_size, + embedding_dim=config.embedding_dim, + num_blocks=config.num_blocks, + num_heads=config.num_heads, + use_bias=config.use_bias, + add_out_norm=config.add_out_norm, + norm_eps=config.norm_eps, + norm_reduction_force_float32=config.norm_reduction_force_float32, + # mlstm_layer + qk_dim_factor=config.qk_dim_factor, + v_dim_factor=config.v_dim_factor, + # mlstm backend + chunkwise_kernel=config.chunkwise_kernel, + sequence_kernel=config.sequence_kernel, + step_kernel=config.step_kernel, + mode=config.mode, + chunk_size=config.chunk_size, + return_last_states=config.return_last_states, + autocast_kernel_dtype=config.autocast_kernel_dtype, + eps=config.eps, + inference_state_dtype=config.inference_state_dtype, + # feedforward + ffn_proj_factor=config.ffn_proj_factor, + ffn_round_up_to_multiple_of=config.ffn_round_up_to_multiple_of, + # capping + gate_soft_cap=config.gate_soft_cap, + output_logit_soft_cap=config.output_logit_soft_cap, + weight_mode=config.weight_mode, + ) + + self.blocks = nn.ModuleList([mLSTMBlock(xlstm_block_config) for _ in range(config.num_blocks)]) self.gradient_checkpointing = False self.out_norm = RMSNorm(config.embedding_dim, eps=config.norm_eps)