diff --git a/paddleformers/datasets/template/template.py b/paddleformers/datasets/template/template.py index 4051cc0827..8edd296456 100644 --- a/paddleformers/datasets/template/template.py +++ b/paddleformers/datasets/template/template.py @@ -135,8 +135,19 @@ def _encode( """ system = system or self.default_system encoded_messages = [] + elements = [] + last_mask = True for i, message in enumerate(messages): - elements = [] + if message["role"] == Role.USER or message["role"] == Role.OBSERVATION: + mask = True + elif message["role"] == Role.ASSISTANT or message["role"] == Role.FUNCTION: + mask = False + else: + raise NotImplementedError("Unexpected role: {}".format(message["role"])) + + if mask != last_mask: + encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) + elements = [] if i == 0: elements += self.format_prefix.apply() @@ -159,7 +170,9 @@ def _encode( else: raise NotImplementedError("Unexpected role: {}".format(message["role"])) - encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) + last_mask = mask + + encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) return encoded_messages