-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #120 from sbintuitions/impl_fixed_fewshot
Implement `FixedFewShotGenerator`
- Loading branch information
Showing
3 changed files
with
42 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .balanced import BalancedFewShotGenerator | ||
from .base import FewShotGenerator | ||
from .fixed import FixedFewShotGenerator | ||
from .rand import RandomFewShotGenerator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
from .base import ChatInstance, FewShotGenerator, GenerationInstance, Instance, MultipleChoiceInstance | ||
|
||
|
||
class FixedFewShotGenerator(FewShotGenerator): | ||
def __init__(self, instance_class: str, instance_params: list[dict[str, Any]]) -> None: | ||
super().__init__(num_trials_to_avoid_leak=0) | ||
|
||
if instance_class == "GenerationInstance": | ||
instance_init = GenerationInstance | ||
elif instance_class == "MultipleChoiceInstance": | ||
instance_init = MultipleChoiceInstance | ||
elif instance_class == "ChatInstance": | ||
instance_init = ChatInstance | ||
else: | ||
msg = f"Unknown instance class: {instance_class}" | ||
raise ValueError(msg) | ||
|
||
self.instances = [instance_init(**params) for params in instance_params] | ||
|
||
def _sample_instances(self, eval_inputs: list[dict[str, Any]] | dict[str, Any] | None = None) -> list[Instance]: | ||
return self.instances | ||
|
||
def __repr__(self) -> str: | ||
return f"{self.__class__.__name__}(instances={self.instances})" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from flexeval import ChatInstance | ||
from flexeval.core.few_shot_generator.fixed import FixedFewShotGenerator | ||
|
||
|
||
def test_fixed_fewshot_generator() -> None: | ||
instance = ChatInstance( | ||
messages=[{"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Hi there!"}] | ||
) | ||
generator = FixedFewShotGenerator( | ||
instance_class="ChatInstance", | ||
instance_params=[{"messages": instance.messages} for _ in range(5)], | ||
) | ||
assert generator() == [instance for _ in range(5)] |