Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 112 additions & 88 deletions paddleformers/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import re
import shutil
import sys
import warnings
from dataclasses import field
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -583,11 +582,67 @@ def __setitem__(self, key, value):
if hasattr(self, key):
setattr(self, key, value)

def __init__(self, **kwargs):
# Attributes with defaults
# map the old attr to new atr, eg: num_classes -> num_labels
def __init__(
self,
*,
output_hidden_states: bool = False,
output_attentions: bool = False,
return_dict: bool = True,
dtype: Optional[str] = None,
tie_word_embeddings: bool = True,
chunk_size_feed_forward: int = 0,
is_encoder_decoder: bool = False,
is_decoder: bool = False,
cross_attention_hidden_size: Optional[int] = None,
add_cross_attention: bool = False,
tie_encoder_decoder: bool = False,
# Fine-tuning task arguments
architectures: Optional[list[str]] = None,
finetuning_task: Optional[str] = None,
id2label: Optional[dict[int, str]] = None,
label2id: Optional[dict[str, int]] = None,
num_labels: Optional[int] = None,
task_specific_params: Optional[dict[str, Any]] = None,
problem_type: Optional[str] = None,
# Tokenizer kwargs
tokenizer_class: Optional[str] = None,
prefix: Optional[str] = None,
bos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
sep_token_id: Optional[int] = None,
decoder_start_token_id: Optional[int] = None,
**kwargs,
):

# Label Map
if label2id is not None and not isinstance(label2id, dict):
raise ValueError("Argument label2id should be a dictionary.")
if id2label is not None and not isinstance(id2label, dict):
raise ValueError("Argument id2label should be a dictionary.")
if num_labels is not None and id2label is not None and len(id2label) != num_labels:
logger.warning(
f"You passed `num_labels={num_labels}` which is incompatible to "
f"the `id2label` map of length `{len(id2label)}`."
)

# problem_type
if problem_type is not None and problem_type not in (
"regression",
"single_label_classification",
"multi_label_classification",
):
raise ValueError(
f"The config parameter `problem_type` was not understood: received {problem_type} "
"but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid."
)

if (torch_dtype := kwargs.pop("torch_dtype", None)) is not None:
dtype = dtype if dtype is not None else torch_dtype

kwargs = attribute_map(self, kwargs=kwargs)
kwargs.pop("transformers_version", None)

llm_meta = LlmMetaConfig._get_defaults()
self._unsavable_keys.update(LlmMetaConfig._get_unsavable_keys())
self._unsavable_keys.remove("tensor_parallel_degree")
Expand All @@ -596,116 +651,86 @@ def __init__(self, **kwargs):
self._unsavable_keys.add("_attn_implementation")

kwargs = set_expected_keys(self, llm_meta, kwargs)
if self.sequence_parallel:

if hasattr(self, "sequence_parallel") and self.sequence_parallel:
assert (
self.tensor_parallel_degree > 1
), f"senquence-parallel only works in tensor parallel, got tensor parallel degree={self.tensor_parallel_degree}"

self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
self.return_dict = kwargs.pop("return_dict", False)
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
self.output_attentions = kwargs.pop("output_attentions", False)
self.dtype = kwargs.pop("dtype", None)
self.use_cache = kwargs.pop("use_cache", False)
self.tie_word_embeddings = kwargs.pop("tie_word_embeddings", True)

# for transformers fuse
getattr(self, "tensor_parallel_degree", 1) > 1
), f"sequence-parallel only works in tensor parallel, got tensor parallel degree={getattr(self, 'tensor_parallel_degree', 1)}"

self.return_dict = return_dict
self.output_hidden_states = output_hidden_states
self.output_attentions = output_attentions
self.dtype = dtype
self.tie_word_embeddings = tie_word_embeddings
self.chunk_size_feed_forward = chunk_size_feed_forward
self.is_encoder_decoder = is_encoder_decoder
self.is_decoder = is_decoder
self.cross_attention_hidden_size = cross_attention_hidden_size
self.add_cross_attention = add_cross_attention
self.tie_encoder_decoder = tie_encoder_decoder

# Fine-tuning attributes
self.architectures = architectures
self.finetuning_task = finetuning_task
self.id2label = id2label
self.label2id = label2id
if self.id2label is None:
self.num_labels = num_labels if num_labels is not None else 2
else:
self.id2label = {int(key): value for key, value in self.id2label.items()}
self.num_labels = len(self.id2label)
self.task_specific_params = task_specific_params
self.problem_type = problem_type

# Tokenizer attributes
self.tokenizer_class = tokenizer_class
self.prefix = prefix
self.bos_token_id = bos_token_id
self.pad_token_id = pad_token_id
self.eos_token_id = eos_token_id
self.sep_token_id = sep_token_id
self.decoder_start_token_id = decoder_start_token_id

# For transformers fuse
self.fuse_linear = kwargs.pop("fuse_linear", False)
self.fuse_attention_qkv = kwargs.pop("fuse_attention_qkv", False)
self.fuse_attention_ffn = kwargs.pop("fuse_attention_ffn", False)

# for general components
# For general components
self._attn_implementation = kwargs.pop("_attn_implementation", "eager")

if "quantization_config" in kwargs and isinstance(kwargs["quantization_config"], Dict):
# Quantization Config
if "quantization_config" in kwargs and isinstance(kwargs["quantization_config"], dict):
kwargs["quantization_config"] = QuantizationConfig.from_dict(kwargs["quantization_config"])
self.quantization_config = kwargs.pop("quantization_config", QuantizationConfig())

self.pruned_heads = kwargs.pop("pruned_heads", {})

# parameter for model dtype
if "torch_dtype" in kwargs:
self.dtype = kwargs.pop("torch_dtype")

# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
self.is_decoder = kwargs.pop("is_decoder", False)
self.cross_attention_hidden_size = kwargs.pop("cross_attention_hidden_size", None)
self.add_cross_attention = kwargs.pop("add_cross_attention", False)
self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False)

# Retrocompatibility: Parameters for sequence generation. While we will keep the ability to load these
# parameters, saving them will be deprecated. In a distant future, we won't need to load them.
for parameter_name, default_value in self._get_generation_defaults().items():
setattr(self, parameter_name, kwargs.pop(parameter_name, default_value))

# Fine-tuning task arguments
self.architectures = kwargs.pop("architectures", None)
self.finetuning_task = kwargs.pop("finetuning_task", None)
self.id2label = kwargs.pop("id2label", None)
self.label2id = kwargs.pop("label2id", None)
if self.id2label is not None:
num_labels = kwargs.pop("num_labels", None)
if num_labels is not None and len(self.id2label) != num_labels:
logger.warning(
f"You passed along `num_labels={num_labels}` with an incompatible id to label map: "
f"{self.id2label}. The number of labels will be overwritten to {self.num_labels}."
)
self.id2label = dict((int(key), value) for key, value in self.id2label.items())
# Keys are always strings in JSON so convert ids to int here.
else:
self.num_labels = kwargs.pop("num_labels", 2)
self.num_choices = kwargs.pop("num_choices", None)

self.classifier_dropout = kwargs.pop("classifier_dropout", None)

self.dpo_config = kwargs.pop("dpo_config", None)
self.kto_config = kwargs.pop("kto_config", None)

# MoE specific
self.moe_subbatch_token_num = kwargs.pop("moe_subbatch_token_num", 0)
self.ep_communication_type = kwargs.pop("ep_communication_type", "deepep")
self.use_unified_moe = kwargs.pop("use_unified_moe", False)
self.using_fake_gate = kwargs.pop("using_fake_gate", False)

# Tokenizer arguments TODO: eventually tokenizer and models should share the same config
self.tokenizer_class = kwargs.pop("tokenizer_class", None)
self.prefix = kwargs.pop("prefix", None)
self.bos_token_id = kwargs.pop("bos_token_id", None)
self.pad_token_id = kwargs.pop("pad_token_id", None)
self.eos_token_id = kwargs.pop("eos_token_id", None)
self.sep_token_id = kwargs.pop("sep_token_id", None)

self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)

# task specific arguments
self.task_specific_params = kwargs.pop("task_specific_params", None)

# regression / multi-label classification
self.problem_type = kwargs.pop("problem_type", None)
allowed_problem_types = ("regression", "single_label_classification", "multi_label_classification")
if self.problem_type is not None and self.problem_type not in allowed_problem_types:
raise ValueError(
f"The config parameter `problem_type` was not understood: received {self.problem_type} "
"but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid."
)

# Name or path to the pretrained checkpoint
# Name or path
self._name_or_path = str(kwargs.pop("name_or_path", ""))

# Drop the transformers version info
self._commit_hash = kwargs.pop("_commit_hash", None)
self._save_to_hf = kwargs.pop("save_to_hf", False)
self._unsavable_keys.add("_save_to_hf")
self.paddleformers_version = kwargs.pop("paddleformers_version", None)

# Deal with gradient checkpointing
for parameter_name in self._get_generation_defaults().keys():
kwargs.pop(parameter_name, None)

if kwargs.get("gradient_checkpointing", False):
warnings.warn(
"Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "
"Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the "
"`Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`."
)
self._save_to_hf = kwargs.pop("save_to_hf", False)
self._unsavable_keys.add("_save_to_hf")
logger.warning("Passing `gradient_checkpointing` to a config initialization is deprecated...")
kwargs.pop("gradient_checkpointing", None)

# Additional attributes without default values
for key, value in kwargs.items():
try:
setattr(self, key, value)
Expand Down Expand Up @@ -878,7 +903,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)

return cls.from_dict(config_dict, **kwargs)

@classmethod
Expand Down
Loading