Skip to content

Commit 9284cc4

Browse files
committed
Implement FixedFewShotGenerator
1 parent 77d6a07 commit 9284cc4

File tree

3 files changed

+42
-0
lines changed

3 files changed

+42
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .balanced import BalancedFewShotGenerator
22
from .base import FewShotGenerator
3+
from .fixed import FixedFewShotGenerator
34
from .rand import RandomFewShotGenerator
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
from .base import ChatInstance, FewShotGenerator, GenerationInstance, Instance, MultipleChoiceInstance
6+
7+
8+
class FixedFewShotGenerator(FewShotGenerator):
9+
def __init__(self, instance_class: str, instance_params: list[dict[str, Any]]) -> None:
10+
super().__init__(num_trials_to_avoid_leak=0)
11+
12+
if instance_class == "GenerationInstance":
13+
instance_init = GenerationInstance
14+
elif instance_class == "MultipleChoiceInstance":
15+
instance_init = MultipleChoiceInstance
16+
elif instance_class == "ChatInstance":
17+
instance_init = ChatInstance
18+
else:
19+
msg = f"Unknown instance class: {instance_class}"
20+
raise ValueError(msg)
21+
22+
self.instances = [instance_init(**params) for params in instance_params]
23+
24+
def _sample_instances(self, eval_inputs: list[dict[str, Any]] | dict[str, Any] | None = None) -> list[Instance]:
25+
return self.instances
26+
27+
def __repr__(self) -> str:
28+
return f"{self.__class__.__name__}(instances={self.instances})"
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from flexeval import ChatInstance
2+
from flexeval.core.few_shot_generator.fixed import FixedFewShotGenerator
3+
4+
5+
def test_fixed_fewshot_generator() -> None:
6+
instance = ChatInstance(
7+
messages=[{"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Hi there!"}]
8+
)
9+
generator = FixedFewShotGenerator(
10+
instance_class="ChatInstance",
11+
instance_params=[{"messages": instance.messages} for _ in range(5)],
12+
)
13+
assert generator() == [instance for _ in range(5)]

0 commit comments

Comments
 (0)