Skip to content

Commit

Permalink
choice_templateが空の場合に無視する
Browse files Browse the repository at this point in the history
  • Loading branch information
Ktakuya332C committed Nov 28, 2024
1 parent f7652f7 commit 92b9fa6
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 21 deletions.
7 changes: 4 additions & 3 deletions flexeval/core/multiple_choice_dataset/template_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,14 @@ def __getitem__(self, i: int) -> MultipleChoiceInstance:
inputs.update({k: v.render(**item) for k, v in self.input_templates.items()})

choices = [t.render(**item) for t in self.choices_templates]
if any(len(c) == 0 for c in choices):
msg = f"choices must be non-empty, but got {choices}"
raise ValueError(msg)
choices = list(filter(lambda x: len(x) > 0, choices))
if self.whitespace_before_choices:
choices = [" " + c for c in choices]

answer_index = int(self.answer_index_template.render(**item))
if not (0 <= answer_index and answer_index < len(choices)):
msg = f"at least {answer_idx+1} choices required, but got {choices}"
raise ValueError(msg)

return MultipleChoiceInstance(
inputs=inputs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@ References:
local dataset_base_args = {
path: 'allenai/ai2_arc',
subset: 'ARC-Challenge',
choices_templates: ['{{ choices.text[0] }}', '{{ choices.text[1] }}', '{{ choices.text[2] }}', '{{ choices.text[3] }}'],
# answerKey is one of A, B, C, D, 1, 2, 3, 4
answer_index_template: '{% if answerKey == "A" %}0{% elif answerKey == "B" %}1{% elif answerKey == "C" %}2{% elif answerKey == "D" %}3{% else %}{{ answerKey | int - 1 }}{% endif %}',
choices_templates: [
'{% if choices.text | length > 0 %}{{ choices.text[0] }}{% endif %}',
'{% if choices.text | length > 1 %}{{ choices.text[1] }}{% endif %}',
'{% if choices.text | length > 2 %}{{ choices.text[2] }}{% endif %}',
'{% if choices.text | length > 3 %}{{ choices.text[3] }}{% endif %}',
'{% if choices.text | length > 4 %}{{ choices.text[4] }}{% endif %}',
],
# answerKey is one of A, B, C, D, E, 1, 2, 3, 4
answer_index_template: '{% if answerKey == "A" %}0{% elif answerKey == "B" %}1{% elif answerKey == "C" %}2{% elif answerKey == "D" %}3{% elif answerKey == "E" %}3{% else %}{{ answerKey | int - 1 }}{% endif %}',
whitespace_before_choices: true,
remove_conditions: {
# Remove questions with 3 or 5 choices because the size of choices_template is fixed to 4.
'{{ choices.text | length }}': '3',
'{{ choices.label | length }}': '5',
},
};

{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@ References:
local dataset_base_args = {
path: 'allenai/ai2_arc',
subset: 'ARC-Easy',
choices_templates: ['{{ choices.text[0] }}', '{{ choices.text[1] }}', '{{ choices.text[2] }}', '{{ choices.text[3] }}'],
# answerKey is one of A, B, C, D, 1, 2, 3, 4
answer_index_template: '{% if answerKey == "A" %}0{% elif answerKey == "B" %}1{% elif answerKey == "C" %}2{% elif answerKey == "D" %}3{% else %}{{ answerKey | int - 1 }}{% endif %}',
choices_templates: [
'{% if choices.text | length > 0 %}{{ choices.text[0] }}{% endif %}',
'{% if choices.text | length > 1 %}{{ choices.text[1] }}{% endif %}',
'{% if choices.text | length > 2 %}{{ choices.text[2] }}{% endif %}',
'{% if choices.text | length > 3 %}{{ choices.text[3] }}{% endif %}',
'{% if choices.text | length > 4 %}{{ choices.text[4] }}{% endif %}',
],
# answerKey is one of A, B, C, D, E, 1, 2, 3, 4
answer_index_template: '{% if answerKey == "A" %}0{% elif answerKey == "B" %}1{% elif answerKey == "C" %}2{% elif answerKey == "D" %}3{% elif answerKey == "E" %}3{% else %}{{ answerKey | int - 1 }}{% endif %}',
whitespace_before_choices: true,
remove_conditions: {
# Remove questions with 3 or 5 choices because the size of choices_template is fixed to 4.
'{{ choices.text | length }}': '3',
'{{ choices.label | length }}': '5',
},
};

{
Expand Down
23 changes: 21 additions & 2 deletions tests/core/multiple_choice_dataset/test_template_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ def test_template_multiple_choice_dataset(
dataset = dataset_class(
**kwargs,
input_templates={"test_additional_input": "additional: {{ question }}"},
choices_templates=["{{ answers[0] }}"],
choices_templates=[
"{% if answers | length > 0 %}{{ answers[0] }}{% endif %}",
"{% if answers | length > 1 %}{{ answers[1] }}{% endif %}",
"{% if answers | length > 2 %}{{ answers[2] }}{% endif %}",
],
answer_index_template="0",
)

Expand All @@ -50,9 +54,24 @@ def test_template_multiple_choice_dataset(
"answers": ["Mount Everest", "Everest"],
"test_additional_input": "additional: What is the highest mountain in the world.",
}
assert item.choices == ["Mount Everest"]
assert item.choices == ["Mount Everest", "Everest"]
assert item.answer_index == 0

item = dataset[1]
assert item.inputs == {
"id": 1,
"question": "What is the chemical symbol for water?",
"answers": ["H2O"],
"test_additional_input": "additional: What is the chemical symbol for water?",
}

item = dataset[4]
assert item.inputs == {
"id": 4,
"question": "Who wrote 'Romeo and Juliet'?",
"answers": ["William Shakespeare", "Shakespeare"],
"test_additional_input": "additional: Who wrote 'Romeo and Juliet'?",
}

@pytest.mark.parametrize(
("dataset_class", "kwargs"),
Expand Down

0 comments on commit 92b9fa6

Please sign in to comment.