diff --git a/paddleformers/cli/hparams/data_args.py b/paddleformers/cli/hparams/data_args.py index f49621a5156..b1e1372a94e 100644 --- a/paddleformers/cli/hparams/data_args.py +++ b/paddleformers/cli/hparams/data_args.py @@ -90,7 +90,7 @@ class DataArguments: metadata={"help": "Number of candidate responses."}, ) mask_out_eos_token: bool = field( - default=True, + default=False, metadata={"help": "Mask out eos token"}, ) random_shuffle: bool = field( diff --git a/paddleformers/cli/train/dpo/workflow.py b/paddleformers/cli/train/dpo/workflow.py index ce42a9e6de6..8eeeef3be24 100644 --- a/paddleformers/cli/train/dpo/workflow.py +++ b/paddleformers/cli/train/dpo/workflow.py @@ -263,14 +263,12 @@ def run_dpo( "mix_strategy": data_args.mix_strategy, "encode_one_turn": data_args.encode_one_turn, "stage": model_args.stage, - "is_valid": False, "template_backend": data_args.template_backend, } dataset_config.update( { "template": data_args.template, - "train_on_prompt": False, "tool_format": None, "default_system": None, "enable_thinking": True, @@ -325,11 +323,11 @@ def run_dpo( train_dataset = None if training_args.do_eval and training_args.should_load_dataset: - dataset_config["is_valid"] = True eval_dataset = create_dataset( task_group=data_args.eval_dataset_path, task_group_prob=data_args.eval_dataset_prob, sub_dataset_type=data_args.eval_dataset_type, + is_valid=True, **dataset_config, ) else: diff --git a/paddleformers/cli/train/sft/workflow.py b/paddleformers/cli/train/sft/workflow.py index 94ad5b48321..862756d292a 100644 --- a/paddleformers/cli/train/sft/workflow.py +++ b/paddleformers/cli/train/sft/workflow.py @@ -336,7 +336,6 @@ def neft_post_hook(module, input, output): "is_pretraining": True if model_args.stage.lower() == "pt" else False, "truncate_packing": data_args.truncate_packing, "stage": model_args.stage, - "is_valid": False, "template_backend": data_args.template_backend, "split_multi_turn": data_args.split_multi_turn, } @@ -344,7 +343,6 @@ def neft_post_hook(module, input, output): dataset_config.update( { "template": data_args.template, - "train_on_prompt": False, "tool_format": None, "default_system": None, "enable_thinking": True, @@ -373,11 +371,11 @@ def neft_post_hook(module, input, output): sub_dataset_type=data_args.train_dataset_type, **dataset_config, ) - dataset_config["is_valid"] = True eval_dataset = create_dataset_sft( task_group=data_args.eval_dataset_path, task_group_prob=data_args.eval_dataset_prob, sub_dataset_type=data_args.eval_dataset_type, + is_valid=True, **dataset_config, ) diff --git a/paddleformers/datasets/DPODataset.py b/paddleformers/datasets/DPODataset.py index 94f63fe8bd2..2ad2076431f 100644 --- a/paddleformers/datasets/DPODataset.py +++ b/paddleformers/datasets/DPODataset.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from copy import deepcopy from dataclasses import dataclass from typing import List, Optional @@ -19,7 +20,7 @@ import numpy as np from paddle.io import IterableDataset -from paddleformers.datasets.data_utils import postprocess_fc_sequence +from paddleformers.datasets.data_utils import postprocess_fc_sequence, print_debug_info from paddleformers.datasets.reader.mix_datasets import create_dataset_instance from paddleformers.datasets.reader.multi_source_datasets import MultiSourceDataset from paddleformers.transformers.tokenizer_utils import PretrainedTokenizer @@ -240,9 +241,10 @@ def __postprocess_before_concat(self, example): response_token_ids_list = [] response_label_ids_list = [] response_len_list = [] + split_index = example["session_start_index"] // 2 for responses in [ - chosen_encoded_messages[example["session_start_index"] // 2 :], - rejected_encoded_messages[example["session_start_index"] // 2 :], + chosen_encoded_messages[split_index:], + rejected_encoded_messages[split_index:], ]: responses_token_ids = [] responses_label_ids = [] @@ -274,9 +276,9 @@ def __postprocess_before_concat(self, example): cur_len += sum(map(len, response_token_ids_list)) # create at least one turn - turn_index = len(chosen_encoded_messages) - 1 + turn_index = split_index while turn_index >= 0: - if turn_index == len(chosen_encoded_messages) - 1: + if turn_index == split_index: cur_turn_token = chosen_encoded_messages[turn_index][0] else: cur_turn_token = chosen_encoded_messages[turn_index][0] + chosen_encoded_messages[turn_index][1] @@ -289,7 +291,7 @@ def __postprocess_before_concat(self, example): turn_index -= 1 # at least one turn - if turn_index == len(chosen_encoded_messages) - 1: + if turn_index == split_index: sub_src = example["chosen"]["messages"][0]["content"].strip()[:5] global LOGGER_COUNT LOGGER_COUNT += 1 @@ -364,6 +366,30 @@ def _postprocess_sequence(self, example): prompt_len : (prompt_len + chosen_len), ] = False attn_mask_startend_row_indices = None + + # print + enable_dataset_debug = os.getenv("FLAGS_enable_dataset_debug", "false").lower() in ("true", "1", "t") + if enable_dataset_debug: + logger.info("\n" + "=" * 50) + logger.info("[dataset debug] Debug mode enabled") + if hasattr(self, "tokenizer"): + print("========================================") + print_debug_info(self.tokenizer, input_ids, "input") + print("========================================\n") + + filtered_labels = [x for x in chosen_labels if x != 0] # remove -100 + print("========================================") + print_debug_info(self.tokenizer, filtered_labels, "chosen_labels") + print("========================================\n") + + filtered_labels = [x for x in rejected_labels if x != 0] # remove -100 + print("========================================") + print_debug_info(self.tokenizer, filtered_labels, "rejected_labels") + print("========================================\n") + else: + logger.info("[dataset debug] Tokenizer not available") + logger.info("=" * 50 + "\n") + # 2. return sequence return Sequence( token_ids=input_ids, diff --git a/paddleformers/datasets/SFTDataset.py b/paddleformers/datasets/SFTDataset.py index 7dcfc79e7f4..39053a56810 100644 --- a/paddleformers/datasets/SFTDataset.py +++ b/paddleformers/datasets/SFTDataset.py @@ -34,7 +34,6 @@ class Sequence: token_ids: List[int] position_ids: List[int] labels: List[int] - loss_mask: List[int] num_examples: int images: List[str] videos: List[str] @@ -51,6 +50,7 @@ def __init__(self, **dataset_config): self.template = dataset_config.get("template_instance", None) self.template_backend = dataset_config.get("template_backend", "jinja") self.use_template = dataset_config.get("use_template", True) + self.efficient_eos = True if not self.template else getattr(self.template, "efficient_eos", False) self.split_multi_turn = dataset_config.get("split_multi_turn", False) self.encode_one_turn = dataset_config.get("encode_one_turn", True) self.is_pretraining = dataset_config.get("is_pretraining", False) @@ -142,13 +142,11 @@ def __iter_func(self): res_tokens = cut_tokens[:-1] res_labels = cut_tokens[1:] - loss_mask = [1] * len(res_tokens) pos_ids = list(range(len(res_tokens))) sequence = Sequence( token_ids=res_tokens, position_ids=pos_ids, labels=res_labels, - loss_mask=loss_mask, num_examples=actual_example_num, images=[], videos=[], @@ -172,13 +170,11 @@ def __iter_func(self): cut_tokens = cut_tokens + [self.tokenizer.eos_token_id] res_tokens = cut_tokens[:-1] res_labels = cut_tokens[1:] - loss_mask = [1] * len(res_tokens) pos_ids = list(range(len(res_tokens))) sequence = Sequence( token_ids=res_tokens, position_ids=pos_ids, labels=res_labels, - loss_mask=loss_mask, num_examples=actual_example_num, images=[], videos=[], @@ -313,13 +309,11 @@ def _postprocess_pretraining_sequence(self, example, actual_example_num): tokens = self._encode_pretraining_example(example, actual_example_num) res_tokens = tokens[:-1] res_labels = tokens[1:] - loss_mask = [1] * len(res_tokens) pos_ids = list(range(len(res_tokens))) sequence = Sequence( token_ids=res_tokens, position_ids=pos_ids, labels=res_labels, - loss_mask=loss_mask, num_examples=actual_example_num, images=[], videos=[], @@ -377,7 +371,7 @@ def _postprocess_sequence(self, example, actual_example_num): turn_index = len(encoded_pairs) - 1 tokens = [] - loss_mask = [] + labels = [] while turn_index >= 0: tokens_src, tokens_target = encoded_pairs[turn_index] if len(tokens_target) == 0: @@ -394,12 +388,16 @@ def _postprocess_sequence(self, example, actual_example_num): reverse_len = self.max_seq_len + 1 - cur_len - num_reserved_tokens_for_each_turn - len(tokens_src) tokens_target = tokens_target[:reverse_len] + if self.use_template and self.efficient_eos and turn_index != 0: + labels_src = [self.tokenizer.eos_token_id] + [-100] * (len(tokens_src) - 1) + else: + labels_src = [-100] * len(tokens_src) + + labels_target = tokens_target tokens = tokens_src + tokens_target + tokens + labels = labels_src + labels_target + labels - loss_mask = ( - [0] * (len(tokens_src) - 1) + [example["label"][turn_index]] * (len(tokens_target) + 1) + loss_mask - ) - assert len(tokens) == len(loss_mask), f"{len(tokens)}-{len(loss_mask)}" + assert len(tokens) == len(labels), f"{len(tokens)}-{len(labels)}" cur_len = len(tokens) @@ -424,15 +422,18 @@ def _postprocess_sequence(self, example, actual_example_num): # Maybe left truncated, so need to add begin_token if tokens[0] != self.begin_token_id: tokens = [self.begin_token_id] + tokens - loss_mask = [0] + loss_mask + labels = [-100] + labels if len(tokens) > self.max_seq_len: raise RuntimeError(f"token_ids is too long: {len(tokens)}") # Add EOS token at the end del tokens[-1] - del loss_mask[-1] - labels = tokens[1:] + [self.tokenizer.eos_token_id] + del labels[-1] + if self.efficient_eos: + tokens = tokens + [self.tokenizer.eos_token_id] + labels = labels + [self.tokenizer.eos_token_id] + labels = labels[1:] + [-100] # end_of_response is a special token that indicates the end of the turn. # end_token is a special token that indicates the end of the answer. @@ -440,25 +441,23 @@ def _postprocess_sequence(self, example, actual_example_num): label if label != self.end_of_response_id else self.tokenizer.eos_token_id for label in labels ] else: - tokens = tokens[:-1] + [self.tokenizer.eos_token_id] - labels = tokens[1:] + [-100] + if self.efficient_eos: + tokens = tokens + [self.tokenizer.eos_token_id] + labels = labels + [self.tokenizer.eos_token_id] + labels = labels[1:] + [-100] if len(tokens) > self.max_seq_len: raise RuntimeError(f"token_ids is too long: {len(tokens)}") else: - oral_tokens = tokens - tokens = oral_tokens[:-1] - labels = oral_tokens[1:] - loss_mask = loss_mask[:-1] + labels = tokens[1:] + [-100] if len(tokens) > self.max_seq_len: raise RuntimeError(f"token_ids is too long: {len(tokens)}") pos_ids = list(range(len(tokens))) - if sum(loss_mask) == 0: + if all(x == -100 for x in labels): logger.warning(f"[SKIP] all labels set to 0: {example}") return None - assert len(tokens) == len(loss_mask), f"{len(tokens)}-{len(loss_mask)}" assert len(tokens) == len(labels), f"{len(tokens)}-{len(labels)}" enable_dataset_debug = os.getenv("FLAGS_enable_dataset_debug", "false").lower() in ("true", "1", "t") @@ -470,12 +469,10 @@ def _postprocess_sequence(self, example, actual_example_num): print_debug_info(self.tokenizer, tokens, "input") print("========================================\n") - filtered_labels = [label if mask == 1 else -100 for label, mask in zip(labels, loss_mask)] - filtered_labels = [x for x in filtered_labels if x != -100] # remove -100 + filtered_labels = [x for x in labels if x != -100] # remove -100 print("========================================") print_debug_info(self.tokenizer, filtered_labels, "labels") print("========================================\n") - logger.info(f"[dataset debug] loss mask: {loss_mask}") else: logger.info("[dataset debug] Tokenizer not available") logger.info("=" * 50 + "\n") @@ -484,7 +481,6 @@ def _postprocess_sequence(self, example, actual_example_num): token_ids=tokens, position_ids=pos_ids, labels=labels, - loss_mask=loss_mask, num_examples=actual_example_num, images=images, videos=videos, diff --git a/paddleformers/datasets/collate.py b/paddleformers/datasets/collate.py index f414a287421..a0f72a1dfb5 100644 --- a/paddleformers/datasets/collate.py +++ b/paddleformers/datasets/collate.py @@ -165,7 +165,6 @@ def collate_fn(batch: List[List[Sequence]], tokenizer, training_args, model_args dict: Dictionary containing: - input_ids: Padded token IDs - labels: Shifted labels for prediction - - loss_mask: Mask for computing loss """ input_keys = ["input_ids", "labels", "position_ids"] if training_args.num_nextn_predict_layers > 0: @@ -180,15 +179,12 @@ def collate_fn(batch: List[List[Sequence]], tokenizer, training_args, model_args for batch_sequence in batch: original_token_ids = [seq.token_ids for seq in batch_sequence] token_ids = [sum(original_token_ids, [])] - loss_mask = [sum([seq.loss_mask for seq in batch_sequence], [])] labels = [sum([seq.labels for seq in batch_sequence], [])] position_ids = [sum([seq.position_ids for seq in batch_sequence], [])] # padding padded_token_ids = pad_batch_data(token_ids, pad_idx=tokenizer.pad_token_id, max_seq_len=max_seq_len) - padded_labels = pad_batch_data(labels, pad_idx=tokenizer.pad_token_id, max_seq_len=max_seq_len) - padded_loss_mask = pad_batch_data(loss_mask, pad_idx=0, max_seq_len=max_seq_len) + padded_labels = pad_batch_data(labels, pad_idx=-100, max_seq_len=max_seq_len) padded_position_ids = pad_batch_data(position_ids, pad_idx=0, max_seq_len=max_seq_len) - padded_labels = np.where(padded_loss_mask == 1, padded_labels, -100) return_list.append( [ padded_token_ids, diff --git a/paddleformers/datasets/data_utils.py b/paddleformers/datasets/data_utils.py index 309f2a0b7c0..edb54ff4b2a 100644 --- a/paddleformers/datasets/data_utils.py +++ b/paddleformers/datasets/data_utils.py @@ -285,7 +285,7 @@ def estimate_training(train_dataset, data_args, training_args, model_args): "sharding_parallel_degree": int(training_args.sharding_parallel_degree), "num_samples_each_epoch": data_args.num_samples_each_epoch, "max_seq_len": int(data_args.max_seq_len), - "seed": data_args.seed, + "seed": training_args.seed, "valid": False, "train_samples": 0, } diff --git a/paddleformers/datasets/reader/file_reader.py b/paddleformers/datasets/reader/file_reader.py index 7fb58e025c7..20f8caeb7d3 100644 --- a/paddleformers/datasets/reader/file_reader.py +++ b/paddleformers/datasets/reader/file_reader.py @@ -103,9 +103,25 @@ def data_check(self, data): if len(data["messages"]) == 0: raise ValueError("Ignore example with empty messages.") + for index in range(len(data["messages"])): + # Fix the role names for tool call and tool response + if data["messages"][index]["role"] == "tool" or data["messages"][index]["role"] == "tool_response": + data["messages"][index]["role"] = "observation" + if data["messages"][index]["role"] == "tool_call" or data["messages"][index]["role"] == "tool_calls": + data["messages"][index]["role"] = "function" + # Convert the content of tool call and tool response into a string + if ( + data["messages"][index]["role"] == "observation" or data["messages"][index]["role"] == "function" + ) and not isinstance(data["messages"][index]["content"], str): + data["messages"][index]["content"] = json.dumps(data["messages"][index]["content"]) + if "tool_calls" in data["messages"][index] and not isinstance(data["messages"][index]["tool_calls"], str): + data["messages"][index]["tool_calls"] = json.dumps(data["messages"][index]["tool_calls"]) + + # Convert the content of tool list into a string if "tools" in data and not isinstance(data["tools"], str): data["tools"] = json.dumps(data["tools"], ensure_ascii=False) + # If no label is input, it means each response needs to be learned. if "label" not in data: data["label"] = [ 1 for turn in data["messages"] if ("assistant" in turn["role"] or "function" in turn["role"]) @@ -116,6 +132,7 @@ def data_check(self, data): system = data["messages"][0]["content"] if not isinstance(system, str): raise ValueError("System field must be a string.") + data["messages"] = data["messages"][1:] data["system"] = system # Convert the relative paths of multimode data into absolute paths diff --git a/paddleformers/datasets/template/mm_plugin.py b/paddleformers/datasets/template/mm_plugin.py index 934c4e177f5..2832a198220 100644 --- a/paddleformers/datasets/template/mm_plugin.py +++ b/paddleformers/datasets/template/mm_plugin.py @@ -563,10 +563,141 @@ def process_messages( return messages +@dataclass +class GLM4VPlugin(Qwen2VLPlugin): + @override + def _get_mm_inputs( + self, + images, + videos, + audios, + processor, + ): + image_processor = getattr(processor, "image_processor", None) + video_processor = getattr(processor, "video_processor", None) + mm_inputs = {} + if len(images) != 0: + images = self._regularize_images( + images, + image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768), + image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32), + )["images"] + mm_inputs.update(image_processor(images, return_tensors="pt")) + + if len(videos) != 0: + video_data = self._regularize_videos( + videos, + image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256), + image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16), + video_fps=getattr(processor, "video_fps", 2.0), + video_maxlen=getattr(processor, "video_maxlen", 128), + ) + # prepare video metadata + video_metadata = [ + {"fps": 2, "duration": len(video), "total_frames": len(video)} for video in video_data["videos"] + ] + mm_inputs.update(video_processor(images=None, videos=video_data["videos"], video_metadata=video_metadata)) + + return mm_inputs + + @override + def process_messages( + self, + messages, + images, + videos, + audios, + processor, + ): + self._validate_input(processor, images, videos, audios) + self._validate_messages(messages, images, videos, audios) + num_image_tokens, num_video_tokens = 0, 0 + messages = deepcopy(messages) + image_processor = getattr(processor, "image_processor") + + merge_length = getattr(image_processor, "merge_size") ** 2 + if self.expand_mm_tokens: + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + image_grid_thw = mm_inputs.get("image_grid_thw", []) + video_grid_thw = mm_inputs.get("video_grid_thw", []) + num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now + timestamps = mm_inputs.get("timestamps", []) + + if hasattr(timestamps, "tolist"): + timestamps = timestamps.tolist() + + if not timestamps: + timestamps_list = [] + elif isinstance(timestamps[0], list): + timestamps_list = timestamps[0] + else: + timestamps_list = timestamps + + unique_timestamps = timestamps_list.copy() + selected_timestamps = unique_timestamps[:num_frames] + while len(selected_timestamps) < num_frames: + selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0) + + else: + image_grid_thw = [None] * len(images) + video_grid_thw = [None] * len(videos) + num_frames = 0 + selected_timestamps = [0] + + for message in messages: + content = message["content"] + while IMAGE_PLACEHOLDER in content: + image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1 + content = content.replace( + IMAGE_PLACEHOLDER, f"<|begin_of_image|>{self.image_token * image_seqlen}<|end_of_image|>", 1 + ) + num_image_tokens += 1 + + while VIDEO_PLACEHOLDER in content: + video_structure = "" + for frame_index in range(num_frames): + video_seqlen = ( + video_grid_thw[num_video_tokens][1:].prod() // merge_length if self.expand_mm_tokens else 1 + ) + timestamp_sec = selected_timestamps[frame_index] + frame_structure = ( + f"<|begin_of_image|>{self.image_token * video_seqlen}<|end_of_image|>{timestamp_sec}" + ) + video_structure += frame_structure + + if not self.expand_mm_tokens: + video_structure = self.video_token + + content = content.replace(VIDEO_PLACEHOLDER, f"<|begin_of_video|>{video_structure}<|end_of_video|>", 1) + num_video_tokens += 1 + + message["content"] = content + + return messages + + @override + def get_mm_inputs( + self, + images, + videos, + audios, + imglens, + vidlens, + audlens, + batch_ids, + processor, + ): + self._validate_input(processor, images, videos, audios) + mm_inputs = self._get_mm_inputs(images, videos, audios, processor) + mm_inputs.pop("timestamps", None) + return mm_inputs + + PLUGINS = { "base": BasePlugin, "qwen2_vl": Qwen2VLPlugin, "qwen3_vl": Qwen3VLPlugin, + "glm4v": GLM4VPlugin, } diff --git a/paddleformers/datasets/template/template.py b/paddleformers/datasets/template/template.py index 2604ae94e58..4051cc0827a 100644 --- a/paddleformers/datasets/template/template.py +++ b/paddleformers/datasets/template/template.py @@ -20,7 +20,7 @@ from copy import deepcopy from dataclasses import dataclass from enum import Enum, unique -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Optional from typing_extensions import override @@ -34,7 +34,6 @@ from .formatter import SLOTS, Formatter from .mm_plugin import BasePlugin - from .tool_utils import FunctionCall @unique @@ -60,7 +59,6 @@ class Template: thought_words: tuple[str, str] efficient_eos: bool replace_eos: bool - replace_jinja_template: bool enable_thinking: Optional[bool] mm_plugin: "BasePlugin" @@ -91,18 +89,6 @@ def encode_multiturn( encoded_messages = self._encode(tokenizer, messages, system, tools) return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)] - def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]: - r"""Extract tool message.""" - return self.format_tools.extract(content) - - def get_stop_token_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]: - r"""Return stop token ids.""" - stop_token_ids = {tokenizer.eos_token_id} - for token in self.stop_words: - stop_token_ids.add(tokenizer.convert_tokens_to_ids(token)) - - return list(stop_token_ids) - def add_thought(self, content: str = "") -> str: r"""Add empty thought to assistant message.""" return f"{self.thought_words[0]}{self.thought_words[1]}" + content @@ -162,6 +148,10 @@ def _encode( elements += self.format_user.apply(content=message["content"], idx=str(i // 2)) elif message["role"] == Role.ASSISTANT: elements += self.format_assistant.apply(content=message["content"]) + if "tool_calls" in message: + elements += self.format_function.apply( + content=message["tool_calls"], thought_words=self.thought_words + ) elif message["role"] == Role.OBSERVATION: elements += self.format_observation.apply(content=message["content"]) elif message["role"] == Role.FUNCTION: @@ -215,70 +205,6 @@ def fix_special_tokens(self, tokenizer: "PreTrainedTokenizer") -> None: if num_added_tokens > 0: logger.warning("New tokens have been added, make sure `resize_vocab` is True.") - @staticmethod - def _jinja_escape(content: str) -> str: - r"""Escape single quotes in content.""" - return content.replace("'", r"\'") - - @staticmethod - def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str: - r"""Convert slots to jinja template.""" - slot_items = [] - for slot in slots: - if isinstance(slot, str): - slot_pieces = slot.split("{{content}}") - if slot_pieces[0]: - slot_items.append("'" + Template._jinja_escape(slot_pieces[0]) + "'") - if len(slot_pieces) > 1: - slot_items.append(placeholder) - if slot_pieces[1]: - slot_items.append("'" + Template._jinja_escape(slot_pieces[1]) + "'") - elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced - if "bos_token" in slot and tokenizer.bos_token_id is not None: - slot_items.append("'" + tokenizer.bos_token + "'") - elif "eos_token" in slot and tokenizer.eos_token_id is not None: - slot_items.append("'" + tokenizer.eos_token + "'") - elif isinstance(slot, dict): - raise ValueError("Dict is not supported.") - - return " + ".join(slot_items) - - def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str: - r"""Return the jinja template.""" - prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer) - system = self._convert_slots_to_jinja(self.format_system.apply(), tokenizer, placeholder="system_message") - user = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer) - assistant = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer) - jinja_template = "" - if prefix: - jinja_template += "{{ " + prefix + " }}" - - if self.default_system: - jinja_template += "{% set system_message = '" + self._jinja_escape(self.default_system) + "' %}" - - jinja_template += ( - "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}" - "{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}" - "{% if system_message is defined %}{{ " + system + " }}{% endif %}" - "{% for message in loop_messages %}" - "{% set content = message['content'] %}" - "{% if message['role'] == 'user' %}" - "{{ " + user + " }}" - "{% elif message['role'] == 'assistant' %}" - "{{ " + assistant + " }}" - "{% endif %}" - "{% endfor %}" - ) - return jinja_template - - def fix_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> None: - r"""Replace the jinja template in the tokenizer.""" - if tokenizer.chat_template is None or self.replace_jinja_template: - try: - tokenizer.chat_template = self._get_jinja_template(tokenizer) - except ValueError as e: - logger.warning(f"Cannot add this chat template to tokenizer: {e}.") - @dataclass class ReasoningTemplate(Template): @@ -355,7 +281,6 @@ def register_template( thought_words: Optional[tuple[str, str]] = None, efficient_eos: bool = False, replace_eos: bool = False, - replace_jinja_template: bool = False, enable_thinking: Optional[bool] = True, mm_plugin: "BasePlugin" = get_mm_plugin(name="base"), template_class: type["Template"] = Template, @@ -406,7 +331,6 @@ def register_template( thought_words=thought_words or ("\n", "\n\n\n"), efficient_eos=efficient_eos, replace_eos=replace_eos, - replace_jinja_template=replace_jinja_template, enable_thinking=enable_thinking, mm_plugin=mm_plugin, ) @@ -467,7 +391,6 @@ def find_diff(short_str: str, long_str: str) -> str: thought_words=("\n", "\n\n\n"), efficient_eos=False, replace_eos=False, - replace_jinja_template=False, enable_thinking=True, mm_plugin=get_mm_plugin(name="base"), ) @@ -489,22 +412,16 @@ def get_template_and_fix_tokenizer(dataset_config) -> "Template": template = TEMPLATES[dataset_config["template"]] - if dataset_config["train_on_prompt"] and template.efficient_eos: - raise ValueError("Current template does not support `train_on_prompt`.") - if dataset_config["tool_format"] is not None: - # logger.info_rank0(f"Using tool format: {dataset_config['tool_format']}.") default_slots = ["{{content}}"] if template.efficient_eos else ["{{content}}", {"eos_token"}] template.format_function = FunctionFormatter(slots=default_slots, tool_format=dataset_config["tool_format"]) template.format_tools = ToolFormatter(tool_format=dataset_config["tool_format"]) if dataset_config["default_system"] is not None: - # logger.info_rank0(f"Using default system message: {dataset_config['default_system']}.") template.default_system = dataset_config["default_system"] template.enable_thinking = dataset_config["enable_thinking"] template.fix_special_tokens(tokenizer) - template.fix_jinja_template(tokenizer) return template @@ -513,7 +430,6 @@ def get_template_and_fix_tokenizer(dataset_config) -> "Template": format_user=StringFormatter(slots=["Human: {{content}}", {"eos_token"}, "\nAssistant:"]), format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]), format_system=StringFormatter(slots=["System: {{content}}", {"eos_token"}, "\n"]), - replace_jinja_template=True, ) register_template( @@ -545,16 +461,16 @@ def get_template_and_fix_tokenizer(dataset_config) -> "Template": register_template( name="qwen", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), - format_assistant=StringFormatter(slots=["\n\n\n\n{{content}}<|im_end|>\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"), format_observation=StringFormatter( slots=["<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n"] ), format_tools=ToolFormatter(tool_format="qwen"), - default_system="", + default_system="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.", stop_words=["<|im_end|>"], - replace_eos=False, + replace_eos=True, ) @@ -562,7 +478,7 @@ def get_template_and_fix_tokenizer(dataset_config) -> "Template": register_template( name="qwen3", format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), - format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]), + format_assistant=StringFormatter(slots=["{{content}}<|im_end|>"]), format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"), format_observation=StringFormatter( @@ -625,3 +541,72 @@ def get_template_and_fix_tokenizer(dataset_config) -> "Template": replace_eos=True, mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"), ) + +register_template( + name="glm4", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), + format_assistant=StringFormatter(slots=["\n{{content}}"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), + format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"), + format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), + format_tools=ToolFormatter(tool_format="glm4"), + format_prefix=EmptyFormatter(slots=["[gMASK]"]), + stop_words=["<|user|>", "<|observation|>"], + efficient_eos=True, +) + + +# copied from glm4 template +register_template( + name="glm4_moe", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), + format_assistant=StringFormatter(slots=["\n{{content}}"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), + format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4_moe"), + format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), + format_tools=ToolFormatter(tool_format="glm4_moe"), + format_prefix=EmptyFormatter(slots=["[gMASK]"]), + stop_words=["<|user|>", "<|observation|>"], + efficient_eos=True, + template_class=ReasoningTemplate, +) + + +# copied from glm4 template +register_template( + name="glm4v", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), + format_assistant=StringFormatter(slots=["\n{{content}}"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), + format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"), + format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), + format_tools=ToolFormatter(tool_format="glm4"), + format_prefix=EmptyFormatter(slots=["[gMASK]"]), + stop_words=["<|user|>", "<|observation|>", ""], + efficient_eos=True, + mm_plugin=get_mm_plugin(name="glm4v", image_token="<|image|>", video_token="<|video|>"), + template_class=ReasoningTemplate, +) + + +# copied from glm4 template +register_template( + name="glm4v_moe", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), + format_assistant=StringFormatter(slots=["\n{{content}}"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), + format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4_moe"), + format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), + format_tools=ToolFormatter(tool_format="glm4_moe"), + format_prefix=EmptyFormatter(slots=["[gMASK]"]), + stop_words=["<|user|>", "<|observation|>", ""], + efficient_eos=True, + mm_plugin=get_mm_plugin(name="glm4v", image_token="<|image|>", video_token="<|video|>"), + template_class=ReasoningTemplate, +) + +register_template( + name="deepseek3", + format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), +) diff --git a/paddleformers/datasets/template/tool_utils.py b/paddleformers/datasets/template/tool_utils.py index 5a04da57b05..5de63ce82dc 100644 --- a/paddleformers/datasets/template/tool_utils.py +++ b/paddleformers/datasets/template/tool_utils.py @@ -18,7 +18,6 @@ import json -import re from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any, NamedTuple, Union @@ -50,6 +49,23 @@ class FunctionCall(NamedTuple): ) +GLM4_TOOL_PROMPT = ( + "你是一个名为 ChatGLM 的人工智能助手。你是基于智谱 AI 公司训练的语言模型 GLM-4 模型开发的," "你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{tool_text}" +) + +GLM4_MOE_TOOL_PROMPT = ( + "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n" + "You are provided with function signatures within XML tags:\n{tool_text}" + "\n\n\nFor each function call, output the function name and arguments within the following XML format:" + "\n{{function-name}}" + "\n{{arg-key-1}}" + "\n{{arg-value-1}}" + "\n{{arg-key-2}}" + "\n{{arg-value-2}}" + "\n...\n\n" +) + + @dataclass class ToolUtils(ABC): """Base class for tool utilities.""" @@ -66,15 +82,6 @@ def function_formatter(functions: list["FunctionCall"]) -> str: r"""Generate the assistant message including all the tool calls.""" ... - @staticmethod - @abstractmethod - def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]: - r"""Extract all the function calls from the assistant message. - - It should be an inverse function of `function_formatter`. - """ - ... - class DefaultToolUtils(ToolUtils): r"""Default tool using template.""" @@ -119,26 +126,6 @@ def tool_formatter(tools: list[dict[str, Any]]) -> str: def function_formatter(functions: list["FunctionCall"]) -> str: return "\n".join([f"Action: {name}\nAction Input: {arguments}" for name, arguments in functions]) - @override - @staticmethod - def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]: - regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL) - action_match: list[tuple[str, str]] = re.findall(regex, content) - if not action_match: - return content - - results = [] - for match in action_match: - tool_name = match[0].strip() - tool_input = match[1].strip().strip('"').strip("```") - try: - arguments = json.loads(tool_input) - results.append(FunctionCall(tool_name, json.dumps(arguments, ensure_ascii=False))) - except json.JSONDecodeError: - return content - - return results - class QwenToolUtils(ToolUtils): r"""Qwen 2.5 tool using template.""" @@ -162,32 +149,82 @@ def function_formatter(functions: list["FunctionCall"]) -> str: ] return "\n".join([f"\n{text}\n" for text in function_texts]) + +class GLM4ToolUtils(ToolUtils): + r"""GLM-4 tool using template.""" + + @override + @staticmethod + def tool_formatter(tools: list[dict[str, Any]]) -> str: + tool_text = "" + for tool in tools: + tool = tool.get("function", "") if tool.get("type") == "function" else tool + tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format( + name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False) + ) + + return GLM4_TOOL_PROMPT.format(tool_text=tool_text) + + @override + @staticmethod + def function_formatter(functions: list["FunctionCall"]) -> str: + if len(functions) > 1: + raise ValueError("GLM-4 does not support parallel functions.") + + return f"{functions[0].name}\n{functions[0].arguments}" + @override @staticmethod def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]: - regex = re.compile(r"(.+?)(?=\s*|\s*$)", re.DOTALL) - tool_match: list[str] = re.findall(regex, content) - if not tool_match: + if "\n" not in content: return content - results = [] - for tool in tool_match: - try: - tool = json.loads(tool.strip()) - except json.JSONDecodeError: - return content + tool_name, tool_input = content.split("\n", maxsplit=1) + try: + arguments = json.loads(tool_input.strip()) + except json.JSONDecodeError: + return content + + return [FunctionCall(tool_name, json.dumps(arguments, ensure_ascii=False))] + - if "name" not in tool or "arguments" not in tool: - return content +class GLM4MOEToolUtils(QwenToolUtils): + r"""GLM-4-MOE tool using template.""" - results.append(FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False))) + @override + @staticmethod + def tool_formatter(tools: list[dict[str, Any]]) -> str: + tool_text = "" + for tool in tools: + wrapped_tool = tool if tool.get("type") == "function" else {"type": "function", "function": tool} + tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False) + + return GLM4_MOE_TOOL_PROMPT.format(tool_text=tool_text) + + @override + @staticmethod + def function_formatter(functions: list["FunctionCall"]) -> str: + function_json = [ + {"func_name": name, "func_key_values": json.loads(arguments)} for name, arguments in functions + ] + function_texts = [] + for func in function_json: + prompt = "\n" + func["func_name"] + for key, value in func["func_key_values"].items(): + prompt += "\n" + key + "" + if not isinstance(value, str): + value = json.dumps(value, ensure_ascii=False) + prompt += "\n" + value + "" + function_texts.append(prompt) - return results + return "\n".join(function_texts) TOOLS = { "default": DefaultToolUtils(), "qwen": QwenToolUtils(), + "glm4": GLM4ToolUtils(), + "glm4_moe": GLM4MOEToolUtils(), } diff --git a/scripts/regression/test_sft_tiny-random-glm4moe.py b/scripts/regression/test_sft_tiny-random-glm4moe.py index 9a3025c25d6..a5fba6f6f36 100644 --- a/scripts/regression/test_sft_tiny-random-glm4moe.py +++ b/scripts/regression/test_sft_tiny-random-glm4moe.py @@ -31,12 +31,12 @@ MAX_STEPS = 6 SAVE_STEPS = 4 -SFT_FULL_EXCEPTED_LOSS = 13.091749 -SFT_FULL_RESUME_EXCEPTED_LOSS = 13.080153 +SFT_FULL_EXCEPTED_LOSS = 13.043166 +SFT_FULL_RESUME_EXCEPTED_LOSS = 13.035461 SFT_FULL_EXCEPTED_RESULT = [[51172, 37927, 96130, 27654, 133362, 95331, 27654, 133362, 115845, 115845]] -SFT_LORA_EXCEPTED_LOSS = 13.092138 -SFT_LORA_RESUME_EXCEPTED_LOSS = 13.081409 +SFT_LORA_EXCEPTED_LOSS = 13.04369 +SFT_LORA_RESUME_EXCEPTED_LOSS = 13.036311 SFT_LORA_EXCEPTED_RESULT = [[51172, 37927, 96130, 27654, 133362, 95331, 27654, 133362, 115845, 115845]] SFT_FULL_TP_PP_EXCEPTED_LOSS = 11.92912 @@ -47,8 +47,8 @@ SFT_LORA_TP_PP_RESUME_EXCEPTED_LOSS = 11.929088 SFT_LORA_TP_PP_EXCEPTED_RESULT = [[51172, 37927, 96130, 27654, 133362, 95331, 27654, 133362, 115845, 115845]] -SFT_FC_EXCEPTED_LOSS = 12.862782 -SFT_FC_RESUME_EXCEPTED_LOSS = 12.867558 +SFT_FC_EXCEPTED_LOSS = 12.859675 +SFT_FC_RESUME_EXCEPTED_LOSS = 12.863781 SFT_FC_EXCEPTED_RESULT = [[51172, 37927, 96130, 27654, 133362, 95331, 27654, 133362, 115845, 115845]] os.environ["NVIDIA_TF32_OVERRIDE"] = "0"