diff --git a/src/lmflow/args.py b/src/lmflow/args.py index dbe6fbb3..085d090b 100644 --- a/src/lmflow/args.py +++ b/src/lmflow/args.py @@ -238,8 +238,8 @@ class ModelArguments: metadata={ "help": "Merging ratio between the fine-tuned model and the original. This is controlled by a parameter called alpha in the paper."}, ) - lora_target_modules: List[str] = field( - default=None, metadata={"help": "Pretrained config name or path if not the same as model_name",} + lora_target_modules: str = field( + default=None, metadata={"help": "Model modules to apply LoRA to. Use comma to separate multiple modules."} ) lora_dropout: float = field( default=0.1, @@ -364,6 +364,9 @@ def __post_init__(self): if not is_flash_attn_available(): self.use_flash_attention = False logger.warning("Flash attention is not available in the current environment. Disabling flash attention.") + + if self.lora_target_modules is not None: + self.lora_target_modules: List[str] = split_args(self.lora_target_modules) @dataclass @@ -1464,3 +1467,7 @@ class AutoArguments: def get_pipeline_args_class(pipeline_name: str): return PIPELINE_ARGUMENT_MAPPING[pipeline_name] + + +def split_args(args): + return [elem.strip() for elem in args.split(",")] if isinstance(args, str) else args \ No newline at end of file diff --git a/src/lmflow/utils/constants.py b/src/lmflow/utils/constants.py index f0278985..04ad5696 100644 --- a/src/lmflow/utils/constants.py +++ b/src/lmflow/utils/constants.py @@ -386,12 +386,17 @@ DEFAULT_IM_END_TOKEN = "" # Lora -# NOTE: Be careful, when passing lora_target_modules through arg parser, the -# value should be like'--lora_target_modules q_proj, v_proj \', while specifying -# here, it should be in list format. +# NOTE: This work as a mapping for those models that `peft` library doesn't support yet, and will be +# overwritten by peft.utils.constants.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING +# if the model is supported (see hf_model_mixin.py). +# NOTE: When passing lora_target_modules through arg parser, the +# value should be a string. Using commas to separate the module names, e.g. +# "--lora_target_modules 'q_proj, v_proj'". +# However, when specifying here, they should be lists. LMFLOW_LORA_TARGET_MODULES_MAPPING = { 'qwen2': ["q_proj", "v_proj"], 'internlm2': ["wqkv"], + 'hymba': ["x_proj.0", "in_proj", "out_proj", "dt_proj.0"] } # vllm inference