Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ToolKit for new interface #340

Merged
merged 37 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
9184ed0
feat: Add ToolKit for new interface
koxudaxi Jun 19, 2024
f586fbf
Merge branch 'v1' into add_toolkit_for_new_interface
koxudaxi Jun 20, 2024
83ce746
Update BaseTool import
koxudaxi Jun 20, 2024
cff7f84
Merge branch 'v1' into add_toolkit_for_new_interface
koxudaxi Jun 20, 2024
d2afc1d
implement create_tool
koxudaxi Jun 20, 2024
1c8d0cb
Fix the order in `__all__`
koxudaxi Jun 20, 2024
30cdb19
fix toolkit_tool decorator
koxudaxi Jun 20, 2024
aee68bc
Add unittest
koxudaxi Jun 20, 2024
32f592b
Fix toolkit logic
koxudaxi Jun 20, 2024
ddab430
Update unittest
koxudaxi Jun 20, 2024
4386a4a
Merge branch 'v1' into add_toolkit_for_new_interface
koxudaxi Jun 21, 2024
b6425b5
Fix unittest
koxudaxi Jun 21, 2024
baf5cf7
Merge branch 'v1' into add_toolkit_for_new_interface
koxudaxi Jun 21, 2024
3781a3e
Apply suggestions
koxudaxi Jun 21, 2024
fc95385
Improve error message
koxudaxi Jun 21, 2024
a34fed5
Fix unittest name
koxudaxi Jun 21, 2024
850ed69
Apply list comprehension
koxudaxi Jun 21, 2024
6ff6883
change has_self_or_cls to has_self
koxudaxi Jun 21, 2024
35716fd
Improve typing
koxudaxi Jun 21, 2024
dbc6876
Fix name
koxudaxi Jun 21, 2024
569f687
Small fixes
koxudaxi Jun 21, 2024
85d61f2
Apply format
koxudaxi Jun 21, 2024
f213943
Add namespace
koxudaxi Jun 21, 2024
c5a3eb3
Add testcase
koxudaxi Jun 21, 2024
1c9d3d8
Add testcase
koxudaxi Jun 21, 2024
9892a2b
Add namespace checking
koxudaxi Jun 21, 2024
4bf7b2c
Fix namespace
koxudaxi Jun 21, 2024
fe6f4b3
Fix namespace
koxudaxi Jun 21, 2024
62e6845
Format unittest
koxudaxi Jun 21, 2024
c6b1508
Merge branch 'v1' into add_toolkit_for_new_interface
koxudaxi Jun 21, 2024
275d397
Fix unittest
koxudaxi Jun 21, 2024
51420b8
Change _namespace to __namespace__
koxudaxi Jun 21, 2024
7cf391e
Format
koxudaxi Jun 21, 2024
1445caa
Apply format
koxudaxi Jun 21, 2024
469315a
Add docstring
koxudaxi Jun 21, 2024
95e7bd4
Fix __namespace__ argument
koxudaxi Jun 21, 2024
dfa318e
Improve docstring
koxudaxi Jun 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions mirascope/core/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from .streams import BaseAsyncStream, BaseStream
from .structured_streams import BaseAsyncStructuredStream, BaseStructuredStream
from .tools import BaseTool
from .toolkit import BaseToolKit, toolkit_tool

__all__ = [
"BaseAsyncStream",
"BaseAsyncStructuredStream",
"BaseCallParams",
"BaseCallResponse",
"BaseCallResponseChunk",
Expand All @@ -23,7 +23,9 @@
"BaseStream",
"BaseStructuredStream",
"BaseTool",
"_partial",
"BaseToolKit",
"tags",
"toolkit_tool",
"_partial",
"_utils",
]
61 changes: 39 additions & 22 deletions mirascope/core/base/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from abc import update_abstractmethods
from enum import Enum
from string import Formatter
from textwrap import dedent
from typing import (
Annotated,
Any,
Expand All @@ -32,17 +31,19 @@
"""


def format_prompt_template(template: str, attrs: dict[str, Any]) -> str:
"""Formats the given prompt `template`"""
dedented_template = dedent(template).strip()
template_vars = [
var for _, var, _, _ in Formatter().parse(dedented_template) if var is not None
]
def get_template_variables(template: str) -> list[str]:
"""Returns the variables in the given template string."""
return [var for _, var, _, _ in Formatter().parse(template) if var is not None]


def get_template_values(
template_variables: list[str], attrs: dict[str, Any]
) -> dict[str, Any]:
"""Returns the values of the given `template_variables` from the provided `attrs`."""
values = {}
if "self" in attrs:
values["self"] = attrs.get("self")
for var in template_vars:
for var in template_variables:
if var.startswith("self"):
continue
attr = attrs[var]
Expand All @@ -57,6 +58,15 @@ def format_prompt_template(template: str, attrs: dict[str, Any]) -> str:
values[var] = "\n".join([str(item) for item in attr])
else:
values[var] = str(attr)
return values


def format_template(template: str, attrs: dict[str, Any]) -> str:
willbakst marked this conversation as resolved.
Show resolved Hide resolved
"""Formats the given prompt `template`"""
dedented_template = inspect.cleandoc(template).strip()
template_vars = get_template_variables(dedented_template)

values = get_template_values(template_vars, attrs)

return dedented_template.format(**values)

Expand Down Expand Up @@ -94,14 +104,14 @@ def parse_prompt_messages(
)
messages += attr
else:
content = format_prompt_template(match.group(2), attrs)
content = format_template(match.group(2), attrs)
if content:
messages.append({"role": role, "content": content})
if len(messages) == 0:
messages.append(
{
"role": "user",
"content": format_prompt_template(template, attrs),
"content": format_template(template, attrs),
}
)
return messages
Expand All @@ -113,7 +123,7 @@ def parse_prompt_messages(


def convert_function_to_base_tool(
fn: Callable, base: type[BaseToolT]
fn: Callable, base: type[BaseToolT], __doc__: str | None = None
koxudaxi marked this conversation as resolved.
Show resolved Hide resolved
) -> type[BaseToolT]:
"""Constructst a `BaseToolT` type from the given function.

Expand All @@ -124,6 +134,7 @@ def convert_function_to_base_tool(
Args:
fn: The function to convert.
base: The `BaseToolT` type to which the function is converted.
__doc__: The docstring to use for the constructed `BaseToolT` type.

Returns:
The constructed `BaseToolT` type.
Expand All @@ -136,13 +147,16 @@ def convert_function_to_base_tool(
doesn't have a docstring description.
"""
docstring = None
if fn.__doc__:
docstring = parse(fn.__doc__)
func_doc = __doc__ or fn.__doc__
if func_doc:
docstring = parse(func_doc)

field_definitions = {}
hints = get_type_hints(fn)
has_self_or_cls = False
for i, parameter in enumerate(inspect.signature(fn).parameters.values()):
if parameter.name == "self" or parameter.name == "cls":
has_self_or_cls = True
willbakst marked this conversation as resolved.
Show resolved Hide resolved
continue
if parameter.annotation == inspect.Parameter.empty:
raise ValueError("All parameters must have a type annotation.")
Expand Down Expand Up @@ -181,20 +195,23 @@ def convert_function_to_base_tool(
model = create_model(
fn.__name__,
koxudaxi marked this conversation as resolved.
Show resolved Hide resolved
__base__=base,
__doc__=inspect.cleandoc(fn.__doc__) if fn.__doc__ else DEFAULT_TOOL_DOCSTRING,
__doc__=inspect.cleandoc(func_doc) if func_doc else DEFAULT_TOOL_DOCSTRING,
**cast(dict[str, Any], field_definitions),
)

def call(self: base):
return fn(
**{
str(
self.model_fields[field_name].alias
if self.model_fields[field_name].alias
else field_name
): getattr(self, field_name)
for field_name in self.model_dump(exclude={"tool_call"})
}
**(
({"self": self} if has_self_or_cls else {})
| {
str(
self.model_fields[field_name].alias
if self.model_fields[field_name].alias
else field_name
): getattr(self, field_name)
for field_name in self.model_dump(exclude={"tool_call"})
}
)
)

setattr(model, "call", call)
Expand Down
2 changes: 1 addition & 1 deletion mirascope/core/base/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class BookRecommendationPrompt(BasePrompt):

def __str__(self) -> str:
"""Returns the formatted template."""
return _utils.format_prompt_template(self.prompt_template, self.model_dump())
return _utils.format_template(self.prompt_template, self.model_dump())

def message_params(self) -> list[BaseMessageParam]:
"""Returns the template as a formatted list of `Message` instances."""
Expand Down
81 changes: 81 additions & 0 deletions mirascope/core/base/toolkit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from __future__ import annotations

import inspect
from abc import ABC
from typing import Callable, ClassVar, ParamSpec, NamedTuple

from pydantic import BaseModel, ConfigDict
from typing_extensions import LiteralString

from . import BaseTool
from ._utils import convert_function_to_base_tool, get_template_variables

_TOOLKIT_TOOL_METHOD_MARKER: LiteralString = "__toolkit_tool_method__"

P = ParamSpec("P")


def toolkit_tool(
method: Callable[[BaseToolKit, ...], str],
) -> Callable[[BaseToolKit, ...], str]:
# Mark the method as a toolkit tool
setattr(method, _TOOLKIT_TOOL_METHOD_MARKER, True)

return method


class TookKitToolMethod(NamedTuple):
willbakst marked this conversation as resolved.
Show resolved Hide resolved
method: Callable[[BaseToolKit, ...], str]
template_vars: list[str]
template: str


class BaseToolKit(BaseModel, ABC):
"""A class for defining tools for LLM call tools."""

model_config = ConfigDict(arbitrary_types_allowed=True)
_toolkit_tool_methods: ClassVar[list[TookKitToolMethod]]

def create_tools(self) -> list[type[BaseTool]]:
"""The method to create the tools."""
tools: list[type[BaseTool]] = []
for method, template_vars, template in self._toolkit_tool_methods:
willbakst marked this conversation as resolved.
Show resolved Hide resolved
formatted_template = template.format(self=self)
tool = convert_function_to_base_tool(method, BaseTool, formatted_template)
tools.append(tool)
return tools

@classmethod
def __pydantic_init_subclass__(cls, **kwargs):
cls._toolkit_tool_methods = []
for attr in cls.__dict__.values():
if not getattr(attr, _TOOLKIT_TOOL_METHOD_MARKER, False):
continue
# Validate the toolkit_tool_method
if (template := attr.__doc__) is None:
raise ValueError("The toolkit_tool method must have a docstring")

dedented_template = inspect.cleandoc(template)
template_vars = get_template_variables(dedented_template)

for var in template_vars:
if not var.startswith("self."):
# Should be supported un-self variables?
raise ValueError(
"The toolkit_tool method must use self. prefix in template variables"
)

self_var = var[5:]

# Expecting pydantic model fields or class attribute and property
if self_var in cls.model_fields or hasattr(cls, self_var):
continue
raise ValueError(
f"The toolkit_tool method template variable {var} is not found in the class"
)

cls._toolkit_tool_methods.append(
TookKitToolMethod(attr, template_vars, dedented_template)
)
if not cls._toolkit_tool_methods:
raise ValueError("No toolkit_tool method found")
35 changes: 35 additions & 0 deletions tests/core/base/test_toolkit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Tests for the `toolkit` module."""
from typing import Literal

from mirascope.core.base import BaseToolKit, toolkit_tool


def test_toolkit() -> None:
"""Tests the `BaseToolKit` class and the `toolkit_tool` decorator."""

class BookRecommendationToolKit(BaseToolKit):
"""A toolkit for recommending books."""

reading_level: Literal["beginner", "advanced"]

@toolkit_tool
def format_book(self, title: str, author: str) -> str:
"""Returns the title and author of a book nicely formatted.

Reading level: {self.reading_level}
"""
return f"{title} by {author}"

toolkit = BookRecommendationToolKit(reading_level="beginner")
tools = toolkit.create_tools()
assert len(tools) == 1
tool = tools[0]
assert tool._name() == "format_book"
koxudaxi marked this conversation as resolved.
Show resolved Hide resolved
assert (
tool._description()
== "Returns the title and author of a book nicely formatted.\n\nReading level: beginner"
)
assert (
tool(title="The Name of the Wind", author="Rothfuss, Patrick").call()
== "The Name of the Wind by Rothfuss, Patrick"
)
7 changes: 3 additions & 4 deletions tests/core/base/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest
from pydantic import BaseModel

from mirascope.core.base import BaseTool, _utils
from mirascope.core.base import _utils, BaseTool


def test_format_prompt_template() -> None:
Expand All @@ -32,7 +32,7 @@ def test_format_prompt_template() -> None:
"genres": genres,
"authors_and_books": authors_and_books,
}
formatted_prompt_template = _utils.format_prompt_template(prompt_template, attrs)
formatted_prompt_template = _utils.format_template(prompt_template, attrs)
assert (
formatted_prompt_template
== dedent(
Expand Down Expand Up @@ -114,8 +114,7 @@ def fn(model_name: str = "", self=None, cls=None) -> str:
def test_convert_function_to_base_model_errors() -> None:
"""Tests the various `ValueErro` cases in `convert_function_to_base_model`."""

def empty(param) -> str:
... # pragma: no cover
def empty(param) -> str: ... # pragma: no cover

with pytest.raises(ValueError):
_utils.convert_function_to_base_tool(empty, BaseTool)
Expand Down