diff --git a/paddleformers/cli/hparams/data_args.py b/paddleformers/cli/hparams/data_args.py
index f49621a5156..b1e1372a94e 100644
--- a/paddleformers/cli/hparams/data_args.py
+++ b/paddleformers/cli/hparams/data_args.py
@@ -90,7 +90,7 @@ class DataArguments:
metadata={"help": "Number of candidate responses."},
)
mask_out_eos_token: bool = field(
- default=True,
+ default=False,
metadata={"help": "Mask out eos token"},
)
random_shuffle: bool = field(
diff --git a/paddleformers/cli/train/dpo/workflow.py b/paddleformers/cli/train/dpo/workflow.py
index ce42a9e6de6..8eeeef3be24 100644
--- a/paddleformers/cli/train/dpo/workflow.py
+++ b/paddleformers/cli/train/dpo/workflow.py
@@ -263,14 +263,12 @@ def run_dpo(
"mix_strategy": data_args.mix_strategy,
"encode_one_turn": data_args.encode_one_turn,
"stage": model_args.stage,
- "is_valid": False,
"template_backend": data_args.template_backend,
}
dataset_config.update(
{
"template": data_args.template,
- "train_on_prompt": False,
"tool_format": None,
"default_system": None,
"enable_thinking": True,
@@ -325,11 +323,11 @@ def run_dpo(
train_dataset = None
if training_args.do_eval and training_args.should_load_dataset:
- dataset_config["is_valid"] = True
eval_dataset = create_dataset(
task_group=data_args.eval_dataset_path,
task_group_prob=data_args.eval_dataset_prob,
sub_dataset_type=data_args.eval_dataset_type,
+ is_valid=True,
**dataset_config,
)
else:
diff --git a/paddleformers/cli/train/sft/workflow.py b/paddleformers/cli/train/sft/workflow.py
index 94ad5b48321..862756d292a 100644
--- a/paddleformers/cli/train/sft/workflow.py
+++ b/paddleformers/cli/train/sft/workflow.py
@@ -336,7 +336,6 @@ def neft_post_hook(module, input, output):
"is_pretraining": True if model_args.stage.lower() == "pt" else False,
"truncate_packing": data_args.truncate_packing,
"stage": model_args.stage,
- "is_valid": False,
"template_backend": data_args.template_backend,
"split_multi_turn": data_args.split_multi_turn,
}
@@ -344,7 +343,6 @@ def neft_post_hook(module, input, output):
dataset_config.update(
{
"template": data_args.template,
- "train_on_prompt": False,
"tool_format": None,
"default_system": None,
"enable_thinking": True,
@@ -373,11 +371,11 @@ def neft_post_hook(module, input, output):
sub_dataset_type=data_args.train_dataset_type,
**dataset_config,
)
- dataset_config["is_valid"] = True
eval_dataset = create_dataset_sft(
task_group=data_args.eval_dataset_path,
task_group_prob=data_args.eval_dataset_prob,
sub_dataset_type=data_args.eval_dataset_type,
+ is_valid=True,
**dataset_config,
)
diff --git a/paddleformers/datasets/DPODataset.py b/paddleformers/datasets/DPODataset.py
index 94f63fe8bd2..2ad2076431f 100644
--- a/paddleformers/datasets/DPODataset.py
+++ b/paddleformers/datasets/DPODataset.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import os
from copy import deepcopy
from dataclasses import dataclass
from typing import List, Optional
@@ -19,7 +20,7 @@
import numpy as np
from paddle.io import IterableDataset
-from paddleformers.datasets.data_utils import postprocess_fc_sequence
+from paddleformers.datasets.data_utils import postprocess_fc_sequence, print_debug_info
from paddleformers.datasets.reader.mix_datasets import create_dataset_instance
from paddleformers.datasets.reader.multi_source_datasets import MultiSourceDataset
from paddleformers.transformers.tokenizer_utils import PretrainedTokenizer
@@ -240,9 +241,10 @@ def __postprocess_before_concat(self, example):
response_token_ids_list = []
response_label_ids_list = []
response_len_list = []
+ split_index = example["session_start_index"] // 2
for responses in [
- chosen_encoded_messages[example["session_start_index"] // 2 :],
- rejected_encoded_messages[example["session_start_index"] // 2 :],
+ chosen_encoded_messages[split_index:],
+ rejected_encoded_messages[split_index:],
]:
responses_token_ids = []
responses_label_ids = []
@@ -274,9 +276,9 @@ def __postprocess_before_concat(self, example):
cur_len += sum(map(len, response_token_ids_list))
# create at least one turn
- turn_index = len(chosen_encoded_messages) - 1
+ turn_index = split_index
while turn_index >= 0:
- if turn_index == len(chosen_encoded_messages) - 1:
+ if turn_index == split_index:
cur_turn_token = chosen_encoded_messages[turn_index][0]
else:
cur_turn_token = chosen_encoded_messages[turn_index][0] + chosen_encoded_messages[turn_index][1]
@@ -289,7 +291,7 @@ def __postprocess_before_concat(self, example):
turn_index -= 1
# at least one turn
- if turn_index == len(chosen_encoded_messages) - 1:
+ if turn_index == split_index:
sub_src = example["chosen"]["messages"][0]["content"].strip()[:5]
global LOGGER_COUNT
LOGGER_COUNT += 1
@@ -364,6 +366,30 @@ def _postprocess_sequence(self, example):
prompt_len : (prompt_len + chosen_len),
] = False
attn_mask_startend_row_indices = None
+
+ # print
+ enable_dataset_debug = os.getenv("FLAGS_enable_dataset_debug", "false").lower() in ("true", "1", "t")
+ if enable_dataset_debug:
+ logger.info("\n" + "=" * 50)
+ logger.info("[dataset debug] Debug mode enabled")
+ if hasattr(self, "tokenizer"):
+ print("========================================")
+ print_debug_info(self.tokenizer, input_ids, "input")
+ print("========================================\n")
+
+ filtered_labels = [x for x in chosen_labels if x != 0] # remove -100
+ print("========================================")
+ print_debug_info(self.tokenizer, filtered_labels, "chosen_labels")
+ print("========================================\n")
+
+ filtered_labels = [x for x in rejected_labels if x != 0] # remove -100
+ print("========================================")
+ print_debug_info(self.tokenizer, filtered_labels, "rejected_labels")
+ print("========================================\n")
+ else:
+ logger.info("[dataset debug] Tokenizer not available")
+ logger.info("=" * 50 + "\n")
+
# 2. return sequence
return Sequence(
token_ids=input_ids,
diff --git a/paddleformers/datasets/SFTDataset.py b/paddleformers/datasets/SFTDataset.py
index 7dcfc79e7f4..39053a56810 100644
--- a/paddleformers/datasets/SFTDataset.py
+++ b/paddleformers/datasets/SFTDataset.py
@@ -34,7 +34,6 @@ class Sequence:
token_ids: List[int]
position_ids: List[int]
labels: List[int]
- loss_mask: List[int]
num_examples: int
images: List[str]
videos: List[str]
@@ -51,6 +50,7 @@ def __init__(self, **dataset_config):
self.template = dataset_config.get("template_instance", None)
self.template_backend = dataset_config.get("template_backend", "jinja")
self.use_template = dataset_config.get("use_template", True)
+ self.efficient_eos = True if not self.template else getattr(self.template, "efficient_eos", False)
self.split_multi_turn = dataset_config.get("split_multi_turn", False)
self.encode_one_turn = dataset_config.get("encode_one_turn", True)
self.is_pretraining = dataset_config.get("is_pretraining", False)
@@ -142,13 +142,11 @@ def __iter_func(self):
res_tokens = cut_tokens[:-1]
res_labels = cut_tokens[1:]
- loss_mask = [1] * len(res_tokens)
pos_ids = list(range(len(res_tokens)))
sequence = Sequence(
token_ids=res_tokens,
position_ids=pos_ids,
labels=res_labels,
- loss_mask=loss_mask,
num_examples=actual_example_num,
images=[],
videos=[],
@@ -172,13 +170,11 @@ def __iter_func(self):
cut_tokens = cut_tokens + [self.tokenizer.eos_token_id]
res_tokens = cut_tokens[:-1]
res_labels = cut_tokens[1:]
- loss_mask = [1] * len(res_tokens)
pos_ids = list(range(len(res_tokens)))
sequence = Sequence(
token_ids=res_tokens,
position_ids=pos_ids,
labels=res_labels,
- loss_mask=loss_mask,
num_examples=actual_example_num,
images=[],
videos=[],
@@ -313,13 +309,11 @@ def _postprocess_pretraining_sequence(self, example, actual_example_num):
tokens = self._encode_pretraining_example(example, actual_example_num)
res_tokens = tokens[:-1]
res_labels = tokens[1:]
- loss_mask = [1] * len(res_tokens)
pos_ids = list(range(len(res_tokens)))
sequence = Sequence(
token_ids=res_tokens,
position_ids=pos_ids,
labels=res_labels,
- loss_mask=loss_mask,
num_examples=actual_example_num,
images=[],
videos=[],
@@ -377,7 +371,7 @@ def _postprocess_sequence(self, example, actual_example_num):
turn_index = len(encoded_pairs) - 1
tokens = []
- loss_mask = []
+ labels = []
while turn_index >= 0:
tokens_src, tokens_target = encoded_pairs[turn_index]
if len(tokens_target) == 0:
@@ -394,12 +388,16 @@ def _postprocess_sequence(self, example, actual_example_num):
reverse_len = self.max_seq_len + 1 - cur_len - num_reserved_tokens_for_each_turn - len(tokens_src)
tokens_target = tokens_target[:reverse_len]
+ if self.use_template and self.efficient_eos and turn_index != 0:
+ labels_src = [self.tokenizer.eos_token_id] + [-100] * (len(tokens_src) - 1)
+ else:
+ labels_src = [-100] * len(tokens_src)
+
+ labels_target = tokens_target
tokens = tokens_src + tokens_target + tokens
+ labels = labels_src + labels_target + labels
- loss_mask = (
- [0] * (len(tokens_src) - 1) + [example["label"][turn_index]] * (len(tokens_target) + 1) + loss_mask
- )
- assert len(tokens) == len(loss_mask), f"{len(tokens)}-{len(loss_mask)}"
+ assert len(tokens) == len(labels), f"{len(tokens)}-{len(labels)}"
cur_len = len(tokens)
@@ -424,15 +422,18 @@ def _postprocess_sequence(self, example, actual_example_num):
# Maybe left truncated, so need to add begin_token
if tokens[0] != self.begin_token_id:
tokens = [self.begin_token_id] + tokens
- loss_mask = [0] + loss_mask
+ labels = [-100] + labels
if len(tokens) > self.max_seq_len:
raise RuntimeError(f"token_ids is too long: {len(tokens)}")
# Add EOS token at the end
del tokens[-1]
- del loss_mask[-1]
- labels = tokens[1:] + [self.tokenizer.eos_token_id]
+ del labels[-1]
+ if self.efficient_eos:
+ tokens = tokens + [self.tokenizer.eos_token_id]
+ labels = labels + [self.tokenizer.eos_token_id]
+ labels = labels[1:] + [-100]
# end_of_response is a special token that indicates the end of the turn.
# end_token is a special token that indicates the end of the answer.
@@ -440,25 +441,23 @@ def _postprocess_sequence(self, example, actual_example_num):
label if label != self.end_of_response_id else self.tokenizer.eos_token_id for label in labels
]
else:
- tokens = tokens[:-1] + [self.tokenizer.eos_token_id]
- labels = tokens[1:] + [-100]
+ if self.efficient_eos:
+ tokens = tokens + [self.tokenizer.eos_token_id]
+ labels = labels + [self.tokenizer.eos_token_id]
+ labels = labels[1:] + [-100]
if len(tokens) > self.max_seq_len:
raise RuntimeError(f"token_ids is too long: {len(tokens)}")
else:
- oral_tokens = tokens
- tokens = oral_tokens[:-1]
- labels = oral_tokens[1:]
- loss_mask = loss_mask[:-1]
+ labels = tokens[1:] + [-100]
if len(tokens) > self.max_seq_len:
raise RuntimeError(f"token_ids is too long: {len(tokens)}")
pos_ids = list(range(len(tokens)))
- if sum(loss_mask) == 0:
+ if all(x == -100 for x in labels):
logger.warning(f"[SKIP] all labels set to 0: {example}")
return None
- assert len(tokens) == len(loss_mask), f"{len(tokens)}-{len(loss_mask)}"
assert len(tokens) == len(labels), f"{len(tokens)}-{len(labels)}"
enable_dataset_debug = os.getenv("FLAGS_enable_dataset_debug", "false").lower() in ("true", "1", "t")
@@ -470,12 +469,10 @@ def _postprocess_sequence(self, example, actual_example_num):
print_debug_info(self.tokenizer, tokens, "input")
print("========================================\n")
- filtered_labels = [label if mask == 1 else -100 for label, mask in zip(labels, loss_mask)]
- filtered_labels = [x for x in filtered_labels if x != -100] # remove -100
+ filtered_labels = [x for x in labels if x != -100] # remove -100
print("========================================")
print_debug_info(self.tokenizer, filtered_labels, "labels")
print("========================================\n")
- logger.info(f"[dataset debug] loss mask: {loss_mask}")
else:
logger.info("[dataset debug] Tokenizer not available")
logger.info("=" * 50 + "\n")
@@ -484,7 +481,6 @@ def _postprocess_sequence(self, example, actual_example_num):
token_ids=tokens,
position_ids=pos_ids,
labels=labels,
- loss_mask=loss_mask,
num_examples=actual_example_num,
images=images,
videos=videos,
diff --git a/paddleformers/datasets/collate.py b/paddleformers/datasets/collate.py
index f414a287421..a0f72a1dfb5 100644
--- a/paddleformers/datasets/collate.py
+++ b/paddleformers/datasets/collate.py
@@ -165,7 +165,6 @@ def collate_fn(batch: List[List[Sequence]], tokenizer, training_args, model_args
dict: Dictionary containing:
- input_ids: Padded token IDs
- labels: Shifted labels for prediction
- - loss_mask: Mask for computing loss
"""
input_keys = ["input_ids", "labels", "position_ids"]
if training_args.num_nextn_predict_layers > 0:
@@ -180,15 +179,12 @@ def collate_fn(batch: List[List[Sequence]], tokenizer, training_args, model_args
for batch_sequence in batch:
original_token_ids = [seq.token_ids for seq in batch_sequence]
token_ids = [sum(original_token_ids, [])]
- loss_mask = [sum([seq.loss_mask for seq in batch_sequence], [])]
labels = [sum([seq.labels for seq in batch_sequence], [])]
position_ids = [sum([seq.position_ids for seq in batch_sequence], [])]
# padding
padded_token_ids = pad_batch_data(token_ids, pad_idx=tokenizer.pad_token_id, max_seq_len=max_seq_len)
- padded_labels = pad_batch_data(labels, pad_idx=tokenizer.pad_token_id, max_seq_len=max_seq_len)
- padded_loss_mask = pad_batch_data(loss_mask, pad_idx=0, max_seq_len=max_seq_len)
+ padded_labels = pad_batch_data(labels, pad_idx=-100, max_seq_len=max_seq_len)
padded_position_ids = pad_batch_data(position_ids, pad_idx=0, max_seq_len=max_seq_len)
- padded_labels = np.where(padded_loss_mask == 1, padded_labels, -100)
return_list.append(
[
padded_token_ids,
diff --git a/paddleformers/datasets/data_utils.py b/paddleformers/datasets/data_utils.py
index 309f2a0b7c0..edb54ff4b2a 100644
--- a/paddleformers/datasets/data_utils.py
+++ b/paddleformers/datasets/data_utils.py
@@ -285,7 +285,7 @@ def estimate_training(train_dataset, data_args, training_args, model_args):
"sharding_parallel_degree": int(training_args.sharding_parallel_degree),
"num_samples_each_epoch": data_args.num_samples_each_epoch,
"max_seq_len": int(data_args.max_seq_len),
- "seed": data_args.seed,
+ "seed": training_args.seed,
"valid": False,
"train_samples": 0,
}
diff --git a/paddleformers/datasets/reader/file_reader.py b/paddleformers/datasets/reader/file_reader.py
index 7fb58e025c7..20f8caeb7d3 100644
--- a/paddleformers/datasets/reader/file_reader.py
+++ b/paddleformers/datasets/reader/file_reader.py
@@ -103,9 +103,25 @@ def data_check(self, data):
if len(data["messages"]) == 0:
raise ValueError("Ignore example with empty messages.")
+ for index in range(len(data["messages"])):
+ # Fix the role names for tool call and tool response
+ if data["messages"][index]["role"] == "tool" or data["messages"][index]["role"] == "tool_response":
+ data["messages"][index]["role"] = "observation"
+ if data["messages"][index]["role"] == "tool_call" or data["messages"][index]["role"] == "tool_calls":
+ data["messages"][index]["role"] = "function"
+ # Convert the content of tool call and tool response into a string
+ if (
+ data["messages"][index]["role"] == "observation" or data["messages"][index]["role"] == "function"
+ ) and not isinstance(data["messages"][index]["content"], str):
+ data["messages"][index]["content"] = json.dumps(data["messages"][index]["content"])
+ if "tool_calls" in data["messages"][index] and not isinstance(data["messages"][index]["tool_calls"], str):
+ data["messages"][index]["tool_calls"] = json.dumps(data["messages"][index]["tool_calls"])
+
+ # Convert the content of tool list into a string
if "tools" in data and not isinstance(data["tools"], str):
data["tools"] = json.dumps(data["tools"], ensure_ascii=False)
+ # If no label is input, it means each response needs to be learned.
if "label" not in data:
data["label"] = [
1 for turn in data["messages"] if ("assistant" in turn["role"] or "function" in turn["role"])
@@ -116,6 +132,7 @@ def data_check(self, data):
system = data["messages"][0]["content"]
if not isinstance(system, str):
raise ValueError("System field must be a string.")
+ data["messages"] = data["messages"][1:]
data["system"] = system
# Convert the relative paths of multimode data into absolute paths
diff --git a/paddleformers/datasets/template/mm_plugin.py b/paddleformers/datasets/template/mm_plugin.py
index 934c4e177f5..2832a198220 100644
--- a/paddleformers/datasets/template/mm_plugin.py
+++ b/paddleformers/datasets/template/mm_plugin.py
@@ -563,10 +563,141 @@ def process_messages(
return messages
+@dataclass
+class GLM4VPlugin(Qwen2VLPlugin):
+ @override
+ def _get_mm_inputs(
+ self,
+ images,
+ videos,
+ audios,
+ processor,
+ ):
+ image_processor = getattr(processor, "image_processor", None)
+ video_processor = getattr(processor, "video_processor", None)
+ mm_inputs = {}
+ if len(images) != 0:
+ images = self._regularize_images(
+ images,
+ image_max_pixels=getattr(processor, "image_max_pixels", 768 * 768),
+ image_min_pixels=getattr(processor, "image_min_pixels", 32 * 32),
+ )["images"]
+ mm_inputs.update(image_processor(images, return_tensors="pt"))
+
+ if len(videos) != 0:
+ video_data = self._regularize_videos(
+ videos,
+ image_max_pixels=getattr(processor, "video_max_pixels", 256 * 256),
+ image_min_pixels=getattr(processor, "video_min_pixels", 16 * 16),
+ video_fps=getattr(processor, "video_fps", 2.0),
+ video_maxlen=getattr(processor, "video_maxlen", 128),
+ )
+ # prepare video metadata
+ video_metadata = [
+ {"fps": 2, "duration": len(video), "total_frames": len(video)} for video in video_data["videos"]
+ ]
+ mm_inputs.update(video_processor(images=None, videos=video_data["videos"], video_metadata=video_metadata))
+
+ return mm_inputs
+
+ @override
+ def process_messages(
+ self,
+ messages,
+ images,
+ videos,
+ audios,
+ processor,
+ ):
+ self._validate_input(processor, images, videos, audios)
+ self._validate_messages(messages, images, videos, audios)
+ num_image_tokens, num_video_tokens = 0, 0
+ messages = deepcopy(messages)
+ image_processor = getattr(processor, "image_processor")
+
+ merge_length = getattr(image_processor, "merge_size") ** 2
+ if self.expand_mm_tokens:
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
+ image_grid_thw = mm_inputs.get("image_grid_thw", [])
+ video_grid_thw = mm_inputs.get("video_grid_thw", [])
+ num_frames = video_grid_thw[0][0] if len(video_grid_thw) > 0 else 0 # hard code for now
+ timestamps = mm_inputs.get("timestamps", [])
+
+ if hasattr(timestamps, "tolist"):
+ timestamps = timestamps.tolist()
+
+ if not timestamps:
+ timestamps_list = []
+ elif isinstance(timestamps[0], list):
+ timestamps_list = timestamps[0]
+ else:
+ timestamps_list = timestamps
+
+ unique_timestamps = timestamps_list.copy()
+ selected_timestamps = unique_timestamps[:num_frames]
+ while len(selected_timestamps) < num_frames:
+ selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0)
+
+ else:
+ image_grid_thw = [None] * len(images)
+ video_grid_thw = [None] * len(videos)
+ num_frames = 0
+ selected_timestamps = [0]
+
+ for message in messages:
+ content = message["content"]
+ while IMAGE_PLACEHOLDER in content:
+ image_seqlen = image_grid_thw[num_image_tokens].prod() // merge_length if self.expand_mm_tokens else 1
+ content = content.replace(
+ IMAGE_PLACEHOLDER, f"<|begin_of_image|>{self.image_token * image_seqlen}<|end_of_image|>", 1
+ )
+ num_image_tokens += 1
+
+ while VIDEO_PLACEHOLDER in content:
+ video_structure = ""
+ for frame_index in range(num_frames):
+ video_seqlen = (
+ video_grid_thw[num_video_tokens][1:].prod() // merge_length if self.expand_mm_tokens else 1
+ )
+ timestamp_sec = selected_timestamps[frame_index]
+ frame_structure = (
+ f"<|begin_of_image|>{self.image_token * video_seqlen}<|end_of_image|>{timestamp_sec}"
+ )
+ video_structure += frame_structure
+
+ if not self.expand_mm_tokens:
+ video_structure = self.video_token
+
+ content = content.replace(VIDEO_PLACEHOLDER, f"<|begin_of_video|>{video_structure}<|end_of_video|>", 1)
+ num_video_tokens += 1
+
+ message["content"] = content
+
+ return messages
+
+ @override
+ def get_mm_inputs(
+ self,
+ images,
+ videos,
+ audios,
+ imglens,
+ vidlens,
+ audlens,
+ batch_ids,
+ processor,
+ ):
+ self._validate_input(processor, images, videos, audios)
+ mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
+ mm_inputs.pop("timestamps", None)
+ return mm_inputs
+
+
PLUGINS = {
"base": BasePlugin,
"qwen2_vl": Qwen2VLPlugin,
"qwen3_vl": Qwen3VLPlugin,
+ "glm4v": GLM4VPlugin,
}
diff --git a/paddleformers/datasets/template/template.py b/paddleformers/datasets/template/template.py
index 2604ae94e58..4051cc0827a 100644
--- a/paddleformers/datasets/template/template.py
+++ b/paddleformers/datasets/template/template.py
@@ -20,7 +20,7 @@
from copy import deepcopy
from dataclasses import dataclass
from enum import Enum, unique
-from typing import TYPE_CHECKING, Optional, Union
+from typing import TYPE_CHECKING, Optional
from typing_extensions import override
@@ -34,7 +34,6 @@
from .formatter import SLOTS, Formatter
from .mm_plugin import BasePlugin
- from .tool_utils import FunctionCall
@unique
@@ -60,7 +59,6 @@ class Template:
thought_words: tuple[str, str]
efficient_eos: bool
replace_eos: bool
- replace_jinja_template: bool
enable_thinking: Optional[bool]
mm_plugin: "BasePlugin"
@@ -91,18 +89,6 @@ def encode_multiturn(
encoded_messages = self._encode(tokenizer, messages, system, tools)
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
- def extract_tool(self, content: str) -> Union[str, list["FunctionCall"]]:
- r"""Extract tool message."""
- return self.format_tools.extract(content)
-
- def get_stop_token_ids(self, tokenizer: "PreTrainedTokenizer") -> list[int]:
- r"""Return stop token ids."""
- stop_token_ids = {tokenizer.eos_token_id}
- for token in self.stop_words:
- stop_token_ids.add(tokenizer.convert_tokens_to_ids(token))
-
- return list(stop_token_ids)
-
def add_thought(self, content: str = "") -> str:
r"""Add empty thought to assistant message."""
return f"{self.thought_words[0]}{self.thought_words[1]}" + content
@@ -162,6 +148,10 @@ def _encode(
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
elif message["role"] == Role.ASSISTANT:
elements += self.format_assistant.apply(content=message["content"])
+ if "tool_calls" in message:
+ elements += self.format_function.apply(
+ content=message["tool_calls"], thought_words=self.thought_words
+ )
elif message["role"] == Role.OBSERVATION:
elements += self.format_observation.apply(content=message["content"])
elif message["role"] == Role.FUNCTION:
@@ -215,70 +205,6 @@ def fix_special_tokens(self, tokenizer: "PreTrainedTokenizer") -> None:
if num_added_tokens > 0:
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
- @staticmethod
- def _jinja_escape(content: str) -> str:
- r"""Escape single quotes in content."""
- return content.replace("'", r"\'")
-
- @staticmethod
- def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
- r"""Convert slots to jinja template."""
- slot_items = []
- for slot in slots:
- if isinstance(slot, str):
- slot_pieces = slot.split("{{content}}")
- if slot_pieces[0]:
- slot_items.append("'" + Template._jinja_escape(slot_pieces[0]) + "'")
- if len(slot_pieces) > 1:
- slot_items.append(placeholder)
- if slot_pieces[1]:
- slot_items.append("'" + Template._jinja_escape(slot_pieces[1]) + "'")
- elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
- if "bos_token" in slot and tokenizer.bos_token_id is not None:
- slot_items.append("'" + tokenizer.bos_token + "'")
- elif "eos_token" in slot and tokenizer.eos_token_id is not None:
- slot_items.append("'" + tokenizer.eos_token + "'")
- elif isinstance(slot, dict):
- raise ValueError("Dict is not supported.")
-
- return " + ".join(slot_items)
-
- def _get_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> str:
- r"""Return the jinja template."""
- prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
- system = self._convert_slots_to_jinja(self.format_system.apply(), tokenizer, placeholder="system_message")
- user = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
- assistant = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer)
- jinja_template = ""
- if prefix:
- jinja_template += "{{ " + prefix + " }}"
-
- if self.default_system:
- jinja_template += "{% set system_message = '" + self._jinja_escape(self.default_system) + "' %}"
-
- jinja_template += (
- "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
- "{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
- "{% if system_message is defined %}{{ " + system + " }}{% endif %}"
- "{% for message in loop_messages %}"
- "{% set content = message['content'] %}"
- "{% if message['role'] == 'user' %}"
- "{{ " + user + " }}"
- "{% elif message['role'] == 'assistant' %}"
- "{{ " + assistant + " }}"
- "{% endif %}"
- "{% endfor %}"
- )
- return jinja_template
-
- def fix_jinja_template(self, tokenizer: "PreTrainedTokenizer") -> None:
- r"""Replace the jinja template in the tokenizer."""
- if tokenizer.chat_template is None or self.replace_jinja_template:
- try:
- tokenizer.chat_template = self._get_jinja_template(tokenizer)
- except ValueError as e:
- logger.warning(f"Cannot add this chat template to tokenizer: {e}.")
-
@dataclass
class ReasoningTemplate(Template):
@@ -355,7 +281,6 @@ def register_template(
thought_words: Optional[tuple[str, str]] = None,
efficient_eos: bool = False,
replace_eos: bool = False,
- replace_jinja_template: bool = False,
enable_thinking: Optional[bool] = True,
mm_plugin: "BasePlugin" = get_mm_plugin(name="base"),
template_class: type["Template"] = Template,
@@ -406,7 +331,6 @@ def register_template(
thought_words=thought_words or ("\n", "\n\n\n"),
efficient_eos=efficient_eos,
replace_eos=replace_eos,
- replace_jinja_template=replace_jinja_template,
enable_thinking=enable_thinking,
mm_plugin=mm_plugin,
)
@@ -467,7 +391,6 @@ def find_diff(short_str: str, long_str: str) -> str:
thought_words=("\n", "\n\n\n"),
efficient_eos=False,
replace_eos=False,
- replace_jinja_template=False,
enable_thinking=True,
mm_plugin=get_mm_plugin(name="base"),
)
@@ -489,22 +412,16 @@ def get_template_and_fix_tokenizer(dataset_config) -> "Template":
template = TEMPLATES[dataset_config["template"]]
- if dataset_config["train_on_prompt"] and template.efficient_eos:
- raise ValueError("Current template does not support `train_on_prompt`.")
-
if dataset_config["tool_format"] is not None:
- # logger.info_rank0(f"Using tool format: {dataset_config['tool_format']}.")
default_slots = ["{{content}}"] if template.efficient_eos else ["{{content}}", {"eos_token"}]
template.format_function = FunctionFormatter(slots=default_slots, tool_format=dataset_config["tool_format"])
template.format_tools = ToolFormatter(tool_format=dataset_config["tool_format"])
if dataset_config["default_system"] is not None:
- # logger.info_rank0(f"Using default system message: {dataset_config['default_system']}.")
template.default_system = dataset_config["default_system"]
template.enable_thinking = dataset_config["enable_thinking"]
template.fix_special_tokens(tokenizer)
- template.fix_jinja_template(tokenizer)
return template
@@ -513,7 +430,6 @@ def get_template_and_fix_tokenizer(dataset_config) -> "Template":
format_user=StringFormatter(slots=["Human: {{content}}", {"eos_token"}, "\nAssistant:"]),
format_assistant=StringFormatter(slots=["{{content}}", {"eos_token"}, "\n"]),
format_system=StringFormatter(slots=["System: {{content}}", {"eos_token"}, "\n"]),
- replace_jinja_template=True,
)
register_template(
@@ -545,16 +461,16 @@ def get_template_and_fix_tokenizer(dataset_config) -> "Template":
register_template(
name="qwen",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
- format_assistant=StringFormatter(slots=["\n\n\n\n{{content}}<|im_end|>\n"]),
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
slots=["<|im_start|>user\n\n{{content}}\n<|im_end|>\n<|im_start|>assistant\n"]
),
format_tools=ToolFormatter(tool_format="qwen"),
- default_system="",
+ default_system="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.",
stop_words=["<|im_end|>"],
- replace_eos=False,
+ replace_eos=True,
)
@@ -562,7 +478,7 @@ def get_template_and_fix_tokenizer(dataset_config) -> "Template":
register_template(
name="qwen3",
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
- format_assistant=StringFormatter(slots=["{{content}}<|im_end|>\n"]),
+ format_assistant=StringFormatter(slots=["{{content}}<|im_end|>"]),
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
format_function=FunctionFormatter(slots=["{{content}}<|im_end|>\n"], tool_format="qwen"),
format_observation=StringFormatter(
@@ -625,3 +541,72 @@ def get_template_and_fix_tokenizer(dataset_config) -> "Template":
replace_eos=True,
mm_plugin=get_mm_plugin(name="qwen3_vl", image_token="<|image_pad|>", video_token="<|video_pad|>"),
)
+
+register_template(
+ name="glm4",
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
+ format_assistant=StringFormatter(slots=["\n{{content}}"]),
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
+ format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
+ format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
+ format_tools=ToolFormatter(tool_format="glm4"),
+ format_prefix=EmptyFormatter(slots=["[gMASK]"]),
+ stop_words=["<|user|>", "<|observation|>"],
+ efficient_eos=True,
+)
+
+
+# copied from glm4 template
+register_template(
+ name="glm4_moe",
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
+ format_assistant=StringFormatter(slots=["\n{{content}}"]),
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
+ format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4_moe"),
+ format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
+ format_tools=ToolFormatter(tool_format="glm4_moe"),
+ format_prefix=EmptyFormatter(slots=["[gMASK]"]),
+ stop_words=["<|user|>", "<|observation|>"],
+ efficient_eos=True,
+ template_class=ReasoningTemplate,
+)
+
+
+# copied from glm4 template
+register_template(
+ name="glm4v",
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
+ format_assistant=StringFormatter(slots=["\n{{content}}"]),
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
+ format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
+ format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
+ format_tools=ToolFormatter(tool_format="glm4"),
+ format_prefix=EmptyFormatter(slots=["[gMASK]"]),
+ stop_words=["<|user|>", "<|observation|>", ""],
+ efficient_eos=True,
+ mm_plugin=get_mm_plugin(name="glm4v", image_token="<|image|>", video_token="<|video|>"),
+ template_class=ReasoningTemplate,
+)
+
+
+# copied from glm4 template
+register_template(
+ name="glm4v_moe",
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
+ format_assistant=StringFormatter(slots=["\n{{content}}"]),
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
+ format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4_moe"),
+ format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
+ format_tools=ToolFormatter(tool_format="glm4_moe"),
+ format_prefix=EmptyFormatter(slots=["[gMASK]"]),
+ stop_words=["<|user|>", "<|observation|>", ""],
+ efficient_eos=True,
+ mm_plugin=get_mm_plugin(name="glm4v", image_token="<|image|>", video_token="<|video|>"),
+ template_class=ReasoningTemplate,
+)
+
+register_template(
+ name="deepseek3",
+ format_user=StringFormatter(slots=["<|User|>{{content}}<|Assistant|>"]),
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
+)
diff --git a/paddleformers/datasets/template/tool_utils.py b/paddleformers/datasets/template/tool_utils.py
index 5a04da57b05..5de63ce82dc 100644
--- a/paddleformers/datasets/template/tool_utils.py
+++ b/paddleformers/datasets/template/tool_utils.py
@@ -18,7 +18,6 @@
import json
-import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, NamedTuple, Union
@@ -50,6 +49,23 @@ class FunctionCall(NamedTuple):
)
+GLM4_TOOL_PROMPT = (
+ "你是一个名为 ChatGLM 的人工智能助手。你是基于智谱 AI 公司训练的语言模型 GLM-4 模型开发的," "你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{tool_text}"
+)
+
+GLM4_MOE_TOOL_PROMPT = (
+ "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\n"
+ "You are provided with function signatures within XML tags:\n{tool_text}"
+ "\n\n\nFor each function call, output the function name and arguments within the following XML format:"
+ "\n{{function-name}}"
+ "\n{{arg-key-1}}"
+ "\n{{arg-value-1}}"
+ "\n{{arg-key-2}}"
+ "\n{{arg-value-2}}"
+ "\n...\n\n"
+)
+
+
@dataclass
class ToolUtils(ABC):
"""Base class for tool utilities."""
@@ -66,15 +82,6 @@ def function_formatter(functions: list["FunctionCall"]) -> str:
r"""Generate the assistant message including all the tool calls."""
...
- @staticmethod
- @abstractmethod
- def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
- r"""Extract all the function calls from the assistant message.
-
- It should be an inverse function of `function_formatter`.
- """
- ...
-
class DefaultToolUtils(ToolUtils):
r"""Default tool using template."""
@@ -119,26 +126,6 @@ def tool_formatter(tools: list[dict[str, Any]]) -> str:
def function_formatter(functions: list["FunctionCall"]) -> str:
return "\n".join([f"Action: {name}\nAction Input: {arguments}" for name, arguments in functions])
- @override
- @staticmethod
- def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
- regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
- action_match: list[tuple[str, str]] = re.findall(regex, content)
- if not action_match:
- return content
-
- results = []
- for match in action_match:
- tool_name = match[0].strip()
- tool_input = match[1].strip().strip('"').strip("```")
- try:
- arguments = json.loads(tool_input)
- results.append(FunctionCall(tool_name, json.dumps(arguments, ensure_ascii=False)))
- except json.JSONDecodeError:
- return content
-
- return results
-
class QwenToolUtils(ToolUtils):
r"""Qwen 2.5 tool using template."""
@@ -162,32 +149,82 @@ def function_formatter(functions: list["FunctionCall"]) -> str:
]
return "\n".join([f"\n{text}\n" for text in function_texts])
+
+class GLM4ToolUtils(ToolUtils):
+ r"""GLM-4 tool using template."""
+
+ @override
+ @staticmethod
+ def tool_formatter(tools: list[dict[str, Any]]) -> str:
+ tool_text = ""
+ for tool in tools:
+ tool = tool.get("function", "") if tool.get("type") == "function" else tool
+ tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
+ name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
+ )
+
+ return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
+
+ @override
+ @staticmethod
+ def function_formatter(functions: list["FunctionCall"]) -> str:
+ if len(functions) > 1:
+ raise ValueError("GLM-4 does not support parallel functions.")
+
+ return f"{functions[0].name}\n{functions[0].arguments}"
+
@override
@staticmethod
def tool_extractor(content: str) -> Union[str, list["FunctionCall"]]:
- regex = re.compile(r"(.+?)(?=\s*|\s*$)", re.DOTALL)
- tool_match: list[str] = re.findall(regex, content)
- if not tool_match:
+ if "\n" not in content:
return content
- results = []
- for tool in tool_match:
- try:
- tool = json.loads(tool.strip())
- except json.JSONDecodeError:
- return content
+ tool_name, tool_input = content.split("\n", maxsplit=1)
+ try:
+ arguments = json.loads(tool_input.strip())
+ except json.JSONDecodeError:
+ return content
+
+ return [FunctionCall(tool_name, json.dumps(arguments, ensure_ascii=False))]
+
- if "name" not in tool or "arguments" not in tool:
- return content
+class GLM4MOEToolUtils(QwenToolUtils):
+ r"""GLM-4-MOE tool using template."""
- results.append(FunctionCall(tool["name"], json.dumps(tool["arguments"], ensure_ascii=False)))
+ @override
+ @staticmethod
+ def tool_formatter(tools: list[dict[str, Any]]) -> str:
+ tool_text = ""
+ for tool in tools:
+ wrapped_tool = tool if tool.get("type") == "function" else {"type": "function", "function": tool}
+ tool_text += "\n" + json.dumps(wrapped_tool, ensure_ascii=False)
+
+ return GLM4_MOE_TOOL_PROMPT.format(tool_text=tool_text)
+
+ @override
+ @staticmethod
+ def function_formatter(functions: list["FunctionCall"]) -> str:
+ function_json = [
+ {"func_name": name, "func_key_values": json.loads(arguments)} for name, arguments in functions
+ ]
+ function_texts = []
+ for func in function_json:
+ prompt = "\n" + func["func_name"]
+ for key, value in func["func_key_values"].items():
+ prompt += "\n" + key + ""
+ if not isinstance(value, str):
+ value = json.dumps(value, ensure_ascii=False)
+ prompt += "\n" + value + ""
+ function_texts.append(prompt)
- return results
+ return "\n".join(function_texts)
TOOLS = {
"default": DefaultToolUtils(),
"qwen": QwenToolUtils(),
+ "glm4": GLM4ToolUtils(),
+ "glm4_moe": GLM4MOEToolUtils(),
}
diff --git a/scripts/regression/test_sft_tiny-random-glm4moe.py b/scripts/regression/test_sft_tiny-random-glm4moe.py
index 9a3025c25d6..a5fba6f6f36 100644
--- a/scripts/regression/test_sft_tiny-random-glm4moe.py
+++ b/scripts/regression/test_sft_tiny-random-glm4moe.py
@@ -31,12 +31,12 @@
MAX_STEPS = 6
SAVE_STEPS = 4
-SFT_FULL_EXCEPTED_LOSS = 13.091749
-SFT_FULL_RESUME_EXCEPTED_LOSS = 13.080153
+SFT_FULL_EXCEPTED_LOSS = 13.043166
+SFT_FULL_RESUME_EXCEPTED_LOSS = 13.035461
SFT_FULL_EXCEPTED_RESULT = [[51172, 37927, 96130, 27654, 133362, 95331, 27654, 133362, 115845, 115845]]
-SFT_LORA_EXCEPTED_LOSS = 13.092138
-SFT_LORA_RESUME_EXCEPTED_LOSS = 13.081409
+SFT_LORA_EXCEPTED_LOSS = 13.04369
+SFT_LORA_RESUME_EXCEPTED_LOSS = 13.036311
SFT_LORA_EXCEPTED_RESULT = [[51172, 37927, 96130, 27654, 133362, 95331, 27654, 133362, 115845, 115845]]
SFT_FULL_TP_PP_EXCEPTED_LOSS = 11.92912
@@ -47,8 +47,8 @@
SFT_LORA_TP_PP_RESUME_EXCEPTED_LOSS = 11.929088
SFT_LORA_TP_PP_EXCEPTED_RESULT = [[51172, 37927, 96130, 27654, 133362, 95331, 27654, 133362, 115845, 115845]]
-SFT_FC_EXCEPTED_LOSS = 12.862782
-SFT_FC_RESUME_EXCEPTED_LOSS = 12.867558
+SFT_FC_EXCEPTED_LOSS = 12.859675
+SFT_FC_RESUME_EXCEPTED_LOSS = 12.863781
SFT_FC_EXCEPTED_RESULT = [[51172, 37927, 96130, 27654, 133362, 95331, 27654, 133362, 115845, 115845]]
os.environ["NVIDIA_TF32_OVERRIDE"] = "0"