Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(data): 重构mask方式,sharegpt 支持更精细的mask控制 #6498

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions src/llamafactory/data/aligner.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def convert_alpaca(
prompt = []
if dataset_attr.history and isinstance(example[dataset_attr.history], list):
for old_prompt, old_response in example[dataset_attr.history]:
prompt.append({"role": Role.USER.value, "content": old_prompt})
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
prompt.append({"role": Role.USER.value, "content": old_prompt, "train": data_args.train_on_prompt})
prompt.append({"role": Role.ASSISTANT.value, "content": old_response, "train": not data_args.mask_history})

query = []
if dataset_attr.prompt and example[dataset_attr.prompt]:
Expand All @@ -99,7 +99,7 @@ def convert_alpaca(
if dataset_attr.query and example[dataset_attr.query]:
query.append(example[dataset_attr.query])

prompt.append({"role": Role.USER.value, "content": "\n".join(query)}) # "prompt\nquery"
prompt.append({"role": Role.USER.value, "content": "\n".join(query), "train": data_args.train_on_prompt}) # "prompt\nquery"

if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
Expand All @@ -117,7 +117,7 @@ def convert_alpaca(
{"role": Role.ASSISTANT.value, "content": example[dataset_attr.rejected]},
]
elif dataset_attr.response and isinstance(example[dataset_attr.response], str): # normal example
response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}]
response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response], "train": True}]
else: # unsupervised
response = []

Expand Down Expand Up @@ -170,9 +170,22 @@ def convert_sharegpt(
logger.warning_rank0(f"Invalid role tag in {messages}.")
broken_data = True

aligned_messages.append(
{"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]}
)
if message[dataset_attr.role_tag] in even_tags:
if data_args.mask_history and turn_idx != len(messages) - 1:
train = False
else:
train = message.get(dataset_attr.train_tag, True)
else:
if data_args.train_on_prompt:
train = True
else:
train = message.get(dataset_attr.train_tag, False)

aligned_messages.append({
"role": tag_mapping[message[dataset_attr.role_tag]],
"content": message[dataset_attr.content_tag],
"train": train,
})

if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
dataset_attr.ranking and len(aligned_messages) % 2 == 0
Expand Down
1 change: 1 addition & 0 deletions src/llamafactory/data/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class DatasetAttr:
observation_tag: Optional[str] = "observation"
function_tag: Optional[str] = "function_call"
system_tag: Optional[str] = "system"
train_tag: Optional[str] = "train"

def __repr__(self) -> str:
return self.dataset_name
Expand Down
24 changes: 6 additions & 18 deletions src/llamafactory/data/processors/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,11 @@ def _encode_supervised_example(
tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"],
cutoff_len: int,
train_on_prompt: bool,
mask_history: bool,
) -> Tuple[List[int], List[int]]:
messages = template.mm_plugin.process_messages(prompt + response, images, videos, processor)
input_ids, labels = template.mm_plugin.process_token_ids([], [], images, videos, tokenizer, processor)
encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools)
total_length = len(input_ids) + (1 if template.efficient_eos else 0)
if mask_history:
encoded_pairs = encoded_pairs[::-1] # high priority for last turns

for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
if total_length >= cutoff_len:
Expand All @@ -61,24 +57,20 @@ def _encode_supervised_example(
target_ids = target_ids[:target_len]
total_length += source_len + target_len

if train_on_prompt:
if messages[turn_idx * 2]["train"]:
source_label = source_ids
elif template.efficient_eos:
source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1)
else:
source_label = [IGNORE_INDEX] * source_len

if mask_history and turn_idx != 0: # train on the last turn only
target_label = [IGNORE_INDEX] * target_len
else:
if messages[turn_idx * 2 + 1]["train"]:
target_label = target_ids

if mask_history: # reversed sequences
input_ids = source_ids + target_ids + input_ids
labels = source_label + target_label + labels
else:
input_ids += source_ids + target_ids
labels += source_label + target_label
target_label = [IGNORE_INDEX] * target_len

input_ids += source_ids + target_ids
labels += source_label + target_label

if template.efficient_eos:
input_ids += [tokenizer.eos_token_id]
Expand Down Expand Up @@ -115,8 +107,6 @@ def preprocess_supervised_dataset(
tokenizer=tokenizer,
processor=processor,
cutoff_len=data_args.cutoff_len,
train_on_prompt=data_args.train_on_prompt,
mask_history=data_args.mask_history,
)
model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids))
Expand Down Expand Up @@ -159,8 +149,6 @@ def preprocess_packed_supervised_dataset(
tokenizer=tokenizer,
processor=processor,
cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token
train_on_prompt=data_args.train_on_prompt,
mask_history=data_args.mask_history,
)
length = len(input_ids)
if length > data_args.cutoff_len:
Expand Down