Skip to content

Commit

Permalink
fix(cohere): various prompt bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
zhudotexe committed Apr 10, 2024
1 parent 8c1693f commit 524081f
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 13 deletions.
6 changes: 5 additions & 1 deletion kani/ai_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,11 @@ def get_params(self) -> list[AIParamSchema]:

# get aiparam and add it to the list
ai_param = get_aiparam(annotation)
params.append(AIParamSchema(name=name, t=type_hints[name], default=param.default, aiparam=ai_param))
params.append(
AIParamSchema(
name=name, t=type_hints[name], default=param.default, aiparam=ai_param, inspect_param=param
)
)
return params

def create_json_schema(self) -> dict:
Expand Down
2 changes: 2 additions & 0 deletions kani/engines/huggingface/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def _build_prompt_tools(self, messages: list[ChatMessage], functions: list[AIFun
tool_prompt = DEFAULT_TOOL_PROMPT.format(user_functions=function_text)

# wrap the initial system message, if any
messages = messages.copy()
if messages and messages[0].role == ChatRole.SYSTEM:
messages[0] = messages[0].copy_with(content=DEFAULT_PREAMBLE + messages[0].text + tool_prompt)
# otherwise add it in
Expand All @@ -195,6 +196,7 @@ def _build_prompt_tools(self, messages: list[ChatMessage], functions: list[AIFun

def _build_prompt_rag(self, messages: list[ChatMessage]):
# wrap the initial system message, if any
messages = messages.copy()
if messages and messages[0].role == ChatRole.SYSTEM:
messages[0] = messages[0].copy_with(content=DEFAULT_PREAMBLE + messages[0].text)
# otherwise add it in
Expand Down
8 changes: 6 additions & 2 deletions kani/json_schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
import typing
from typing import TYPE_CHECKING, Optional
from typing import Optional, TYPE_CHECKING

import pydantic

Expand All @@ -18,11 +18,12 @@ class AIParamSchema:
This class is only used internally within kani and generally shouldn't be constructed manually.
"""

def __init__(self, name: str, t: type, default, aiparam: Optional["AIParam"] = None):
def __init__(self, name: str, t: type, default, aiparam: Optional["AIParam"], inspect_param: inspect.Parameter):
self.name = name
self.type = t
self.default = default
self.aiparam = aiparam
self.inspect_param = inspect_param

@property
def required(self):
Expand All @@ -37,6 +38,9 @@ def origin_type(self):
def description(self):
return self.aiparam.desc if self.aiparam is not None else None

def __str__(self):
return str(self.inspect_param)


class JSONSchemaBuilder(pydantic.json_schema.GenerateJsonSchema):
"""Subclass of the Pydantic JSON schema builder to provide more fine-grained control over titles and refs."""
Expand Down
23 changes: 13 additions & 10 deletions kani/prompts/impl/cohere.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import json

from kani import AIFunction
Expand Down Expand Up @@ -95,11 +96,12 @@ def tool_call_formatter(msg: ChatMessage) -> str:
indent=4,
)
return f"{text}Action: ```json\n{tool_calls}\n```"
else:
return ( # is the EOT/SOT token doing weird stuff here?
'Action: ```json\n[\n {\n "tool_name": "directly_answer",\n "parameters": {}\n'
f" }}\n]\n```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{msg.text}"
)
# else:
# return ( # is the EOT/SOT token doing weird stuff here?
# 'Action: ```json\n[\n {\n "tool_name": "directly_answer",\n "parameters": {}\n'
# f" }}\n]\n```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{msg.text}"
# )
return msg.content


def build_tool_pipeline(
Expand All @@ -117,14 +119,16 @@ def build_tool_pipeline(

steps = []

# format function calls with an Action: prefix; otherwise do a directly_answer call
# format function calls with an Action: prefix
if include_function_calls:

def apply_tc_format(msg):
msg.content = tool_call_formatter(msg)
return msg

steps.append(Apply(apply_tc_format, role=ChatRole.ASSISTANT))
else:
steps.append(Remove(role=ChatRole.ASSISTANT, predicate=lambda msg: msg.content is None))

# keep function results around as SYSTEM messages
if include_function_results:
Expand Down Expand Up @@ -174,6 +178,8 @@ def remover(m, is_last):
.merge_consecutive(role=ChatRole.FUNCTION, joiner=function_result_joiner)
# remove all but the last function message
.apply(remover, role=ChatRole.FUNCTION)
# remove asst messages with no content (function calls)
.remove(role=ChatRole.ASSISTANT, predicate=lambda msg: msg.content is None)
.conversation_fmt(
prefix="<BOS_TOKEN>",
generation_suffix=(
Expand Down Expand Up @@ -202,10 +208,7 @@ def function_prompt(f: AIFunction) -> str:
# build params
param_parts = []
for param in params:
default = ""
if param.default:
default = f" = {param.default}"
param_parts.append(f"{param.name}: {param.type}{default}")
param_parts.append(str(param))
params_str = ", ".join(param_parts)

# build docstring
Expand Down

0 comments on commit 524081f

Please sign in to comment.