Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ToolKit for new interface #340

Merged
merged 37 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
9184ed0
feat: Add ToolKit for new interface
koxudaxi Jun 19, 2024
f586fbf
Merge branch 'v1' into add_toolkit_for_new_interface
koxudaxi Jun 20, 2024
83ce746
Update BaseTool import
koxudaxi Jun 20, 2024
cff7f84
Merge branch 'v1' into add_toolkit_for_new_interface
koxudaxi Jun 20, 2024
d2afc1d
implement create_tool
koxudaxi Jun 20, 2024
1c8d0cb
Fix the order in `__all__`
koxudaxi Jun 20, 2024
30cdb19
fix toolkit_tool decorator
koxudaxi Jun 20, 2024
aee68bc
Add unittest
koxudaxi Jun 20, 2024
32f592b
Fix toolkit logic
koxudaxi Jun 20, 2024
ddab430
Update unittest
koxudaxi Jun 20, 2024
4386a4a
Merge branch 'v1' into add_toolkit_for_new_interface
koxudaxi Jun 21, 2024
b6425b5
Fix unittest
koxudaxi Jun 21, 2024
baf5cf7
Merge branch 'v1' into add_toolkit_for_new_interface
koxudaxi Jun 21, 2024
3781a3e
Apply suggestions
koxudaxi Jun 21, 2024
fc95385
Improve error message
koxudaxi Jun 21, 2024
a34fed5
Fix unittest name
koxudaxi Jun 21, 2024
850ed69
Apply list comprehension
koxudaxi Jun 21, 2024
6ff6883
change has_self_or_cls to has_self
koxudaxi Jun 21, 2024
35716fd
Improve typing
koxudaxi Jun 21, 2024
dbc6876
Fix name
koxudaxi Jun 21, 2024
569f687
Small fixes
koxudaxi Jun 21, 2024
85d61f2
Apply format
koxudaxi Jun 21, 2024
f213943
Add namespace
koxudaxi Jun 21, 2024
c5a3eb3
Add testcase
koxudaxi Jun 21, 2024
1c9d3d8
Add testcase
koxudaxi Jun 21, 2024
9892a2b
Add namespace checking
koxudaxi Jun 21, 2024
4bf7b2c
Fix namespace
koxudaxi Jun 21, 2024
fe6f4b3
Fix namespace
koxudaxi Jun 21, 2024
62e6845
Format unittest
koxudaxi Jun 21, 2024
c6b1508
Merge branch 'v1' into add_toolkit_for_new_interface
koxudaxi Jun 21, 2024
275d397
Fix unittest
koxudaxi Jun 21, 2024
51420b8
Change _namespace to __namespace__
koxudaxi Jun 21, 2024
7cf391e
Format
koxudaxi Jun 21, 2024
1445caa
Apply format
koxudaxi Jun 21, 2024
469315a
Add docstring
koxudaxi Jun 21, 2024
95e7bd4
Fix __namespace__ argument
koxudaxi Jun 21, 2024
dfa318e
Improve docstring
koxudaxi Jun 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions mirascope/core/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,17 @@
"""


def format_prompt_template(template: str, attrs: dict[str, Any]) -> str:
"""Formats the given prompt `template`"""
dedented_template = dedent(template).strip()
template_vars = [
var for _, var, _, _ in Formatter().parse(dedented_template) if var is not None
]
def get_template_variables(template: str) -> list[str]:
"""Returns the variables in the given template string."""
return [var for _, var, _, _ in Formatter().parse(template) if var is not None]


def get_template_values(
template_variables: list[str], attrs: dict[str, Any]
) -> dict[str, Any]:
"""Returns the values of the given `template_variables` from the provided `attrs`."""
values = {}
for var in template_vars:
for var in template_variables:
attr = attrs[var]
if isinstance(attr, list):
if len(attr) == 0:
Expand All @@ -49,6 +51,15 @@ def format_prompt_template(template: str, attrs: dict[str, Any]) -> str:
values[var] = "\n".join([str(item) for item in attr])
else:
values[var] = str(attr)
return values


def format_prompt_template(template: str, attrs: dict[str, Any]) -> str:
koxudaxi marked this conversation as resolved.
Show resolved Hide resolved
"""Formats the given prompt `template`"""
dedented_template = dedent(template).strip()
koxudaxi marked this conversation as resolved.
Show resolved Hide resolved
template_vars = get_template_variables(dedented_template)

values = get_template_values(template_vars, attrs)

return dedented_template.format(**values)

Expand Down
11 changes: 10 additions & 1 deletion mirascope/core/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,14 @@
from .prompt import BasePrompt, tags
from .tool import BaseTool
from .types import BaseCallResponse, MessageParam
from .toolkit import BaseToolKit, toolkit_tool

__all__ = ["BaseCallResponse", "BasePrompt", "BaseTool", "MessageParam", "tags"]
__all__ = [
"BaseCallResponse",
"BasePrompt",
"BaseTool",
"MessageParam",
"tags",
"BaseToolKit",
"toolkit_tool",
]
87 changes: 87 additions & 0 deletions mirascope/core/base/toolkit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from __future__ import annotations

from abc import ABC
from textwrap import dedent
from typing import Callable, Optional, ClassVar, ParamSpec
from functools import wraps

from pydantic import BaseModel, ConfigDict
from typing_extensions import LiteralString

from .tool import BaseTool
from .._internal.utils import get_template_variables

_TOOLKIT_TOOL_METHOD_MARKER: LiteralString = "__toolkit_tool_method__"

P = ParamSpec("P")


def toolkit_tool(
method: Callable[[BaseToolKit, ...], str],
) -> Callable[[BaseToolKit, ...], str]:
# Mark the method as a toolkit tool
setattr(method, _TOOLKIT_TOOL_METHOD_MARKER, True)

# TODO: Validate first argument is self
@wraps(method)
def inner(*args, **kwargs):
return method(*args, **kwargs)

return inner


class BaseToolKit(BaseModel, ABC):
"""A class for defining tools for LLM call tools."""

model_config = ConfigDict(arbitrary_types_allowed=True)
_toolkit_tool_method: ClassVar[Callable[..., str]]
_toolkit_template_vars: ClassVar[list[str]]

def create_tool(self) -> type[BaseTool]:
koxudaxi marked this conversation as resolved.
Show resolved Hide resolved
"""The method to create the tools."""
formated_template = self._toolkit_tool_method.__doc__.format(
**{var: getattr(self, var) for var in self._toolkit_template_vars}
)
# TODO: Generate the tool class with the formatted template
...

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
toolkit_tool_method: Optional[Callable] = None
for name, value in cls.__dict__.items():
if getattr(value, _TOOLKIT_TOOL_METHOD_MARKER, False):
if toolkit_tool_method:
raise ValueError("Only one toolkit_tool method is allowed")
koxudaxi marked this conversation as resolved.
Show resolved Hide resolved
toolkit_tool_method = value
if not toolkit_tool_method:
raise ValueError("No toolkit_tool method found")

# Validate the toolkit_tool_method
if (template := toolkit_tool_method.__doc__) is None:
raise ValueError("The toolkit_tool method must have a docstring")

dedented_template = dedent(template).strip()
if not (template_vars := get_template_variables(dedented_template)):
raise ValueError("The toolkit_tool method must have template variables")
koxudaxi marked this conversation as resolved.
Show resolved Hide resolved

if dedented_template != template:
koxudaxi marked this conversation as resolved.
Show resolved Hide resolved
toolkit_tool_method.__doc__ = dedented_template

for var in template_vars:
if not var.startswith("self."):
# Should be supported un-self variables?
raise ValueError(
"The toolkit_tool method must use self. prefix in template variables"
koxudaxi marked this conversation as resolved.
Show resolved Hide resolved
)

self_var = var[5:]
# Expecting pydantic model fields or class attribute and property
# TODO: Check attribute type such like callable, property, etc.
if self_var in cls.model_fields_set or hasattr(cls, self_var):
continue
raise ValueError(
f"The toolkit_tool method template variable {var} is not found in the class"
)

cls._toolkit_tool_method = toolkit_tool_method
willbakst marked this conversation as resolved.
Show resolved Hide resolved
cls._toolkit_template_vars = template_vars
11 changes: 11 additions & 0 deletions mirascope/core/openai/openai_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ def decorator(fn: Callable[P, R]) -> Callable[P, OpenAICallResponse]:
def call(*args: P.args, **kwargs: P.kwargs) -> OpenAICallResponse:
prompt_template = inspect.getdoc(fn)
assert prompt_template is not None, "The function must have a docstring."

# Try to get the dictionary for tools from the function result
# tools = []
fn_result = fn(*args, **kwargs)
if isinstance(fn_result, dict):
if fn_result_tools := fn_result.get("tools"):
for fn_result_tool in fn_result_tools:
# TODO: Generate tools
# tools.append(generate(fn_result_tool))
pass
koxudaxi marked this conversation as resolved.
Show resolved Hide resolved

attrs = inspect.signature(fn).bind(*args, **kwargs).arguments
messages = utils.parse_prompt_messages(
roles=["system", "user", "assistant"],
Expand Down