Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddleformers/cli/hparams/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class DataArguments:
metadata={"help": "Number of candidate responses."},
)
mask_out_eos_token: bool = field(
default=True,
default=False,
metadata={"help": "Mask out eos token"},
)
random_shuffle: bool = field(
Expand Down
4 changes: 1 addition & 3 deletions paddleformers/cli/train/dpo/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,14 +263,12 @@ def run_dpo(
"mix_strategy": data_args.mix_strategy,
"encode_one_turn": data_args.encode_one_turn,
"stage": model_args.stage,
"is_valid": False,
"template_backend": data_args.template_backend,
}

dataset_config.update(
{
"template": data_args.template,
"train_on_prompt": False,
"tool_format": None,
"default_system": None,
"enable_thinking": True,
Expand Down Expand Up @@ -325,11 +323,11 @@ def run_dpo(
train_dataset = None

if training_args.do_eval and training_args.should_load_dataset:
dataset_config["is_valid"] = True
eval_dataset = create_dataset(
task_group=data_args.eval_dataset_path,
task_group_prob=data_args.eval_dataset_prob,
sub_dataset_type=data_args.eval_dataset_type,
is_valid=True,
**dataset_config,
)
else:
Expand Down
4 changes: 1 addition & 3 deletions paddleformers/cli/train/sft/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,15 +336,13 @@ def neft_post_hook(module, input, output):
"is_pretraining": True if model_args.stage.lower() == "pt" else False,
"truncate_packing": data_args.truncate_packing,
"stage": model_args.stage,
"is_valid": False,
"template_backend": data_args.template_backend,
"split_multi_turn": data_args.split_multi_turn,
}

dataset_config.update(
{
"template": data_args.template,
"train_on_prompt": False,
"tool_format": None,
"default_system": None,
"enable_thinking": True,
Expand Down Expand Up @@ -373,11 +371,11 @@ def neft_post_hook(module, input, output):
sub_dataset_type=data_args.train_dataset_type,
**dataset_config,
)
dataset_config["is_valid"] = True
eval_dataset = create_dataset_sft(
task_group=data_args.eval_dataset_path,
task_group_prob=data_args.eval_dataset_prob,
sub_dataset_type=data_args.eval_dataset_type,
is_valid=True,
**dataset_config,
)

Expand Down
38 changes: 32 additions & 6 deletions paddleformers/datasets/DPODataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from copy import deepcopy
from dataclasses import dataclass
from typing import List, Optional

import numpy as np
from paddle.io import IterableDataset

from paddleformers.datasets.data_utils import postprocess_fc_sequence
from paddleformers.datasets.data_utils import postprocess_fc_sequence, print_debug_info
from paddleformers.datasets.reader.mix_datasets import create_dataset_instance
from paddleformers.datasets.reader.multi_source_datasets import MultiSourceDataset
from paddleformers.transformers.tokenizer_utils import PretrainedTokenizer
Expand Down Expand Up @@ -240,9 +241,10 @@ def __postprocess_before_concat(self, example):
response_token_ids_list = []
response_label_ids_list = []
response_len_list = []
split_index = example["session_start_index"] // 2
for responses in [
chosen_encoded_messages[example["session_start_index"] // 2 :],
rejected_encoded_messages[example["session_start_index"] // 2 :],
chosen_encoded_messages[split_index:],
rejected_encoded_messages[split_index:],
]:
responses_token_ids = []
responses_label_ids = []
Expand Down Expand Up @@ -274,9 +276,9 @@ def __postprocess_before_concat(self, example):
cur_len += sum(map(len, response_token_ids_list))

# create at least one turn
turn_index = len(chosen_encoded_messages) - 1
turn_index = split_index
while turn_index >= 0:
if turn_index == len(chosen_encoded_messages) - 1:
if turn_index == split_index:
cur_turn_token = chosen_encoded_messages[turn_index][0]
else:
cur_turn_token = chosen_encoded_messages[turn_index][0] + chosen_encoded_messages[turn_index][1]
Expand All @@ -289,7 +291,7 @@ def __postprocess_before_concat(self, example):
turn_index -= 1

# at least one turn
if turn_index == len(chosen_encoded_messages) - 1:
if turn_index == split_index:
sub_src = example["chosen"]["messages"][0]["content"].strip()[:5]
global LOGGER_COUNT
LOGGER_COUNT += 1
Expand Down Expand Up @@ -364,6 +366,30 @@ def _postprocess_sequence(self, example):
prompt_len : (prompt_len + chosen_len),
] = False
attn_mask_startend_row_indices = None

# print
enable_dataset_debug = os.getenv("FLAGS_enable_dataset_debug", "false").lower() in ("true", "1", "t")
if enable_dataset_debug:
logger.info("\n" + "=" * 50)
logger.info("[dataset debug] Debug mode enabled")
if hasattr(self, "tokenizer"):
print("========================================")
print_debug_info(self.tokenizer, input_ids, "input")
print("========================================\n")

filtered_labels = [x for x in chosen_labels if x != 0] # remove -100
print("========================================")
print_debug_info(self.tokenizer, filtered_labels, "chosen_labels")
print("========================================\n")

filtered_labels = [x for x in rejected_labels if x != 0] # remove -100
print("========================================")
print_debug_info(self.tokenizer, filtered_labels, "rejected_labels")
print("========================================\n")
else:
logger.info("[dataset debug] Tokenizer not available")
logger.info("=" * 50 + "\n")

# 2. return sequence
return Sequence(
token_ids=input_ids,
Expand Down
50 changes: 23 additions & 27 deletions paddleformers/datasets/SFTDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class Sequence:
token_ids: List[int]
position_ids: List[int]
labels: List[int]
loss_mask: List[int]
num_examples: int
images: List[str]
videos: List[str]
Expand All @@ -51,6 +50,7 @@ def __init__(self, **dataset_config):
self.template = dataset_config.get("template_instance", None)
self.template_backend = dataset_config.get("template_backend", "jinja")
self.use_template = dataset_config.get("use_template", True)
self.efficient_eos = True if not self.template else getattr(self.template, "efficient_eos", False)
self.split_multi_turn = dataset_config.get("split_multi_turn", False)
self.encode_one_turn = dataset_config.get("encode_one_turn", True)
self.is_pretraining = dataset_config.get("is_pretraining", False)
Expand Down Expand Up @@ -142,13 +142,11 @@ def __iter_func(self):

res_tokens = cut_tokens[:-1]
res_labels = cut_tokens[1:]
loss_mask = [1] * len(res_tokens)
pos_ids = list(range(len(res_tokens)))
sequence = Sequence(
token_ids=res_tokens,
position_ids=pos_ids,
labels=res_labels,
loss_mask=loss_mask,
num_examples=actual_example_num,
images=[],
videos=[],
Expand All @@ -172,13 +170,11 @@ def __iter_func(self):
cut_tokens = cut_tokens + [self.tokenizer.eos_token_id]
res_tokens = cut_tokens[:-1]
res_labels = cut_tokens[1:]
loss_mask = [1] * len(res_tokens)
pos_ids = list(range(len(res_tokens)))
sequence = Sequence(
token_ids=res_tokens,
position_ids=pos_ids,
labels=res_labels,
loss_mask=loss_mask,
num_examples=actual_example_num,
images=[],
videos=[],
Expand Down Expand Up @@ -313,13 +309,11 @@ def _postprocess_pretraining_sequence(self, example, actual_example_num):
tokens = self._encode_pretraining_example(example, actual_example_num)
res_tokens = tokens[:-1]
res_labels = tokens[1:]
loss_mask = [1] * len(res_tokens)
pos_ids = list(range(len(res_tokens)))
sequence = Sequence(
token_ids=res_tokens,
position_ids=pos_ids,
labels=res_labels,
loss_mask=loss_mask,
num_examples=actual_example_num,
images=[],
videos=[],
Expand Down Expand Up @@ -377,7 +371,7 @@ def _postprocess_sequence(self, example, actual_example_num):
turn_index = len(encoded_pairs) - 1

tokens = []
loss_mask = []
labels = []
while turn_index >= 0:
tokens_src, tokens_target = encoded_pairs[turn_index]
if len(tokens_target) == 0:
Expand All @@ -394,12 +388,16 @@ def _postprocess_sequence(self, example, actual_example_num):
reverse_len = self.max_seq_len + 1 - cur_len - num_reserved_tokens_for_each_turn - len(tokens_src)
tokens_target = tokens_target[:reverse_len]

if self.use_template and self.efficient_eos and turn_index != 0:
labels_src = [self.tokenizer.eos_token_id] + [-100] * (len(tokens_src) - 1)
else:
labels_src = [-100] * len(tokens_src)

labels_target = tokens_target
tokens = tokens_src + tokens_target + tokens
labels = labels_src + labels_target + labels

loss_mask = (
[0] * (len(tokens_src) - 1) + [example["label"][turn_index]] * (len(tokens_target) + 1) + loss_mask
)
assert len(tokens) == len(loss_mask), f"{len(tokens)}-{len(loss_mask)}"
assert len(tokens) == len(labels), f"{len(tokens)}-{len(labels)}"

cur_len = len(tokens)

Expand All @@ -424,41 +422,42 @@ def _postprocess_sequence(self, example, actual_example_num):
# Maybe left truncated, so need to add begin_token
if tokens[0] != self.begin_token_id:
tokens = [self.begin_token_id] + tokens
loss_mask = [0] + loss_mask
labels = [-100] + labels

if len(tokens) > self.max_seq_len:
raise RuntimeError(f"token_ids is too long: {len(tokens)}")

# Add EOS token at the end
del tokens[-1]
del loss_mask[-1]
labels = tokens[1:] + [self.tokenizer.eos_token_id]
del labels[-1]
if self.efficient_eos:
tokens = tokens + [self.tokenizer.eos_token_id]
labels = labels + [self.tokenizer.eos_token_id]
labels = labels[1:] + [-100]

# end_of_response is a special token that indicates the end of the turn.
# end_token is a special token that indicates the end of the answer.
labels = [
label if label != self.end_of_response_id else self.tokenizer.eos_token_id for label in labels
]
else:
tokens = tokens[:-1] + [self.tokenizer.eos_token_id]
labels = tokens[1:] + [-100]
if self.efficient_eos:
tokens = tokens + [self.tokenizer.eos_token_id]
labels = labels + [self.tokenizer.eos_token_id]
labels = labels[1:] + [-100]
if len(tokens) > self.max_seq_len:
raise RuntimeError(f"token_ids is too long: {len(tokens)}")
else:
oral_tokens = tokens
tokens = oral_tokens[:-1]
labels = oral_tokens[1:]
loss_mask = loss_mask[:-1]
labels = tokens[1:] + [-100]
if len(tokens) > self.max_seq_len:
raise RuntimeError(f"token_ids is too long: {len(tokens)}")

pos_ids = list(range(len(tokens)))

if sum(loss_mask) == 0:
if all(x == -100 for x in labels):
logger.warning(f"[SKIP] all labels set to 0: {example}")
return None

assert len(tokens) == len(loss_mask), f"{len(tokens)}-{len(loss_mask)}"
assert len(tokens) == len(labels), f"{len(tokens)}-{len(labels)}"

enable_dataset_debug = os.getenv("FLAGS_enable_dataset_debug", "false").lower() in ("true", "1", "t")
Expand All @@ -470,12 +469,10 @@ def _postprocess_sequence(self, example, actual_example_num):
print_debug_info(self.tokenizer, tokens, "input")
print("========================================\n")

filtered_labels = [label if mask == 1 else -100 for label, mask in zip(labels, loss_mask)]
filtered_labels = [x for x in filtered_labels if x != -100] # remove -100
filtered_labels = [x for x in labels if x != -100] # remove -100
print("========================================")
print_debug_info(self.tokenizer, filtered_labels, "labels")
print("========================================\n")
logger.info(f"[dataset debug] loss mask: {loss_mask}")
else:
logger.info("[dataset debug] Tokenizer not available")
logger.info("=" * 50 + "\n")
Expand All @@ -484,7 +481,6 @@ def _postprocess_sequence(self, example, actual_example_num):
token_ids=tokens,
position_ids=pos_ids,
labels=labels,
loss_mask=loss_mask,
num_examples=actual_example_num,
images=images,
videos=videos,
Expand Down
4 changes: 0 additions & 4 deletions paddleformers/datasets/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def collate_fn(batch: List[List[Sequence]], tokenizer, training_args, model_args
dict: Dictionary containing:
- input_ids: Padded token IDs
- labels: Shifted labels for prediction
- loss_mask: Mask for computing loss
"""
input_keys = ["input_ids", "labels", "position_ids"]
if training_args.num_nextn_predict_layers > 0:
Expand All @@ -180,15 +179,12 @@ def collate_fn(batch: List[List[Sequence]], tokenizer, training_args, model_args
for batch_sequence in batch:
original_token_ids = [seq.token_ids for seq in batch_sequence]
token_ids = [sum(original_token_ids, [])]
loss_mask = [sum([seq.loss_mask for seq in batch_sequence], [])]
labels = [sum([seq.labels for seq in batch_sequence], [])]
position_ids = [sum([seq.position_ids for seq in batch_sequence], [])]
# padding
padded_token_ids = pad_batch_data(token_ids, pad_idx=tokenizer.pad_token_id, max_seq_len=max_seq_len)
padded_labels = pad_batch_data(labels, pad_idx=tokenizer.pad_token_id, max_seq_len=max_seq_len)
padded_loss_mask = pad_batch_data(loss_mask, pad_idx=0, max_seq_len=max_seq_len)
padded_position_ids = pad_batch_data(position_ids, pad_idx=0, max_seq_len=max_seq_len)
padded_labels = np.where(padded_loss_mask == 1, padded_labels, -100)
return_list.append(
[
padded_token_ids,
Expand Down
2 changes: 1 addition & 1 deletion paddleformers/datasets/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def estimate_training(train_dataset, data_args, training_args, model_args):
"sharding_parallel_degree": int(training_args.sharding_parallel_degree),
"num_samples_each_epoch": data_args.num_samples_each_epoch,
"max_seq_len": int(data_args.max_seq_len),
"seed": data_args.seed,
"seed": training_args.seed,
"valid": False,
"train_samples": 0,
}
Expand Down
17 changes: 17 additions & 0 deletions paddleformers/datasets/reader/file_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,25 @@ def data_check(self, data):
if len(data["messages"]) == 0:
raise ValueError("Ignore example with empty messages.")

for index in range(len(data["messages"])):
# Fix the role names for tool call and tool response
if data["messages"][index]["role"] == "tool" or data["messages"][index]["role"] == "tool_response":
data["messages"][index]["role"] = "observation"
if data["messages"][index]["role"] == "tool_call" or data["messages"][index]["role"] == "tool_calls":
data["messages"][index]["role"] = "function"
# Convert the content of tool call and tool response into a string
if (
data["messages"][index]["role"] == "observation" or data["messages"][index]["role"] == "function"
) and not isinstance(data["messages"][index]["content"], str):
data["messages"][index]["content"] = json.dumps(data["messages"][index]["content"])
if "tool_calls" in data["messages"][index] and not isinstance(data["messages"][index]["tool_calls"], str):
data["messages"][index]["tool_calls"] = json.dumps(data["messages"][index]["tool_calls"])

# Convert the content of tool list into a string
if "tools" in data and not isinstance(data["tools"], str):
data["tools"] = json.dumps(data["tools"], ensure_ascii=False)

# If no label is input, it means each response needs to be learned.
if "label" not in data:
data["label"] = [
1 for turn in data["messages"] if ("assistant" in turn["role"] or "function" in turn["role"])
Expand All @@ -116,6 +132,7 @@ def data_check(self, data):
system = data["messages"][0]["content"]
if not isinstance(system, str):
raise ValueError("System field must be a string.")
data["messages"] = data["messages"][1:]
data["system"] = system

# Convert the relative paths of multimode data into absolute paths
Expand Down
Loading