Skip to content

Commit

Permalink
Merge pull request #340 from koxudaxi/add_toolkit_for_new_interface
Browse files Browse the repository at this point in the history
Add ToolKit for new interface
  • Loading branch information
willbakst committed Jun 21, 2024
2 parents f036fcc + dfa318e commit 8b51c53
Show file tree
Hide file tree
Showing 8 changed files with 421 additions and 30 deletions.
1 change: 1 addition & 0 deletions examples/logging/logging_to_csv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Logging your LLM responses to a CSV file"""

import os

import pandas as pd
Expand Down
1 change: 1 addition & 0 deletions examples/tool_calls/tool_calls_with_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
You can add examples to your tool definitions to help the model better use the tool.
Examples can be added for individual fields as well as for the entire model.
"""

import os

from pydantic import ConfigDict, Field
Expand Down
6 changes: 5 additions & 1 deletion mirascope/core/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from .structured_stream import BaseStructuredStream
from .structured_stream_async import BaseAsyncStructuredStream
from .tool import BaseTool
from .toolkit import BaseToolKit, toolkit_tool


__all__ = [
"BaseAsyncStream",
Expand All @@ -25,7 +27,9 @@
"BaseStream",
"BaseStructuredStream",
"BaseTool",
"_partial",
"BaseToolKit",
"tags",
"toolkit_tool",
"_partial",
"_utils",
]
95 changes: 71 additions & 24 deletions mirascope/core/base/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from abc import update_abstractmethods
from enum import Enum
from string import Formatter
from textwrap import dedent
from typing import (
Annotated,
Any,
Expand Down Expand Up @@ -34,17 +33,34 @@
"""


def format_prompt_template(template: str, attrs: dict[str, Any]) -> str:
"""Formats the given prompt `template`"""
dedented_template = dedent(template).strip()
template_vars = [
var for _, var, _, _ in Formatter().parse(dedented_template) if var is not None
]
def get_template_variables(template: str) -> list[str]:
"""Returns the variables in the given template string.
Args:
template: The template string to parse.
Returns:
The variables in the template string.
"""
return [var for _, var, _, _ in Formatter().parse(template) if var is not None]


def get_template_values(
template_variables: list[str], attrs: dict[str, Any]
) -> dict[str, Any]:
"""Returns the values of the given `template_variables` from the provided `attrs`.
Args:
template_variables: The variables to extract from the `attrs`.
attrs: The attributes to extract the variables from.
Returns:
The values of the template variables.
"""
values = {}
if "self" in attrs:
values["self"] = attrs.get("self")
for var in template_vars:
for var in template_variables:
if var.startswith("self"):
continue
attr = attrs[var]
Expand All @@ -59,6 +75,24 @@ def format_prompt_template(template: str, attrs: dict[str, Any]) -> str:
values[var] = "\n".join([str(item) for item in attr])
else:
values[var] = str(attr)
return values


def format_template(template: str, attrs: dict[str, Any]) -> str:
"""Formats the given prompt `template`
Args:
template: The template to format.
attrs: The attributes to use for formatting.
Returns:
The formatted template.
"""
dedented_template = inspect.cleandoc(template).strip()
template_vars = get_template_variables(dedented_template)

values = get_template_values(template_vars, attrs)

return dedented_template.format(**values)

Expand Down Expand Up @@ -96,14 +130,14 @@ def parse_prompt_messages(
)
messages += attr
else:
content = format_prompt_template(match.group(2), attrs)
content = format_template(match.group(2), attrs)
if content:
messages.append({"role": role, "content": content})
if len(messages) == 0:
messages.append(
{
"role": "user",
"content": format_prompt_template(template, attrs),
"content": format_template(template, attrs),
}
)
return messages
Expand All @@ -115,7 +149,10 @@ def parse_prompt_messages(


def convert_function_to_base_tool(
fn: Callable, base: type[BaseToolT]
fn: Callable,
base: type[BaseToolT],
__doc__: str | None = None,
__namespace__: str | None = None,
) -> type[BaseToolT]:
"""Constructst a `BaseToolT` type from the given function.
Expand All @@ -126,6 +163,8 @@ def convert_function_to_base_tool(
Args:
fn: The function to convert.
base: The `BaseToolT` type to which the function is converted.
__doc__: The docstring to use for the constructed `BaseToolT` type.
__namespace__: The namespace to use for the constructed `BaseToolT` type.
Returns:
The constructed `BaseToolT` type.
Expand All @@ -138,13 +177,18 @@ def convert_function_to_base_tool(
doesn't have a docstring description.
"""
docstring = None
if fn.__doc__:
docstring = parse(fn.__doc__)
func_doc = __doc__ or fn.__doc__
if func_doc:
docstring = parse(func_doc)

field_definitions = {}
hints = get_type_hints(fn)
has_self = False
for i, parameter in enumerate(inspect.signature(fn).parameters.values()):
if parameter.name == "self" or parameter.name == "cls":
if parameter.name == "self":
has_self = True
continue
if parameter.name == "cls":
continue
if parameter.annotation == inspect.Parameter.empty:
raise ValueError("All parameters must have a type annotation.")
Expand Down Expand Up @@ -181,22 +225,25 @@ def convert_function_to_base_tool(
)

model = create_model(
fn.__name__,
f"{__namespace__}.{fn.__name__}" if __namespace__ else fn.__name__,
__base__=base,
__doc__=inspect.cleandoc(fn.__doc__) if fn.__doc__ else DEFAULT_TOOL_DOCSTRING,
__doc__=inspect.cleandoc(func_doc) if func_doc else DEFAULT_TOOL_DOCSTRING,
**cast(dict[str, Any], field_definitions),
)

def call(self: base):
return fn(
**{
str(
self.model_fields[field_name].alias
if self.model_fields[field_name].alias
else field_name
): getattr(self, field_name)
for field_name in self.model_dump(exclude={"tool_call"})
}
**(
({"self": self} if has_self else {})
| {
str(
self.model_fields[field_name].alias
if self.model_fields[field_name].alias
else field_name
): getattr(self, field_name)
for field_name in self.model_dump(exclude={"tool_call"})
}
)
)

setattr(model, "call", call)
Expand Down
2 changes: 1 addition & 1 deletion mirascope/core/base/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class BookRecommendationPrompt(BasePrompt):

def __str__(self) -> str:
"""Returns the formatted template."""
return _utils.format_prompt_template(self.prompt_template, self.model_dump())
return _utils.format_template(self.prompt_template, self.model_dump())

def message_params(self) -> list[BaseMessageParam]:
"""Returns the template as a formatted list of `Message` instances."""
Expand Down
133 changes: 133 additions & 0 deletions mirascope/core/base/toolkit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""The module for defining the toolkit class for LLM call tools."""

from __future__ import annotations

import inspect
from abc import ABC
from typing import Callable, ClassVar, NamedTuple

from pydantic import BaseModel, ConfigDict
from typing_extensions import ParamSpec, Concatenate

from . import BaseTool
from ._utils import convert_function_to_base_tool, get_template_variables

_TOOLKIT_TOOL_METHOD_MARKER: str = "__toolkit_tool_method__"

_namespaces: set[str] = set()

P = ParamSpec("P")


def toolkit_tool(
method: Callable[Concatenate[BaseToolKit, P], str],
) -> Callable[Concatenate[BaseToolKit, P], str]:
# Mark the method as a toolkit tool
setattr(method, _TOOLKIT_TOOL_METHOD_MARKER, True)

return method


class ToolKitToolMethod(NamedTuple):
method: Callable[..., str]
template_vars: list[str]
template: str


class BaseToolKit(BaseModel, ABC):
"""A class for defining tools for LLM call tools.
The class should have methods decorated with `@toolkit_tool` to create tools.
Example:
```python
from mirascope.core.base import BaseToolKit, toolkit_tool
from mirascope.core.openai import openai_call
class BookRecommendationToolKit(BaseToolKit):
'''A toolkit for recommending books.'''
__namespace__: ClassVar[str | None] = 'book_tools'
reading_level: Literal["beginner", "advanced"]
@toolkit_tool
def format_book(self, title: str, author: str) -> str:
'''Returns the title and author of a book nicely formatted.
Reading level: {self.reading_level}
'''
return f"{title} by {author}"
toolkit = BookRecommendationToolKit(reading_level="beginner")
tools = toolkit.create_tools()
@openai_call(model="gpt-4o")
def recommend_book(genre: str, reading_level: Literal["beginner", "advanced"]):
'''Recommend a {genre} book.'''
toolkit = BookRecommendationToolKit(reading_level=reading_level)
return {"tools": [toolkit.create_tools()]}
response = recommend_book("fantasy", "beginner")
if tool := response.tool:
output = tool.call()
print(output)
#> The Name of the Wind by Patrick Rothfuss
else:
print(response.content)
#> Sure! I would recommend...
```
"""

model_config = ConfigDict(arbitrary_types_allowed=True)
_toolkit_tool_methods: ClassVar[list[ToolKitToolMethod]]
__namespace__: ClassVar[str | None] = None

def create_tools(self) -> list[type[BaseTool]]:
"""The method to create the tools."""
return [
convert_function_to_base_tool(
method, BaseTool, template.format(self=self), self.__namespace__
)
for method, template_vars, template in self._toolkit_tool_methods
]

@classmethod
def __pydantic_init_subclass__(cls, **kwargs):
# validate the namespace
if cls.__namespace__:
if cls.__namespace__ in _namespaces:
raise ValueError(f"The namespace {cls.__namespace__} is already used")
_namespaces.add(cls.__namespace__)

cls._toolkit_tool_methods = []
for attr in cls.__dict__.values():
if not getattr(attr, _TOOLKIT_TOOL_METHOD_MARKER, False):
continue
# Validate the toolkit_tool_method
if (template := attr.__doc__) is None:
raise ValueError("The toolkit_tool method must have a docstring")

dedented_template = inspect.cleandoc(template)
template_vars = get_template_variables(dedented_template)

for var in template_vars:
if not var.startswith("self."):
raise ValueError(
"The toolkit_tool method must use self. prefix in template variables "
"when creating tools dynamically"
)

self_var = var[5:]

# Expecting pydantic model fields or class attribute and property
if self_var in cls.model_fields or hasattr(cls, self_var):
continue
raise ValueError(
f"The toolkit_tool method template variable {var} is not found in the class"
)

cls._toolkit_tool_methods.append(
ToolKitToolMethod(attr, template_vars, dedented_template)
)
if not cls._toolkit_tool_methods:
raise ValueError("No toolkit_tool method found")
Loading

0 comments on commit 8b51c53

Please sign in to comment.