From b7779b6276af9da0cd5e8043a19e1e5d14bc3444 Mon Sep 17 00:00:00 2001 From: zzc <1378113190@qq.com> Date: Tue, 31 Dec 2024 14:26:20 +0800 Subject: [PATCH] =?UTF-8?q?refactor(data):=20=E9=87=8D=E6=9E=84mask?= =?UTF-8?q?=E6=96=B9=E5=BC=8F=EF=BC=8Csharegpt=20=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E6=9B=B4=E7=B2=BE=E7=BB=86=E7=9A=84mask=E6=8E=A7=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/llamafactory/data/aligner.py | 27 ++++++++++++++----- src/llamafactory/data/parser.py | 1 + .../data/processors/supervised.py | 24 +++++------------ 3 files changed, 27 insertions(+), 25 deletions(-) diff --git a/src/llamafactory/data/aligner.py b/src/llamafactory/data/aligner.py index 82bbfafb2d..68bde401fb 100644 --- a/src/llamafactory/data/aligner.py +++ b/src/llamafactory/data/aligner.py @@ -89,8 +89,8 @@ def convert_alpaca( prompt = [] if dataset_attr.history and isinstance(example[dataset_attr.history], list): for old_prompt, old_response in example[dataset_attr.history]: - prompt.append({"role": Role.USER.value, "content": old_prompt}) - prompt.append({"role": Role.ASSISTANT.value, "content": old_response}) + prompt.append({"role": Role.USER.value, "content": old_prompt, "train": data_args.train_on_prompt}) + prompt.append({"role": Role.ASSISTANT.value, "content": old_response, "train": not data_args.mask_history}) query = [] if dataset_attr.prompt and example[dataset_attr.prompt]: @@ -99,7 +99,7 @@ def convert_alpaca( if dataset_attr.query and example[dataset_attr.query]: query.append(example[dataset_attr.query]) - prompt.append({"role": Role.USER.value, "content": "\n".join(query)}) # "prompt\nquery" + prompt.append({"role": Role.USER.value, "content": "\n".join(query), "train": data_args.train_on_prompt}) # "prompt\nquery" if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}] @@ -117,7 +117,7 @@ def convert_alpaca( {"role": Role.ASSISTANT.value, "content": example[dataset_attr.rejected]}, ] elif dataset_attr.response and isinstance(example[dataset_attr.response], str): # normal example - response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response]}] + response = [{"role": Role.ASSISTANT.value, "content": example[dataset_attr.response], "train": True}] else: # unsupervised response = [] @@ -170,9 +170,22 @@ def convert_sharegpt( logger.warning_rank0(f"Invalid role tag in {messages}.") broken_data = True - aligned_messages.append( - {"role": tag_mapping[message[dataset_attr.role_tag]], "content": message[dataset_attr.content_tag]} - ) + if message[dataset_attr.role_tag] in even_tags: + if data_args.mask_history and turn_idx != len(messages) - 1: + train = False + else: + train = message.get(dataset_attr.train_tag, True) + else: + if data_args.train_on_prompt: + train = True + else: + train = message.get(dataset_attr.train_tag, False) + + aligned_messages.append({ + "role": tag_mapping[message[dataset_attr.role_tag]], + "content": message[dataset_attr.content_tag], + "train": train, + }) if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or ( dataset_attr.ranking and len(aligned_messages) % 2 == 0 diff --git a/src/llamafactory/data/parser.py b/src/llamafactory/data/parser.py index 709d0c900c..a9bf90d2b6 100644 --- a/src/llamafactory/data/parser.py +++ b/src/llamafactory/data/parser.py @@ -63,6 +63,7 @@ class DatasetAttr: observation_tag: Optional[str] = "observation" function_tag: Optional[str] = "function_call" system_tag: Optional[str] = "system" + train_tag: Optional[str] = "train" def __repr__(self) -> str: return self.dataset_name diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index 83bd8ba2a7..83f3fe1d79 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -42,15 +42,11 @@ def _encode_supervised_example( tokenizer: "PreTrainedTokenizer", processor: Optional["ProcessorMixin"], cutoff_len: int, - train_on_prompt: bool, - mask_history: bool, ) -> Tuple[List[int], List[int]]: messages = template.mm_plugin.process_messages(prompt + response, images, videos, processor) input_ids, labels = template.mm_plugin.process_token_ids([], [], images, videos, tokenizer, processor) encoded_pairs = template.encode_multiturn(tokenizer, messages, system, tools) total_length = len(input_ids) + (1 if template.efficient_eos else 0) - if mask_history: - encoded_pairs = encoded_pairs[::-1] # high priority for last turns for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs): if total_length >= cutoff_len: @@ -61,24 +57,20 @@ def _encode_supervised_example( target_ids = target_ids[:target_len] total_length += source_len + target_len - if train_on_prompt: + if messages[turn_idx * 2]["train"]: source_label = source_ids elif template.efficient_eos: source_label = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (source_len - 1) else: source_label = [IGNORE_INDEX] * source_len - if mask_history and turn_idx != 0: # train on the last turn only - target_label = [IGNORE_INDEX] * target_len - else: + if messages[turn_idx * 2 + 1]["train"]: target_label = target_ids - - if mask_history: # reversed sequences - input_ids = source_ids + target_ids + input_ids - labels = source_label + target_label + labels else: - input_ids += source_ids + target_ids - labels += source_label + target_label + target_label = [IGNORE_INDEX] * target_len + + input_ids += source_ids + target_ids + labels += source_label + target_label if template.efficient_eos: input_ids += [tokenizer.eos_token_id] @@ -115,8 +107,6 @@ def preprocess_supervised_dataset( tokenizer=tokenizer, processor=processor, cutoff_len=data_args.cutoff_len, - train_on_prompt=data_args.train_on_prompt, - mask_history=data_args.mask_history, ) model_inputs["input_ids"].append(input_ids) model_inputs["attention_mask"].append([1] * len(input_ids)) @@ -159,8 +149,6 @@ def preprocess_packed_supervised_dataset( tokenizer=tokenizer, processor=processor, cutoff_len=data_args.cutoff_len - 1, # reserved for the padding token - train_on_prompt=data_args.train_on_prompt, - mask_history=data_args.mask_history, ) length = len(input_ids) if length > data_args.cutoff_len: