From 92b9fa6c86ce775c7464062f23a2b4231ccd4f1e Mon Sep 17 00:00:00 2001 From: Takuya Kato <17094694+Ktakuya332C@users.noreply.github.com> Date: Thu, 28 Nov 2024 09:54:58 +0900 Subject: [PATCH] =?UTF-8?q?choice=5Ftemplate=E3=81=8C=E7=A9=BA=E3=81=AE?= =?UTF-8?q?=E5=A0=B4=E5=90=88=E3=81=AB=E7=84=A1=E8=A6=96=E3=81=99=E3=82=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../multiple_choice_dataset/template_based.py | 7 +++--- .../en_multiple_choice/arc_challenge.jsonnet | 17 +++++++------- .../en_multiple_choice/arc_easy.jsonnet | 17 +++++++------- .../test_template_based.py | 23 +++++++++++++++++-- 4 files changed, 43 insertions(+), 21 deletions(-) diff --git a/flexeval/core/multiple_choice_dataset/template_based.py b/flexeval/core/multiple_choice_dataset/template_based.py index 2122016..66d609a 100644 --- a/flexeval/core/multiple_choice_dataset/template_based.py +++ b/flexeval/core/multiple_choice_dataset/template_based.py @@ -72,13 +72,14 @@ def __getitem__(self, i: int) -> MultipleChoiceInstance: inputs.update({k: v.render(**item) for k, v in self.input_templates.items()}) choices = [t.render(**item) for t in self.choices_templates] - if any(len(c) == 0 for c in choices): - msg = f"choices must be non-empty, but got {choices}" - raise ValueError(msg) + choices = list(filter(lambda x: len(x) > 0, choices)) if self.whitespace_before_choices: choices = [" " + c for c in choices] answer_index = int(self.answer_index_template.render(**item)) + if not (0 <= answer_index and answer_index < len(choices)): + msg = f"at least {answer_idx+1} choices required, but got {choices}" + raise ValueError(msg) return MultipleChoiceInstance( inputs=inputs, diff --git a/flexeval/preset_configs/EvalSetup/en_multiple_choice/arc_challenge.jsonnet b/flexeval/preset_configs/EvalSetup/en_multiple_choice/arc_challenge.jsonnet index d56698c..e83816c 100644 --- a/flexeval/preset_configs/EvalSetup/en_multiple_choice/arc_challenge.jsonnet +++ b/flexeval/preset_configs/EvalSetup/en_multiple_choice/arc_challenge.jsonnet @@ -11,15 +11,16 @@ References: local dataset_base_args = { path: 'allenai/ai2_arc', subset: 'ARC-Challenge', - choices_templates: ['{{ choices.text[0] }}', '{{ choices.text[1] }}', '{{ choices.text[2] }}', '{{ choices.text[3] }}'], - # answerKey is one of A, B, C, D, 1, 2, 3, 4 - answer_index_template: '{% if answerKey == "A" %}0{% elif answerKey == "B" %}1{% elif answerKey == "C" %}2{% elif answerKey == "D" %}3{% else %}{{ answerKey | int - 1 }}{% endif %}', + choices_templates: [ + '{% if choices.text | length > 0 %}{{ choices.text[0] }}{% endif %}', + '{% if choices.text | length > 1 %}{{ choices.text[1] }}{% endif %}', + '{% if choices.text | length > 2 %}{{ choices.text[2] }}{% endif %}', + '{% if choices.text | length > 3 %}{{ choices.text[3] }}{% endif %}', + '{% if choices.text | length > 4 %}{{ choices.text[4] }}{% endif %}', + ], + # answerKey is one of A, B, C, D, E, 1, 2, 3, 4 + answer_index_template: '{% if answerKey == "A" %}0{% elif answerKey == "B" %}1{% elif answerKey == "C" %}2{% elif answerKey == "D" %}3{% elif answerKey == "E" %}3{% else %}{{ answerKey | int - 1 }}{% endif %}', whitespace_before_choices: true, - remove_conditions: { - # Remove questions with 3 or 5 choices because the size of choices_template is fixed to 4. - '{{ choices.text | length }}': '3', - '{{ choices.label | length }}': '5', - }, }; { diff --git a/flexeval/preset_configs/EvalSetup/en_multiple_choice/arc_easy.jsonnet b/flexeval/preset_configs/EvalSetup/en_multiple_choice/arc_easy.jsonnet index b6fe114..9879eef 100644 --- a/flexeval/preset_configs/EvalSetup/en_multiple_choice/arc_easy.jsonnet +++ b/flexeval/preset_configs/EvalSetup/en_multiple_choice/arc_easy.jsonnet @@ -11,15 +11,16 @@ References: local dataset_base_args = { path: 'allenai/ai2_arc', subset: 'ARC-Easy', - choices_templates: ['{{ choices.text[0] }}', '{{ choices.text[1] }}', '{{ choices.text[2] }}', '{{ choices.text[3] }}'], - # answerKey is one of A, B, C, D, 1, 2, 3, 4 - answer_index_template: '{% if answerKey == "A" %}0{% elif answerKey == "B" %}1{% elif answerKey == "C" %}2{% elif answerKey == "D" %}3{% else %}{{ answerKey | int - 1 }}{% endif %}', + choices_templates: [ + '{% if choices.text | length > 0 %}{{ choices.text[0] }}{% endif %}', + '{% if choices.text | length > 1 %}{{ choices.text[1] }}{% endif %}', + '{% if choices.text | length > 2 %}{{ choices.text[2] }}{% endif %}', + '{% if choices.text | length > 3 %}{{ choices.text[3] }}{% endif %}', + '{% if choices.text | length > 4 %}{{ choices.text[4] }}{% endif %}', + ], + # answerKey is one of A, B, C, D, E, 1, 2, 3, 4 + answer_index_template: '{% if answerKey == "A" %}0{% elif answerKey == "B" %}1{% elif answerKey == "C" %}2{% elif answerKey == "D" %}3{% elif answerKey == "E" %}3{% else %}{{ answerKey | int - 1 }}{% endif %}', whitespace_before_choices: true, - remove_conditions: { - # Remove questions with 3 or 5 choices because the size of choices_template is fixed to 4. - '{{ choices.text | length }}': '3', - '{{ choices.label | length }}': '5', - }, }; { diff --git a/tests/core/multiple_choice_dataset/test_template_based.py b/tests/core/multiple_choice_dataset/test_template_based.py index b9e1089..9c14350 100644 --- a/tests/core/multiple_choice_dataset/test_template_based.py +++ b/tests/core/multiple_choice_dataset/test_template_based.py @@ -38,7 +38,11 @@ def test_template_multiple_choice_dataset( dataset = dataset_class( **kwargs, input_templates={"test_additional_input": "additional: {{ question }}"}, - choices_templates=["{{ answers[0] }}"], + choices_templates=[ + "{% if answers | length > 0 %}{{ answers[0] }}{% endif %}", + "{% if answers | length > 1 %}{{ answers[1] }}{% endif %}", + "{% if answers | length > 2 %}{{ answers[2] }}{% endif %}", + ], answer_index_template="0", ) @@ -50,9 +54,24 @@ def test_template_multiple_choice_dataset( "answers": ["Mount Everest", "Everest"], "test_additional_input": "additional: What is the highest mountain in the world.", } - assert item.choices == ["Mount Everest"] + assert item.choices == ["Mount Everest", "Everest"] assert item.answer_index == 0 + item = dataset[1] + assert item.inputs == { + "id": 1, + "question": "What is the chemical symbol for water?", + "answers": ["H2O"], + "test_additional_input": "additional: What is the chemical symbol for water?", + } + + item = dataset[4] + assert item.inputs == { + "id": 4, + "question": "Who wrote 'Romeo and Juliet'?", + "answers": ["William Shakespeare", "Shakespeare"], + "test_additional_input": "additional: Who wrote 'Romeo and Juliet'?", + } @pytest.mark.parametrize( ("dataset_class", "kwargs"),