Skip to content

Commit

Permalink
Fix: Mitigate unused config attr tests by explicit usage.
Browse files Browse the repository at this point in the history
  • Loading branch information
kpoeppel committed Dec 21, 2024
1 parent 827133b commit f2e77c0
Showing 1 changed file with 35 additions and 1 deletion.
36 changes: 35 additions & 1 deletion src/transformers/models/xlstm/modeling_xlstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
mLSTMBlock,
mLSTMStateType,
soft_cap,
xLSTMLargeConfig,
)
else:
mLSTMBlock = None
xLSTMLargeConfig = None

from .configuration_xlstm import xLSTMConfig

Expand Down Expand Up @@ -213,7 +215,39 @@ def __init__(self, config):
super().__init__(config)

self.embeddings = nn.Embedding(config.vocab_size, config.embedding_dim)
self.blocks = nn.ModuleList([mLSTMBlock(config.to_xlstm_block_config()) for _ in range(config.num_blocks)])
# use config explicitly to mitigate unused variable tests
xlstm_block_config = xLSTMLargeConfig(
vocab_size=config.vocab_size,
embedding_dim=config.embedding_dim,
num_blocks=config.num_blocks,
num_heads=config.num_heads,
use_bias=config.use_bias,
add_out_norm=config.add_out_norm,
norm_eps=config.norm_eps,
norm_reduction_force_float32=config.norm_reduction_force_float32,
# mlstm_layer
qk_dim_factor=config.qk_dim_factor,
v_dim_factor=config.v_dim_factor,
# mlstm backend
chunkwise_kernel=config.chunkwise_kernel,
sequence_kernel=config.sequence_kernel,
step_kernel=config.step_kernel,
mode=config.mode,
chunk_size=config.chunk_size,
return_last_states=config.return_last_states,
autocast_kernel_dtype=config.autocast_kernel_dtype,
eps=config.eps,
inference_state_dtype=config.inference_state_dtype,
# feedforward
ffn_proj_factor=config.ffn_proj_factor,
ffn_round_up_to_multiple_of=config.ffn_round_up_to_multiple_of,
# capping
gate_soft_cap=config.gate_soft_cap,
output_logit_soft_cap=config.output_logit_soft_cap,
weight_mode=config.weight_mode,
)

self.blocks = nn.ModuleList([mLSTMBlock(xlstm_block_config) for _ in range(config.num_blocks)])

self.gradient_checkpointing = False
self.out_norm = RMSNorm(config.embedding_dim, eps=config.norm_eps)
Expand Down

0 comments on commit f2e77c0

Please sign in to comment.