diff --git a/paddleformers/transformers/configuration_utils.py b/paddleformers/transformers/configuration_utils.py index 4a658e9af97..c827c23497c 100644 --- a/paddleformers/transformers/configuration_utils.py +++ b/paddleformers/transformers/configuration_utils.py @@ -24,7 +24,6 @@ import re import shutil import sys -import warnings from dataclasses import field from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union @@ -583,11 +582,67 @@ def __setitem__(self, key, value): if hasattr(self, key): setattr(self, key, value) - def __init__(self, **kwargs): - # Attributes with defaults - # map the old attr to new atr, eg: num_classes -> num_labels + def __init__( + self, + *, + output_hidden_states: bool = False, + output_attentions: bool = False, + return_dict: bool = True, + dtype: Optional[str] = None, + tie_word_embeddings: bool = True, + chunk_size_feed_forward: int = 0, + is_encoder_decoder: bool = False, + is_decoder: bool = False, + cross_attention_hidden_size: Optional[int] = None, + add_cross_attention: bool = False, + tie_encoder_decoder: bool = False, + # Fine-tuning task arguments + architectures: Optional[list[str]] = None, + finetuning_task: Optional[str] = None, + id2label: Optional[dict[int, str]] = None, + label2id: Optional[dict[str, int]] = None, + num_labels: Optional[int] = None, + task_specific_params: Optional[dict[str, Any]] = None, + problem_type: Optional[str] = None, + # Tokenizer kwargs + tokenizer_class: Optional[str] = None, + prefix: Optional[str] = None, + bos_token_id: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + sep_token_id: Optional[int] = None, + decoder_start_token_id: Optional[int] = None, + **kwargs, + ): + + # Label Map + if label2id is not None and not isinstance(label2id, dict): + raise ValueError("Argument label2id should be a dictionary.") + if id2label is not None and not isinstance(id2label, dict): + raise ValueError("Argument id2label should be a dictionary.") + if num_labels is not None and id2label is not None and len(id2label) != num_labels: + logger.warning( + f"You passed `num_labels={num_labels}` which is incompatible to " + f"the `id2label` map of length `{len(id2label)}`." + ) + + # problem_type + if problem_type is not None and problem_type not in ( + "regression", + "single_label_classification", + "multi_label_classification", + ): + raise ValueError( + f"The config parameter `problem_type` was not understood: received {problem_type} " + "but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid." + ) + + if (torch_dtype := kwargs.pop("torch_dtype", None)) is not None: + dtype = dtype if dtype is not None else torch_dtype + kwargs = attribute_map(self, kwargs=kwargs) kwargs.pop("transformers_version", None) + llm_meta = LlmMetaConfig._get_defaults() self._unsavable_keys.update(LlmMetaConfig._get_unsavable_keys()) self._unsavable_keys.remove("tensor_parallel_degree") @@ -596,116 +651,86 @@ def __init__(self, **kwargs): self._unsavable_keys.add("_attn_implementation") kwargs = set_expected_keys(self, llm_meta, kwargs) - if self.sequence_parallel: + + if hasattr(self, "sequence_parallel") and self.sequence_parallel: assert ( - self.tensor_parallel_degree > 1 - ), f"senquence-parallel only works in tensor parallel, got tensor parallel degree={self.tensor_parallel_degree}" - - self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0) - self.return_dict = kwargs.pop("return_dict", False) - self.output_hidden_states = kwargs.pop("output_hidden_states", False) - self.output_attentions = kwargs.pop("output_attentions", False) - self.dtype = kwargs.pop("dtype", None) - self.use_cache = kwargs.pop("use_cache", False) - self.tie_word_embeddings = kwargs.pop("tie_word_embeddings", True) - - # for transformers fuse + getattr(self, "tensor_parallel_degree", 1) > 1 + ), f"sequence-parallel only works in tensor parallel, got tensor parallel degree={getattr(self, 'tensor_parallel_degree', 1)}" + + self.return_dict = return_dict + self.output_hidden_states = output_hidden_states + self.output_attentions = output_attentions + self.dtype = dtype + self.tie_word_embeddings = tie_word_embeddings + self.chunk_size_feed_forward = chunk_size_feed_forward + self.is_encoder_decoder = is_encoder_decoder + self.is_decoder = is_decoder + self.cross_attention_hidden_size = cross_attention_hidden_size + self.add_cross_attention = add_cross_attention + self.tie_encoder_decoder = tie_encoder_decoder + + # Fine-tuning attributes + self.architectures = architectures + self.finetuning_task = finetuning_task + self.id2label = id2label + self.label2id = label2id + if self.id2label is None: + self.num_labels = num_labels if num_labels is not None else 2 + else: + self.id2label = {int(key): value for key, value in self.id2label.items()} + self.num_labels = len(self.id2label) + self.task_specific_params = task_specific_params + self.problem_type = problem_type + + # Tokenizer attributes + self.tokenizer_class = tokenizer_class + self.prefix = prefix + self.bos_token_id = bos_token_id + self.pad_token_id = pad_token_id + self.eos_token_id = eos_token_id + self.sep_token_id = sep_token_id + self.decoder_start_token_id = decoder_start_token_id + + # For transformers fuse self.fuse_linear = kwargs.pop("fuse_linear", False) self.fuse_attention_qkv = kwargs.pop("fuse_attention_qkv", False) self.fuse_attention_ffn = kwargs.pop("fuse_attention_ffn", False) - # for general components + # For general components self._attn_implementation = kwargs.pop("_attn_implementation", "eager") - if "quantization_config" in kwargs and isinstance(kwargs["quantization_config"], Dict): + # Quantization Config + if "quantization_config" in kwargs and isinstance(kwargs["quantization_config"], dict): kwargs["quantization_config"] = QuantizationConfig.from_dict(kwargs["quantization_config"]) self.quantization_config = kwargs.pop("quantization_config", QuantizationConfig()) self.pruned_heads = kwargs.pop("pruned_heads", {}) - # parameter for model dtype - if "torch_dtype" in kwargs: - self.dtype = kwargs.pop("torch_dtype") - - # Is decoder is used in encoder-decoder models to differentiate encoder from decoder - self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False) - self.is_decoder = kwargs.pop("is_decoder", False) - self.cross_attention_hidden_size = kwargs.pop("cross_attention_hidden_size", None) - self.add_cross_attention = kwargs.pop("add_cross_attention", False) - self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False) - - # Retrocompatibility: Parameters for sequence generation. While we will keep the ability to load these - # parameters, saving them will be deprecated. In a distant future, we won't need to load them. - for parameter_name, default_value in self._get_generation_defaults().items(): - setattr(self, parameter_name, kwargs.pop(parameter_name, default_value)) - - # Fine-tuning task arguments - self.architectures = kwargs.pop("architectures", None) - self.finetuning_task = kwargs.pop("finetuning_task", None) - self.id2label = kwargs.pop("id2label", None) - self.label2id = kwargs.pop("label2id", None) - if self.id2label is not None: - num_labels = kwargs.pop("num_labels", None) - if num_labels is not None and len(self.id2label) != num_labels: - logger.warning( - f"You passed along `num_labels={num_labels}` with an incompatible id to label map: " - f"{self.id2label}. The number of labels will be overwritten to {self.num_labels}." - ) - self.id2label = dict((int(key), value) for key, value in self.id2label.items()) - # Keys are always strings in JSON so convert ids to int here. - else: - self.num_labels = kwargs.pop("num_labels", 2) self.num_choices = kwargs.pop("num_choices", None) - self.classifier_dropout = kwargs.pop("classifier_dropout", None) - self.dpo_config = kwargs.pop("dpo_config", None) self.kto_config = kwargs.pop("kto_config", None) + # MoE specific self.moe_subbatch_token_num = kwargs.pop("moe_subbatch_token_num", 0) self.ep_communication_type = kwargs.pop("ep_communication_type", "deepep") self.use_unified_moe = kwargs.pop("use_unified_moe", False) self.using_fake_gate = kwargs.pop("using_fake_gate", False) - # Tokenizer arguments TODO: eventually tokenizer and models should share the same config - self.tokenizer_class = kwargs.pop("tokenizer_class", None) - self.prefix = kwargs.pop("prefix", None) - self.bos_token_id = kwargs.pop("bos_token_id", None) - self.pad_token_id = kwargs.pop("pad_token_id", None) - self.eos_token_id = kwargs.pop("eos_token_id", None) - self.sep_token_id = kwargs.pop("sep_token_id", None) - - self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None) - - # task specific arguments - self.task_specific_params = kwargs.pop("task_specific_params", None) - - # regression / multi-label classification - self.problem_type = kwargs.pop("problem_type", None) - allowed_problem_types = ("regression", "single_label_classification", "multi_label_classification") - if self.problem_type is not None and self.problem_type not in allowed_problem_types: - raise ValueError( - f"The config parameter `problem_type` was not understood: received {self.problem_type} " - "but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid." - ) - - # Name or path to the pretrained checkpoint + # Name or path self._name_or_path = str(kwargs.pop("name_or_path", "")) - - # Drop the transformers version info + self._commit_hash = kwargs.pop("_commit_hash", None) + self._save_to_hf = kwargs.pop("save_to_hf", False) + self._unsavable_keys.add("_save_to_hf") self.paddleformers_version = kwargs.pop("paddleformers_version", None) - # Deal with gradient checkpointing + for parameter_name in self._get_generation_defaults().keys(): + kwargs.pop(parameter_name, None) + if kwargs.get("gradient_checkpointing", False): - warnings.warn( - "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 " - "Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the " - "`Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`." - ) - self._save_to_hf = kwargs.pop("save_to_hf", False) - self._unsavable_keys.add("_save_to_hf") + logger.warning("Passing `gradient_checkpointing` to a config initialization is deprecated...") + kwargs.pop("gradient_checkpointing", None) - # Additional attributes without default values for key, value in kwargs.items(): try: setattr(self, key, value) @@ -878,7 +903,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." ) - return cls.from_dict(config_dict, **kwargs) @classmethod