Skip to content

Commit

Permalink
feat(pipeline): use generic to specify output type
Browse files Browse the repository at this point in the history
  • Loading branch information
zhudotexe committed Apr 5, 2024
1 parent 36d3a33 commit ccf20d4
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
2 changes: 1 addition & 1 deletion kani/engines/huggingface/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
self,
model_id: str,
max_context_size: int,
prompt_pipeline: PromptPipeline = None,
prompt_pipeline: PromptPipeline[str | torch.Tensor] = None,
*,
token=None,
device: str | None = None,
Expand Down
2 changes: 1 addition & 1 deletion kani/engines/llamacpp/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
repo_id: str,
filename: str = None,
max_context_size: int = 0,
prompt_pipeline: PromptPipeline = LLAMA2_PIPELINE,
prompt_pipeline: PromptPipeline[str | list[int]] = LLAMA2_PIPELINE,
*,
model_load_kwargs: dict = None,
**hyperparams,
Expand Down
13 changes: 8 additions & 5 deletions kani/prompts/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import operator
import pprint
import time
from typing import overload
from typing import Generic, TypeVar, overload

from kani.models import ChatMessage, ChatRole
from kani.prompts.base import PipelineStep
Expand All @@ -30,8 +30,11 @@
# sad to use a generic typevar here
Self = "PromptPipeline"

# use a generic to specify the return type of the pipeline
T = TypeVar("T")

class PromptPipeline:

class PromptPipeline(Generic[T]):
r"""
This class creates a reproducible pipeline for translating a list of :class:`.ChatMessage` into an engine-specific
format using fluent-style chaining.
Expand Down Expand Up @@ -250,7 +253,7 @@ def conversation_fmt(
# FUNCTION messages (if not specified, defaults to user)
function_prefix: str = None,
function_suffix: str = None,
) -> Self: ...
) -> "PromptPipeline[str]": ...

def conversation_fmt(self, **kwargs):
"""
Expand Down Expand Up @@ -281,14 +284,14 @@ def conversation_fmt(self, **kwargs):
return self

# ==== eval ====
def __call__(self, msgs: list[ChatMessage]):
def __call__(self, msgs: list[ChatMessage]) -> T:
"""
Apply the pipeline to a list of kani messages. The return type will vary based on the steps in the pipeline;
if no steps are defined the return type will be a copy of the input messages.
"""
return self.execute(msgs)

def execute(self, msgs: list[ChatMessage], *, deepcopy=False, for_measurement=False):
def execute(self, msgs: list[ChatMessage], *, deepcopy=False, for_measurement=False) -> T:
"""
Apply the pipeline to a list of kani messages. The return type will vary based on the steps in the pipeline;
if no steps are defined the return type will be a copy of the input messages.
Expand Down

0 comments on commit ccf20d4

Please sign in to comment.