Skip to content

Qwen2 converter #163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Mar 4, 2025
4 changes: 4 additions & 0 deletions fast_llm/models/gpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class Starcoder2GPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
class LlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
name: typing.ClassVar[str] = "llama"

class Qwen2GPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
name: typing.ClassVar[str] = "qwen2"


class MistralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
name: typing.ClassVar[str] = "mistral"
Expand Down Expand Up @@ -98,6 +101,7 @@ class GPTModelConfig(FastLLMModelConfig):
AutoGPTHuggingfaceCheckpointFormat,
Starcoder2GPTHuggingfaceCheckpointFormat,
LlamaGPTHuggingfaceCheckpointFormat,
Qwen2GPTHuggingfaceCheckpointFormat,
MistralGPTHuggingfaceCheckpointFormat,
MixtralGPTHuggingfaceCheckpointFormat,
)
Expand Down
112 changes: 99 additions & 13 deletions fast_llm/models/gpt/conversion.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import abc
import dataclasses
import logging
import typing

import torch

from fast_llm.config import DEFAULT
from fast_llm.config import DEFAULT, MISSING
from fast_llm.engine.checkpoint.config import CheckpointFormat
from fast_llm.engine.checkpoint.external import (
AutoStateDictCheckpointHandler,
Expand All @@ -23,11 +24,12 @@
from fast_llm.functional.config import ActivationType
from fast_llm.functional.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex
from fast_llm.layers.common.config import NormalizationType
from fast_llm.layers.transformer.config import RotaryEmbeddingType, RoutingType
from fast_llm.layers.transformer.config import RotaryEmbeddingType, RoutingType, TransformerConfig
from fast_llm.models.gpt.config import (
GPTArchitectureConfig,
GPTModelConfig,
LlamaGPTHuggingfaceCheckpointFormat,
Qwen2GPTHuggingfaceCheckpointFormat,
MistralGPTHuggingfaceCheckpointFormat,
MixtralGPTHuggingfaceCheckpointFormat,
Starcoder2GPTHuggingfaceCheckpointFormat,
Expand All @@ -39,6 +41,8 @@
if typing.TYPE_CHECKING:
pass

logger = logging.getLogger(__name__)


class QueryWeightConverter(WeightConverter):
# Hf uses the real format for rotary embeddings.
Expand Down Expand Up @@ -156,11 +160,14 @@ def _create_config_converters(cls) -> list[ParamConverter]:
def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]:
pass

def _create_weight_converters(self) -> list[WeightConverter]:

def _create_weight_converters(
self,
) -> list[WeightConverter]:
converters = []
num_layers = self._model.config.base_model.transformer.num_layers
norm_bias: bool = self._model.config.base_model.transformer.normalization.type == NormalizationType.layer_norm
linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases
transformer_config: TransformerConfig = self._model.config.base_model.transformer

# Embedding and output
if self._model.config.base_model.tie_word_embeddings:
Expand All @@ -180,17 +187,19 @@ def _create_weight_converters(self) -> list[WeightConverter]:
converters += self._get_weight_and_bias_converters(
f"layers.{i+1}.self_attn.query",
f"model.layers.{i}.self_attn.q_proj",
linear_bias,
transformer_config.add_attn_qkv_bias,
QueryWeightConverter,
)
converters += self._get_weight_and_bias_converters(
f"layers.{i+1}.self_attn.key_value",
(f"model.layers.{i}.self_attn.k_proj", f"model.layers.{i}.self_attn.v_proj"),
linear_bias,
transformer_config.add_attn_qkv_bias,
KeyValueWeightConverter,
)
converters += self._get_weight_and_bias_converters(
f"layers.{i+1}.self_attn.dense", f"model.layers.{i}.self_attn.o_proj", linear_bias
f"layers.{i+1}.self_attn.dense",
f"model.layers.{i}.self_attn.o_proj",
transformer_config.add_attn_dense_bias,
)

# Norm
Expand Down Expand Up @@ -256,13 +265,16 @@ def _create_config_converters(cls) -> list[ParamConverter]:
]

def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]:
linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases
transformer_config: TransformerConfig = self._model.config.base_model.transformer
return [
*self._get_weight_and_bias_converters(
f"{fast_llm_prefix}.mlp.layer_1", f"{hf_prefix}.mlp.c_fc", linear_bias
f"{fast_llm_prefix}.mlp.layer_1", f"{hf_prefix}.mlp.c_fc", transformer_config.add_mlp_bias
),
*self._get_weight_and_bias_converters(
f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.c_proj", linear_bias, MLPLayer2Converter
f"{fast_llm_prefix}.mlp.layer_2",
f"{hf_prefix}.mlp.c_proj",
transformer_config.add_mlp_bias,
MLPLayer2Converter,
),
]

Expand Down Expand Up @@ -352,18 +364,91 @@ def _create_config_converters(cls) -> list[ParamConverter]:
]

def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]:
linear_bias: bool = self._model.config.base_model.transformer.add_linear_biases
transformer_config: TransformerConfig = self._model.config.base_model.transformer
return [
*self._get_weight_and_bias_converters(
f"{fast_llm_prefix}.mlp.layer_1",
(f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"),
transformer_config.add_mlp_bias,
SplitWeightConverter,
),
*self._get_weight_and_bias_converters(
f"{fast_llm_prefix}.mlp.layer_2",
f"{hf_prefix}.mlp.down_proj",
transformer_config.add_mlp_bias,
MLPLayer2Converter,
),
]


@dataclasses.dataclass
class IgnoreImportQwen2SlidingWindowParamsConverter(ParamConverter):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bigximik this is fine, but can you please add a todo here that says that this is a temporary hack until we can load these params from the config?

def __post_init__(self):
Assert.eq(len(self.fast_llm_names), 0)
Assert.eq(len(self.export_names), 0)
self.export_names = (("use_sliding_window",), ("sliding_window",), ("max_window_layers",))

def export_params(self, fast_llm_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]:
return (MISSING, MISSING, MISSING)

def import_params(self, export_values: tuple[typing.Any, ...]) -> tuple[typing.Any, ...]:
# Default value for use_sliding_window in Qwen2 HF config is False
if export_values[0] != MISSING and export_values[0] == True:
logger.warning(
f"The configuration parameters `{self.export_names[0]}={export_values[0]}`,"
f" `{self.export_names[1]}={export_values[1]}`, `{self.export_names[2]}={export_values[2]}`"
f" are ignored during conversion."
f" If you intend to use them in Fast-LLM, make sure to set them explicitly in the model configuration."
)
return ()


class Qwen2HuggingfaceCheckpointHandler(CommonHuggingfaceCheckpointHandler):
format: typing.ClassVar[type[CheckpointFormat]] = Qwen2GPTHuggingfaceCheckpointFormat

@classmethod
def _create_config_converters(cls) -> list[ParamConverter]:
return super()._create_config_converters() + [
ConstantExportParamConverter(export_names=(("architectures",),), export_value=["Qwen2ForCausalLM"]),
ConstantImportParamConverter(
fast_llm_names=(("transformer", "normalization", "type"),), fast_llm_value=NormalizationType.rms_norm
),
RenameParamConverter(
fast_llm_names=(("transformer", "normalization", "epsilon"),), export_names=(("rms_norm_eps",),)
),
ConstantImportParamConverter(fast_llm_names=(("transformer", "gated"),), fast_llm_value=True),
ConstantImportParamConverter(
fast_llm_names=(("transformer", "add_linear_biases"),), fast_llm_value="only_attn_qkv"
),
RopeScalingParamConverter(
fast_llm_names=(
("transformer", "rotary", "type"),
("transformer", "rotary", "scale_factor"),
("transformer", "rotary", "low_frequency_factor"),
("transformer", "rotary", "high_frequency_factor"),
("transformer", "rotary", "original_context_length"),
("transformer", "rotary", "attention_factor"),
("transformer", "rotary", "beta_fast"),
("transformer", "rotary", "beta_slow"),
),
export_names=(("rope_scaling",),),
),
IgnoreImportQwen2SlidingWindowParamsConverter(),
]

def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[WeightConverter]:
transformer_config: TransformerConfig = self._model.config.base_model.transformer
return [
*self._get_weight_and_bias_converters(
f"{fast_llm_prefix}.mlp.layer_1",
(f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"),
linear_bias,
transformer_config.add_mlp_bias,
SplitWeightConverter,
),
*self._get_weight_and_bias_converters(
f"{fast_llm_prefix}.mlp.layer_2",
f"{hf_prefix}.mlp.down_proj",
linear_bias,
transformer_config.add_mlp_bias,
MLPLayer2Converter,
),
]
Expand Down Expand Up @@ -439,6 +524,7 @@ class AutoGPTHuggingfaceCheckpointHandler(
handler_map = {
Starcoder2GPTHuggingfaceCheckpointFormat.name: Starcoder2HuggingfaceCheckpointHandler,
LlamaGPTHuggingfaceCheckpointFormat.name: LlamaHuggingfaceCheckpointHandler,
Qwen2GPTHuggingfaceCheckpointFormat.name: Qwen2HuggingfaceCheckpointHandler,
MistralGPTHuggingfaceCheckpointFormat.name: MistralHuggingfaceCheckpointHandler,
MixtralGPTHuggingfaceCheckpointFormat.name: MixtralHuggingfaceCheckpointHandler,
}
20 changes: 20 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from fast_llm.data.dataset.gpt.sampled import GPTSample
from fast_llm.models.gpt.config import (
LlamaGPTHuggingfaceCheckpointFormat,
Qwen2GPTHuggingfaceCheckpointFormat,
MistralGPTHuggingfaceCheckpointFormat,
MixtralGPTHuggingfaceCheckpointFormat,
Starcoder2GPTHuggingfaceCheckpointFormat,
Expand Down Expand Up @@ -155,6 +156,18 @@
]
CONFIG_LLAMA3_COMMON = CONFIG_LLAMA3_FAST_LLM + ["model.distributed.training_dtype=bf16"]

# Megatron does not support per sub layer biases
CONFIG_QWEN2_MEGATRON = None
CONFIG_QWEN2_FAST_LLM = CONFIG_SC2_FAST_LLM + [
"model.base_model.transformer.gated=True",
"model.base_model.transformer.activation_type=silu",
"model.base_model.transformer.add_linear_biases=only_attn_qkv",
"model.base_model.transformer.normalization.type=rms_norm",
"model.base_model.transformer.ffn_hidden_size=1024",
"model.base_model.tie_word_embeddings=False",
]
CONFIG_QWEN2_COMMON = CONFIG_QWEN2_FAST_LLM + ["model.distributed.training_dtype=bf16"]

# Yarn-style Rotary Embeddings
CONFIG_LLAMA_YARN_MEGATRON = None
CONFIG_LLAMA_YARN_FAST_LLM = CONFIG_LLAMA_FAST_LLM + [
Expand Down Expand Up @@ -202,6 +215,13 @@
CONFIG_LLAMA3_COMMON,
LlamaGPTHuggingfaceCheckpointFormat,
),
"qwen2": (
"gpt",
CONFIG_QWEN2_FAST_LLM,
CONFIG_QWEN2_MEGATRON,
CONFIG_QWEN2_COMMON,
Qwen2GPTHuggingfaceCheckpointFormat,
),
"llama-yarn": (
"gpt",
CONFIG_LLAMA_YARN_FAST_LLM,
Expand Down