Skip to content

Commit

Permalink
fix(cohere): token counting on tool calls
Browse files Browse the repository at this point in the history
  • Loading branch information
zhudotexe committed Apr 10, 2024
1 parent 32826a5 commit 8c1693f
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions kani/prompts/impl/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,19 @@ def function_result_joiner(msgs):
return "\n\n".join(contents)


def tool_call_formatter(msg: ChatMessage):
def tool_call_formatter(msg: ChatMessage) -> str:
if msg.tool_calls:
text = msg.text + "\n" if msg.text else ""
tool_calls = json.dumps(
[{"tool_name": tc.function.name, "parameters": tc.function.kwargs} for tc in msg.tool_calls],
indent=4,
)
msg.content = f"{text}Action: ```json\n{tool_calls}\n```"
return f"{text}Action: ```json\n{tool_calls}\n```"
else:
msg.content = ( # is the EOT/SOT token doing weird stuff here?
return ( # is the EOT/SOT token doing weird stuff here?
'Action: ```json\n[\n {\n "tool_name": "directly_answer",\n "parameters": {}\n'
f" }}\n]\n```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{msg.text}"
)
return msg


def build_tool_pipeline(
Expand All @@ -120,7 +119,12 @@ def build_tool_pipeline(

# format function calls with an Action: prefix; otherwise do a directly_answer call
if include_function_calls:
steps.append(Apply(tool_call_formatter, role=ChatRole.ASSISTANT))

def apply_tc_format(msg):
msg.content = tool_call_formatter(msg)
return msg

steps.append(Apply(apply_tc_format, role=ChatRole.ASSISTANT))

# keep function results around as SYSTEM messages
if include_function_results:
Expand Down

0 comments on commit 8c1693f

Please sign in to comment.