Skip to content

Commit

Permalink
chore: chat_in_terminal_async
Browse files Browse the repository at this point in the history
  • Loading branch information
zhudotexe committed Jul 20, 2023
1 parent ebe4b21 commit 5dc2878
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 7 deletions.
2 changes: 2 additions & 0 deletions docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,5 @@ See :doc:`engine_reference`.
Utilities
---------
.. autofunction:: kani.chat_in_terminal

.. autofunction:: kani.chat_in_terminal_async
2 changes: 1 addition & 1 deletion kani/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from .ai_function import AIParam, ai_function
from .kani import Kani
from .models import ChatMessage, ChatRole
from .utils.cli import chat_in_terminal
from .utils.cli import chat_in_terminal, chat_in_terminal_async
23 changes: 18 additions & 5 deletions kani/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,20 @@ def _function_formatter(message: ChatMessage):
return f"Thinking ({message.function_call.name})..."


async def _chat_in_terminal(kani: Kani, rounds: int = 0):
async def chat_in_terminal_async(kani: Kani, rounds: int = 0, stopword: str = None):
"""Async version of :func:`.chat_in_terminal`.
Use in environments when there is already an asyncio loop running (e.g. Google Colab).
"""
if os.getenv("KANI_DEBUG") is not None:
logging.basicConfig(level=logging.DEBUG)

try:
round_num = 0
while round_num < rounds or not rounds:
round_num += 1
query = input("USER: ")
if stopword and query == stopword:
break
async for msg in kani.full_round_str(query, function_call_formatter=_function_formatter):
print(f"AI: {msg}")
except KeyboardInterrupt:
Expand All @@ -26,7 +34,7 @@ async def _chat_in_terminal(kani: Kani, rounds: int = 0):
await kani.engine.close()


def chat_in_terminal(kani: Kani, rounds: int = 0):
def chat_in_terminal(kani: Kani, rounds: int = 0, stopword: str = None):
"""Chat with a kani right in your terminal.
Useful for playing with kani, quick prompt engineering, or demoing the library.
Expand All @@ -38,7 +46,12 @@ def chat_in_terminal(kani: Kani, rounds: int = 0):
This function is only a development utility and should not be used in production.
:param rounds: The number of chat rounds to play (defaults to 0 for infinite).
:param stopword: Break out of the chat loop if the user sends this message.
"""
if os.getenv("KANI_DEBUG") is not None:
logging.basicConfig(level=logging.DEBUG)
asyncio.run(_chat_in_terminal(kani, rounds))
try:
asyncio.run(chat_in_terminal_async(kani, rounds=rounds, stopword=stopword))
except RuntimeError:
print(
f"WARNING: It looks like you're in an environment with a running asyncio loop (e.g. Google Colab).\nYou"
f" should use `await chat_in_terminal_async(...)` instead."
)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "kani"
version = "0.0.1"
version = "0.0.2"
authors = [
{ name = "Andrew Zhu", email = "[email protected]" },
]
Expand Down

0 comments on commit 5dc2878

Please sign in to comment.