diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 6e8140f..d22be7f 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -54,7 +54,7 @@ jobs: uses: actions/cache@v3 with: path: ~/.cache/huggingface - key: ${{ runner.os }}-llamav2 + key: ${{ runner.os }}-llamav2-cpp - name: Run LLaMA tests run: pytest -m llama diff --git a/LICENSE b/LICENSE index 5e466e2..0aeabf9 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2023 Andrew Zhu +Copyright (c) 2023-present Andrew Zhu Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 4072882..a520e5a 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,13 @@ Discord +
+ + Model zoo + + + Retrieval example +

# kani (カニ) @@ -32,10 +39,26 @@ developers alike. kani comes with support for the following models out of the box, with a model-agnostic framework to add support for many more: +**Hosted Models** + - OpenAI Models (GPT-3.5-turbo, GPT-4, GPT-4-turbo) - Anthropic Models (Claude, Claude Instant) -- LLaMA v2 (via Hugging Face or ctransformers) & fine-tunes -- Vicuna v1.3 (via Hugging Face) & fine-tunes + +**Open Source Models** + +kani supports every chat model available on Hugging Face through `transformers` or `llama.cpp`! + +In particular, we have reference implementations for the following base models, and their fine-tunes: + +- [Command R](https://huggingface.co/CohereForAI/c4ai-command-r-v01) + and [Command R+](https://huggingface.co/CohereForAI/c4ai-command-r-plus) +- [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) + and [Mixtral-8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) (uses LLaMA prompting format) +- [Gemma](https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b) (all sizes) +- [LLaMA 2](https://huggingface.co/meta-llama) (all sizes) +- [Vicuna v1.3](https://huggingface.co/lmsys/vicuna-7b-v1.3) + +Check out the [Model Zoo](examples/4_engines_zoo.py) to see how to use each of these models in your application! **Interested in contributing? Check out our [guide](https://kani.readthedocs.io/en/latest/community/contributing.html).** @@ -49,7 +72,7 @@ more: - **Lightweight and high-level** - kani implements common boilerplate to interface with language models without forcing you to use opinionated prompt frameworks or complex library-specific tooling. - **Model agnostic** - kani provides a simple interface to implement: token counting and completion generation. - Implement these two, and kani can run with any language model. + kani lets developers switch which language model runs on the backend without major code refactors. - **Automatic chat memory management** - Allow chat sessions to flow without worrying about managing the number of tokens in the history - kani takes care of it. - **Function calling with model feedback and retry** - Give models access to functions in just one line of code. @@ -79,6 +102,7 @@ Then, let's use kani to create a simple chatbot using ChatGPT as a backend. ```python # import the library +import asyncio from kani import Kani, chat_in_terminal from kani.engines.openai import OpenAIEngine @@ -94,9 +118,17 @@ engine = OpenAIEngine(api_key, model="gpt-3.5-turbo") # system_prompt="You are..." here. ai = Kani(engine) -# kani comes with a utility to interact with a kani through your terminal! Check out -# the docs for how to use kani programmatically. +# kani comes with a utility to interact with a kani through your terminal... chat_in_terminal(ai) + + +# or you can use kani programmatically in an async function! +async def main(): + resp = await ai.chat_round("What is the airspeed velocity of an unladen swallow?") + print(resp.text) + + +asyncio.run(main()) ``` kani makes the time to set up a working chat model short, while offering the programmer deep customizability over @@ -112,8 +144,9 @@ decorator. ```python # import the library +import asyncio from typing import Annotated -from kani import AIParam, Kani, ai_function, chat_in_terminal +from kani import AIParam, Kani, ai_function, chat_in_terminal, ChatRole from kani.engines.openai import OpenAIEngine # set up the engine as above @@ -136,13 +169,47 @@ class MyKani(Kani): ai = MyKani(engine) + +# the terminal utility allows you to test function calls... chat_in_terminal(ai) + + +# and you can track multiple rounds programmatically. +async def main(): + async for msg in ai.full_round("What's the weather in Tokyo?"): + print(msg.role, msg.text) + + +asyncio.run(main()) ``` kani guarantees that function calls are valid by the time they reach your methods while allowing you to focus on writing code. For more information, check out [the function calling docs](https://kani.readthedocs.io/en/latest/function_calling.html). +## Streaming + +kani supports streaming responses from the underlying language model token-by-token, even in the presence of function +calls. Streaming is designed to be a drop-in superset of the ``chat_round`` and ``full_round`` methods, allowing you to +gradually refactor your code without ever leaving it in a broken state. + +```python +async def stream_chat(): + stream = ai.chat_round_stream("What does kani mean?") + async for token in stream: + print(token, end="") + print() + msg = await stream.message() # or `await stream` + + +async def stream_with_function_calling(): + async for stream in ai.full_round_stream("What's the weather in Tokyo?"): + async for token in stream: + print(token, end="") + print() + msg = await stream.message() +``` + ## Why kani? Existing frameworks for language models like LangChain and simpleaichat are opinionated and/or heavyweight - they edit @@ -159,7 +226,6 @@ kani is to LangChain as Flask (or FastAPI) is to Django. kani is appropriate for everyone from academic researchers to industry professionals to hobbyists to use without worrying about under-the-hood hacks. - ## Docs To learn more about how @@ -178,21 +244,6 @@ https://github.com/zhudotexe/kani/actions/workflows/pytest.yml?query=branch%3Ama Simply click on the latest build to see LLaMA's output! -## Kani in the News - -Kani will appear at the NLP Open Source Software workshop at EMNLP 2023! - -We are really excited and grateful to see people talking about Kani online. We are also trending on Papers With Code, -GitHub, and OSS Insight. Check out some recent articles and videos below! - -- [Researchers from the University of Pennsylvania Introduce Kani: A Lightweight, Flexible, and Model-Agnostic Open-Source AI Framework for Building Language Model Applications](https://www.marktechpost.com/2023/09/18/researchers-from-the-university-of-pennsylvania-introduce-kani-a-lightweight-flexible-and-model-agnostic-open-source-ai-framework-for-building-language-model-applications/) -- [Unlocking AI Potential: Unveiling Kani, the Groundbreaking Open-Source Framework Revolutionizing Large Language Model Applications](https://www.cjco.com.au/article/news/unlocking-ai-potential-unveiling-kani-the-groundbreaking-open-source-framework-revolutionizing-large-language-model-applications/) -- [Kani: A Lightweight and Customizable Framework for Language Model Applications](https://ts2.space/en/kani-a-lightweight-and-customizable-framework-for-language-model-applications/) -- [Introducing Kani (Sanskrit Word): A Game-Changing Open-Source AI Framework for Language Models](https://www.linkedin.com/pulse/introducing-kani-sanskrit-word-game-changing/) - - *Kani was originally named after the Japanese word for crab and coincidentally means "knowledge" in Sanskrit.* -- [kani: lightweight LLM framework (Japanese)](https://note.com/hamachi_jp/n/n342becc4f345) -- [Top Trending LLM Projects of the Week: Dive into the Future of Tech! 🚀](https://www.youtube.com/watch?v=qoGKzmnhAnA) - ## Who we are University of Pennsylvania Logo @@ -220,13 +271,25 @@ University of Pennsylvania. We're all members of If you use Kani, please cite us as: ``` -@misc{zhu2023kani, - title={Kani: A Lightweight and Highly Hackable Framework for Building Language Model Applications}, - author={Andrew Zhu and Liam Dugan and Alyssa Hwang and Chris Callison-Burch}, - year={2023}, - eprint={2309.05542}, - archivePrefix={arXiv}, - primaryClass={cs.SE} +@inproceedings{zhu-etal-2023-kani, + title = "Kani: A Lightweight and Highly Hackable Framework for Building Language Model Applications", + author = "Zhu, Andrew and + Dugan, Liam and + Hwang, Alyssa and + Callison-Burch, Chris", + editor = "Tan, Liling and + Milajevs, Dmitrijs and + Chauhan, Geeticka and + Gwinnup, Jeremy and + Rippeth, Elijah", + booktitle = "Proceedings of the 3rd Workshop for Natural Language Processing Open Source Software (NLP-OSS 2023)", + month = dec, + year = "2023", + address = "Singapore", + publisher = "Association for Computational Linguistics", + url = "https://aclanthology.org/2023.nlposs-1.8", + doi = "10.18653/v1/2023.nlposs-1.8", + pages = "65--77", } ``` diff --git a/citation.cff b/citation.cff index 3a96e49..43b2437 100644 --- a/citation.cff +++ b/citation.cff @@ -15,7 +15,7 @@ authors: orcid: "https://orcid.org/0000-0001-8196-1943" title: "kani" date-released: 2023-09-11 -url: "https://github.com/zhudotexe/kani/tree/main" +url: "https://github.com/zhudotexe/kani" preferred-citation: type: generic authors: @@ -32,10 +32,13 @@ preferred-citation: given-names: "Chris" orcid: "https://orcid.org/0000-0001-8196-1943" year: 2023 - month: 9 - day: 11 + month: 12 title: "Kani: A Lightweight and Highly Hackable Framework for Building Language Model Applications" - eprint: 2309.05542 - archive-prefix: "arXiv" - primary-class: "cs.SE" - url: "https://arxiv.org/abs/2309.05542" + collection-title: "Proceedings of the 3rd Workshop for Natural Language Processing Open Source Software (NLP-OSS 2023)" + collection-type: proceedings + location: "Singapore" + doi: "10.18653/v1/2023.nlposs-1.8" + url: "https://aclanthology.org/2023.nlposs-1.8" + publisher: "Association for Computational Linguistics" + start: 65 + end: 77 diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 1b84708..d198dea 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -48,6 +48,27 @@ AI Function .. autoclass:: kani.AIParam :members: +Streaming +--------- +.. autoclass:: kani.streaming.StreamManager + :members: + + .. automethod:: __await__ + + .. automethod:: __aiter__ + +Prompting +--------- +.. automodule:: kani.prompts + +.. autoclass:: kani.PromptPipeline + :members: + + .. automethod:: __call__ + +.. autoclass:: kani.prompts.PipelineStep + :members: + Internals --------- .. autoclass:: kani.FunctionCallResult @@ -66,6 +87,8 @@ Utilities .. autofunction:: kani.chat_in_terminal_async +.. autofunction:: kani.print_stream + Message Formatters ^^^^^^^^^^^^^^^^^^ .. automodule:: kani.utils.message_formatters diff --git a/docs/community/contributing.rst b/docs/community/contributing.rst index 4375934..47dd140 100644 --- a/docs/community/contributing.rst +++ b/docs/community/contributing.rst @@ -11,8 +11,6 @@ You can contribute in many ways: helping us develop the library, (Funding agencies: the core devs are PhD students - you can support us that way too! |:wink:| Send us an email.) -.. todo: cite us - Contributing Code ----------------- When contributing code and opening PRs, there are a couple things to keep in mind. kani is built on the principles diff --git a/docs/conf.py b/docs/conf.py index 98b29e9..210bcbf 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -14,7 +14,7 @@ # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information project = "kani" -copyright = "2023, Andrew Zhu, Liam Dugan, and Alyssa Hwang" +copyright = "2023-present, Andrew Zhu, Liam Dugan, and Alyssa Hwang" author = "Andrew Zhu, Liam Dugan, and Alyssa Hwang" # -- General configuration --------------------------------------------------- @@ -51,8 +51,10 @@ nitpick_ignore_regex = [ (r"py:class", r"aiohttp\..*"), # aiohttp intersphinx is borked (r"py:class", r"torch\..*"), # idk if torch has intersphinx - (r"py:class", r"kani\.engines\.openai\.models\..*"), # internal models are just pydantic models of the oai api + (r"py:class", r"openai\..*"), # openai does not use sphinx for docs (r"py:class", r"anthropic\..*"), # anthropic does not use sphinx for docs + (r"py:class", r"asyncio\.\w+\..*"), # asyncio submodule intersphinx is borked + (r"py:class", r"kani\..*\.[\w_0-9]*T"), # ignore generics and other typevars ] # sphinx.ext.autodoc diff --git a/docs/engine_reference.rst b/docs/engine_reference.rst index 170d988..aed303a 100644 --- a/docs/engine_reference.rst +++ b/docs/engine_reference.rst @@ -21,9 +21,6 @@ OpenAI ------ .. autoclass:: kani.engines.openai.OpenAIEngine -.. autoclass:: kani.engines.openai.client.OpenAIClient - :members: - Anthropic --------- .. autoclass:: kani.engines.anthropic.AnthropicEngine @@ -37,13 +34,13 @@ Hugging Face .. autoclass:: kani.engines.huggingface.llama2.LlamaEngine :members: -.. autoclass:: kani.engines.huggingface.vicuna.VicunaEngine +.. autoclass:: kani.engines.huggingface.cohere.CommandREngine :members: -CTransformers -------------- -.. autoclass:: kani.engines.ctransformers.CTransformersEngine +.. autoclass:: kani.engines.huggingface.vicuna.VicunaEngine :members: -.. autoclass:: kani.engines.ctransformers.llama2.LlamaCTransformersEngine +llama.cpp +--------- +.. autoclass:: kani.engines.llamacpp.LlamaCppEngine :members: diff --git a/docs/engines.rst b/docs/engines.rst index 0e821ba..ac5882c 100644 --- a/docs/engines.rst +++ b/docs/engines.rst @@ -18,4 +18,4 @@ In this section, we'll discuss how to implement your own engine to use any langu engines/implementing engines/http engines/huggingface - engines/ctransformers + engines/llamacpp diff --git a/docs/engines/ctransformers.rst b/docs/engines/ctransformers.rst deleted file mode 100644 index d365745..0000000 --- a/docs/engines/ctransformers.rst +++ /dev/null @@ -1,20 +0,0 @@ -CTransformers -============= -If your language model backend is available with GGML, kani includes a base engine that implements -a prediction pipeline. - -Instead of having to implement the prediction logic, all you have to do is subclass :class:`.CTransformersEngine` and -implement :meth:`~.CTransformersEngine.build_prompt` and :meth:`~.BaseEngine.message_len`. - -.. seealso:: - - The source code of the :class:`.LlamaCTransformersEngine`, which uses the CTransformersEngine. - -.. autoclass:: kani.engines.ctransformers.base.CTransformersEngine - :noindex: - - .. automethod:: kani.engines.ctransformers.base.CTransformersEngine.build_prompt - :noindex: - - .. automethod:: kani.engines.ctransformers.base.CTransformersEngine.message_len - :noindex: diff --git a/docs/engines/http.rst b/docs/engines/http.rst index fd215be..2c0daf0 100644 --- a/docs/engines/http.rst +++ b/docs/engines/http.rst @@ -1,14 +1,18 @@ HTTP Client =========== + +.. danger:: + The aiohttp-based BaseClient has been deprecated in v1.0.0 and will be removed in a future version. + We recommend using `httpx `_ to make HTTP requests instead. + + We have removed the top-level library dependency on ``aiohttp`` - existing code using the BaseClient will require + manual installation of the ``aiohttp`` package. + If your language model backend exposes an HTTP API, you can create a subclass of :class:`.BaseClient` to interface with it. Your engine should then create an instance of the new HTTP client and call it to make predictions. Minimally, to use the HTTP client, your subclass should set the ``SERVICE_BASE`` class variable. -.. seealso:: - - The source code of the :class:`.OpenAIClient`, which uses the HTTP client. - .. autoclass:: kani.engines.httpclient.BaseClient :noindex: :members: diff --git a/docs/engines/huggingface.rst b/docs/engines/huggingface.rst index d937264..4ba89f9 100644 --- a/docs/engines/huggingface.rst +++ b/docs/engines/huggingface.rst @@ -3,8 +3,12 @@ HuggingFace If your language model backend is available on HuggingFace or is compatible with ``transformers``' ``AutoModelForCausalLM`` interface, kani includes a base engine that implements a prediction pipeline. -Instead of having to implement the prediction logic, all you have to do is subclass :class:`.HuggingEngine` and -implement :meth:`~.HuggingEngine.build_prompt` and :meth:`~.BaseEngine.message_len`. +.. versionadded:: 1.0.0 + For most models that use a chat format, you won't even need to create a new engine class - instead, you can pass + a :class:`.PromptPipeline` to the :class:`.HuggingEngine`. + +If you do create a new engine, instead of having to implement the prediction logic, all you have to do is subclass +:class:`.HuggingEngine` and implement :meth:`~.HuggingEngine.build_prompt` and :meth:`~.BaseEngine.message_len`. .. seealso:: @@ -27,8 +31,12 @@ If you're running your model locally, you might run into issues because large la Unless you pay for a massive compute cluster (|:money_with_wings:|) or have access to one at your institution, you might not be able to fit models with billions of params on your GPU. That's where model quantization comes into play. - Using FP4 quantization you can expect to reduce up to 8x the model size compared to its native full precision - version. +.. tip:: + + Thanks to the hard work of the LLM community, many models on Hugging Face also have quantized versions available + in the GGUF format. GGUF is the format for ``llama.cpp``, a low-level optimized LLM runtime. Unlike the name + suggests, it supports many more models than LLaMA. If your model has a GGUF version available, consider using the + :class:`.LlamaCppEngine` instead of the ``HuggingEngine`` to load a pre-quantized version. In this section, we'll show how to load HuggingFace models in FP4. @@ -57,31 +65,16 @@ example shows the :class:`.LlamaEngine`, but the same arguments should apply to :class:`.HuggingEngine`. .. code-block:: python - :emphasize-lines: 4-7 + + from transformers import BitsAndBytesConfig + from kani.engines.huggingface.llama2 import LlamaEngine + + quantization_config = BitsAndBytesConfig(load_in_4bit=True) engine = LlamaEngine( use_auth_token=True, - strict=True, model_load_kwargs={ "device_map": "auto", - "load_in_4bit": True, + "quantization_config": quantization_config, }, ) - - -**Memory Usage Comparison** - -This table shows the effect of enabling fp4 quantization on GPU memory usage and inference speed on ``Llama-2-7b-chat``. - -These numbers represent the average of three runs on a consumer RTX 4070ti (12GB memory) with greedy sampling. - -+--------------+----------------------+----------------------------------------+ -| fp4 Enabled? | Memory Usage | Inference Time (per token) | -+==============+======================+========================================+ -| No | 26.6GB | 1215.6 ms | -+--------------+----------------------+----------------------------------------+ -| Yes | 5.0GB (5.32x less) | 23.6 ms (51.5x speedup\ [#shared]_) | -+--------------+----------------------+----------------------------------------+ - -.. [#shared] Since the memory usage without fp4 enabled is larger than the VRAM size of my GPU, some weights were stored - in shared memory. This likely led to much slower inference compared to storing all weights on a GPU. diff --git a/docs/engines/llamacpp.rst b/docs/engines/llamacpp.rst new file mode 100644 index 0000000..3ef6fd8 --- /dev/null +++ b/docs/engines/llamacpp.rst @@ -0,0 +1,22 @@ +llama.cpp +========= +If your language model backend is available with GGUF, kani includes a base engine that implements +a prediction pipeline. + +kani uses `llama-cpp-python `_ for binding to the llama.cpp runtime. + +.. versionadded:: 1.0.0 + For most models that use a chat format, you won't even need to create a new engine class - instead, you can pass + a :class:`.PromptPipeline` to the :class:`.LlamaCppEngine`. + +If you do create a new engine, instead of having to implement the prediction logic, all you have to do is subclass +:class:`.LlamaCppEngine` and implement :meth:`~.LlamaCppEngine.build_prompt` and :meth:`~.LlamaCppEngine.message_len`. + +.. autoclass:: kani.engines.llamacpp.LlamaCppEngine + :noindex: + + .. automethod:: kani.engines.llamacpp.LlamaCppEngine.build_prompt + :noindex: + + .. automethod:: kani.engines.llamacpp.LlamaCppEngine.message_len + :noindex: diff --git a/docs/index.rst b/docs/index.rst index 025ff09..ac24039 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -17,7 +17,7 @@ Features - **Lightweight and high-level** - kani implements common boilerplate to interface with language models without forcing you to use opinionated prompt frameworks or complex library-specific tooling. - **Model agnostic** - kani provides a simple interface to implement: token counting and completion generation. - Implement these two, and kani can run with any language model. + kani lets developers switch which language model runs on the backend without major code refactors. - **Automatic chat memory management** - Allow chat sessions to flow without worrying about managing the number of tokens in the history - kani takes care of it. - **Function calling with model feedback and retry** - Give models access to functions in just one line of code. @@ -43,6 +43,7 @@ Then, let's use kani to create a simple chatbot using ChatGPT as a backend. .. code-block:: python # import the library + import asyncio from kani import Kani, chat_in_terminal from kani.engines.openai import OpenAIEngine @@ -54,13 +55,20 @@ Then, let's use kani to create a simple chatbot using ChatGPT as a backend. engine = OpenAIEngine(api_key, model="gpt-3.5-turbo") # The kani manages the chat state, prompting, and function calling. Here, we only give - # it the engine to call ChatGPT, but you can specify other parameters like system_prompt="You are..." here. + # it the engine to call ChatGPT, but you can specify other parameters like + # system_prompt="You are..." here. ai = Kani(engine) - # kani comes with a utility to interact with a kani through your terminal! Check out - # the docs for how to use kani programmatically. + # kani comes with a utility to interact with a kani through your terminal... chat_in_terminal(ai) + # or you can use kani programmatically in an async function! + async def main(): + resp = await ai.chat_round("What is the airspeed velocity of an unladen swallow?") + print(resp.text) + + asyncio.run(main()) + kani makes the time to set up a working chat model short, while offering the programmer deep customizability over every prompt, function call, and even the underlying language model. diff --git a/docs/kani.rst b/docs/kani.rst index 81b2e38..fa5073a 100644 --- a/docs/kani.rst +++ b/docs/kani.rst @@ -87,10 +87,6 @@ here's how you might implement a simple chat: # use `asyncio.run` to call your async function to start the program asyncio.run(chat_with_kani()) -.. seealso:: - - The source code of :func:`.chat_in_terminal`. - Engines ^^^^^^^ Engines are responsible for interfacing with a language model. @@ -121,16 +117,8 @@ conversations), and an assistant message can contain a ``function_call`` (discus At a high level, a :class:`.Kani` is responsible for managing a list of :class:`.ChatMessage`: the chat session associated with it. You can access the chat messages through the :attr:`.Kani.chat_history` attribute. -You may even modify the chat history (e.g. append or delete ChatMessages) to change the prompt at any time. - -.. tip:: - To edit the content of a message in the chat history, you must replace the object. ChatMessages are - immutable by default. - - For example, to edit the last message, you could set ``ai.chat_history[-1] = ChatMessage.assistant("...")``. - - You can use the :meth:`.ChatMessage.copy_with` convenience method to make a copy with only certain attributes - updated. +You may even modify the chat history (e.g. append or delete ChatMessages or edit a message's content) to change the +prompt at any time. .. warning:: In some advanced use cases, :attr:`.ChatMessage.content` may be a tuple of :class:`.MessagePart` or ``str`` rather @@ -165,6 +153,66 @@ You may even modify the chat history (e.g. append or delete ChatMessages) to cha ChatMessage(role=ChatRole.ASSISTANT, content="Hello! How can I assist you today?"), ] +Streaming +--------- +kani supports streaming to print tokens from the engine as they are received. Streaming is designed to be a drop-in +superset of the ``chat_round`` and ``full_round`` methods, allowing you to gradually refactor your code without ever +leaving it in a broken state. + +To request a stream from the engine, use :meth:`.Kani.chat_round_stream` or :meth:`.Kani.full_round_stream`. These +methods will return a :class:`.StreamManager`, which you can use in different ways to consume the stream. + +The simplest way to consume the stream is to iterate over it with ``async for``, which will yield a stream of +:class:`str`. + +.. code-block:: python + + # CHAT ROUND: + stream = ai.chat_round_stream("What is the airspeed velocity of an unladen swallow?") + async for token in stream: + print(token, end="") + msg = await stream.message() + + # FULL ROUND: + async for stream in ai.full_round_stream("What is the airspeed velocity of an unladen swallow?"): + async for token in stream: + print(token, end="") + msg = await stream.message() + +kani also provides a helper to print streams (:func:`kani.print_stream`): + +.. code-block:: python + + stream = ai.chat_round_stream("What is the most interesting train line in Tokyo?") + await kani.print_stream(stream) + +After a stream finishes, its contents will be available as a :class:`.ChatMessage`. You can retrieve the final +message or :class:`.BaseCompletion` with: + +.. code-block:: python + + msg = await stream.message() + completion = await stream.completion() + +The final :class:`.ChatMessage` may contain non-yielded tokens (e.g. a request for a function call). If the final +message or completion is requested before the stream is iterated over, the stream manager will consume the entire +stream. + +.. tip:: + For compatibility and ease of refactoring, awaiting the stream itself will also return the message, i.e.: + + .. code-block:: python + + msg = await ai.chat_round_stream("What is the airspeed velocity of an unladen swallow?") + + (note the ``await`` that is not present in the above examples). This allows you to refactor your code by changing + ``chat_round`` to ``chat_round_stream`` without other changes. + + .. code-block:: diff + + - msg = await ai.chat_round("What is the airspeed velocity of an unladen swallow?") + + msg = await ai.chat_round_stream("What is the airspeed velocity of an unladen swallow?") + Few-Shot Prompting ------------------ Few-shot prompting (AKA in-context learning) is the idea that language models can "learn" the task the user wants diff --git a/docs/shared/engine_table.rst b/docs/shared/engine_table.rst index 6bd5a2e..1d5ee93 100644 --- a/docs/shared/engine_table.rst +++ b/docs/shared/engine_table.rst @@ -1,20 +1,26 @@ -+----------------------------------------+------------------------------------+-----------------------+----------------------------------------------------------------------+ -| Model Name | Extra | Capabilities | Engine | -+========================================+====================================+=======================+======================================================================+ -| GPT-3.5-turbo, GPT-4 | ``openai`` | |function| |api| | :class:`kani.engines.openai.OpenAIEngine` | -+----------------------------------------+------------------------------------+-----------------------+----------------------------------------------------------------------+ -| Claude, Claude Instant | ``anthropic`` | |api| | :class:`kani.engines.anthropic.AnthropicEngine` | -+----------------------------------------+------------------------------------+-----------------------+----------------------------------------------------------------------+ -| |:hugging:| transformers\ [#abstract]_ | ``huggingface``\ [#torch]_ | (abstract) | :class:`kani.engines.huggingface.HuggingEngine` | -+----------------------------------------+------------------------------------+-----------------------+----------------------------------------------------------------------+ -| |:hugging:| |:llama:| LLaMA v2 | ``huggingface, llama``\ [#torch]_ | |oss| |cpu| |gpu| | :class:`kani.engines.huggingface.llama2.LlamaEngine` | -+----------------------------------------+------------------------------------+-----------------------+----------------------------------------------------------------------+ -| |:hugging:| |:llama:| Vicuna v1.3 | ``huggingface, llama``\ [#torch]_ | |oss| |cpu| |gpu| | :class:`kani.engines.huggingface.vicuna.VicunaEngine` | -+----------------------------------------+------------------------------------+-----------------------+----------------------------------------------------------------------+ -| ctransformers\ [#abstract]_ | ``ctransformers`` | (abstract) | :class:`kani.engines.ctransformers.CTransformersEngine` | -+----------------------------------------+------------------------------------+-----------------------+----------------------------------------------------------------------+ -| |:llama:| LLaMA v2 (GGML) | ``ctransformers`` | |oss| |cpu| |gpu| | :class:`kani.engines.ctransformers.llama2.LlamaCTransformersEngine` | -+----------------------------------------+------------------------------------+-----------------------+----------------------------------------------------------------------+ ++----------------------------------------+------------------------------------+------------------------------+----------------------------------------------------------------------+ +| Model Name | Extra | Capabilities | Engine | ++========================================+====================================+==============================+======================================================================+ +| GPT-3.5-turbo, GPT-4 | ``openai`` | |function| |api| | :class:`kani.engines.openai.OpenAIEngine` | ++----------------------------------------+------------------------------------+------------------------------+----------------------------------------------------------------------+ +| Claude, Claude Instant | ``anthropic`` | |function| |api| | :class:`kani.engines.anthropic.AnthropicEngine` | ++----------------------------------------+------------------------------------+------------------------------+----------------------------------------------------------------------+ +| |:hugging:| transformers\ [#runtime]_ | ``huggingface``\ [#torch]_ | (runtime) | :class:`kani.engines.huggingface.HuggingEngine` | ++----------------------------------------+------------------------------------+------------------------------+----------------------------------------------------------------------+ +| |:hugging:| Command R, Command R+ | ``huggingface``\ [#torch]_ | |function| |oss| |cpu| |gpu| | :class:`kani.engines.huggingface.cohere.CommandREngine` | ++----------------------------------------+------------------------------------+------------------------------+----------------------------------------------------------------------+ +| |:hugging:| |:llama:| LLaMA v2 | ``huggingface, llama``\ [#torch]_ | |oss| |cpu| |gpu| | :class:`kani.engines.huggingface.llama2.LlamaEngine` | ++----------------------------------------+------------------------------------+------------------------------+----------------------------------------------------------------------+ +| |:hugging:| |:llama:| Vicuna v1.3 | ``huggingface, llama``\ [#torch]_ | |oss| |cpu| |gpu| | :class:`kani.engines.huggingface.vicuna.VicunaEngine` | ++----------------------------------------+------------------------------------+------------------------------+----------------------------------------------------------------------+ +| llama.cpp\ [#runtime]_ | ``cpp`` | (runtime) | :class:`kani.engines.llamacpp.LlamaCppEngine` | ++----------------------------------------+------------------------------------+------------------------------+----------------------------------------------------------------------+ +| |:llama:| LLaMA v2 (GGUF) | ``cpp`` | |oss| |cpu| |gpu| | :class:`kani.engines.llamacpp.LlamaCppEngine` | ++----------------------------------------+------------------------------------+------------------------------+----------------------------------------------------------------------+ + +Additional models using the classes above are also supported - see the +`model zoo `_ for a more comprehensive list of +models! **Legend** @@ -33,3 +39,4 @@ .. [#torch] You will also need to install `PyTorch `_ manually. .. [#abstract] This is an abstract class of models; kani includes a couple concrete implementations for reference. +.. [#runtime] This is a model runtime that can support multiple models using a :class:`.PromptPipeline`. \ No newline at end of file diff --git a/examples/4_engines_fp4_quant.py b/examples/4_engines_fp4_quant.py index 0da5a25..c3d8e71 100644 --- a/examples/4_engines_fp4_quant.py +++ b/examples/4_engines_fp4_quant.py @@ -8,15 +8,18 @@ import asyncio import time +from transformers import BitsAndBytesConfig + from kani import Kani, chat_in_terminal from kani.engines.huggingface.llama2 import LlamaEngine +quantization_config = BitsAndBytesConfig(load_in_4bit=True) + engine = LlamaEngine( use_auth_token=True, - strict=True, model_load_kwargs={ "device_map": "auto", - "load_in_4bit": True, + "quantization_config": quantization_config, }, ) ai = Kani( @@ -41,5 +44,5 @@ async def time_completion(): if __name__ == "__main__": - chat_in_terminal(ai) - # asyncio.run(time_completion()) + # chat_in_terminal(ai) + asyncio.run(time_completion()) diff --git a/examples/4_engines_llama2.py b/examples/4_engines_llama2.py index f7b3f9f..60e344e 100644 --- a/examples/4_engines_llama2.py +++ b/examples/4_engines_llama2.py @@ -7,7 +7,7 @@ from kani import Kani, chat_in_terminal from kani.engines.huggingface.llama2 import LlamaEngine -engine = LlamaEngine(use_auth_token=True, strict=True) +engine = LlamaEngine(use_auth_token=True) ai = Kani( engine, system_prompt=( diff --git a/examples/4_engines_zoo.py b/examples/4_engines_zoo.py index efbabdd..4407115 100644 --- a/examples/4_engines_zoo.py +++ b/examples/4_engines_zoo.py @@ -10,27 +10,51 @@ # ==== OpenAI (GPT) ==== from kani.engines.openai import OpenAIEngine -engine = OpenAIEngine(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-3.5-turbo") +engine = OpenAIEngine(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4") # ==== Anthropic (Claude) ==== +# see https://docs.anthropic.com/claude/docs/models-overview for a list of model IDs from kani.engines.anthropic import AnthropicEngine -engine = AnthropicEngine(api_key=os.getenv("ANTHROPIC_API_KEY"), model="claude-2.1") +engine = AnthropicEngine(api_key=os.getenv("ANTHROPIC_API_KEY"), model="claude-3-opus-20240229") -# ==== LLaMA v2 (Hugging Face) ==== +# ========== Hugging Face ========== +# ---- LLaMA v2 (Hugging Face) ---- from kani.engines.huggingface.llama2 import LlamaEngine engine = LlamaEngine(model_id="meta-llama/Llama-2-7b-chat-hf", use_auth_token=True) # log in with huggingface-cli -# ==== LLaMA v2 (ctransformers) ==== -from kani.engines.ctransformers.llama2 import LlamaCTransformersEngine -engine = LlamaCTransformersEngine( - model_id="TheBloke/Llama-2-7B-Chat-GGML", - model_file="llama-2-7b-chat.ggmlv3.q5_K_M.bin", +# ---- Command R (Hugging Face) ---- +import torch +from kani.engines.huggingface.cohere import CommandREngine +engine = CommandREngine( + model_id="CohereForAI/c4ai-command-r-v01", model_load_kwargs={"device_map": "auto", "torch_dtype": torch.float16} ) -# ==== Vicuna v1.3 (Hugging Face) ==== +# ---- Mistral-7B (Hugging Face) ---- +from kani.engines.huggingface import HuggingEngine +from kani.prompts.impl import MISTRAL_PIPELINE +engine = HuggingEngine(model_id="mistralai/Mistral-7B-Instruct-v0.2", prompt_pipeline=MISTRAL_PIPELINE) + +# ---- Gemma (Hugging Face) ---- +from kani.engines.huggingface import HuggingEngine +from kani.prompts.impl import GEMMA_PIPELINE +engine = HuggingEngine(model_id="google/gemma-1.1-7b-it", prompt_pipeline=GEMMA_PIPELINE) + +# ---- Vicuna v1.3 (Hugging Face) ---- from kani.engines.huggingface.vicuna import VicunaEngine engine = VicunaEngine(model_id="lmsys/vicuna-7b-v1.3") +# ========== llama.cpp ========== +# ---- LLaMA v2 (llama.cpp) ---- +from kani.engines.llamacpp import LlamaCppEngine +engine = LlamaCppEngine(repo_id="TheBloke/Llama-2-7B-Chat-GGUF", filename="*.Q4_K_M.gguf") + +# ---- Mistral-7B (llama.cpp) ---- +from kani.engines.llamacpp import LlamaCppEngine +from kani.prompts.impl import MISTRAL_PIPELINE +engine = LlamaCppEngine( + repo_id="TheBloke/Mistral-7B-Instruct-v0.2-GGUF", filename="*.Q4_K_M.gguf", prompt_pipeline=MISTRAL_PIPELINE +) + # take your pick - the kani interface is compatible with all! ai = Kani(engine) diff --git a/examples/5_advanced_retrieval.py b/examples/5_advanced_retrieval.py index 0690208..c77b896 100644 --- a/examples/5_advanced_retrieval.py +++ b/examples/5_advanced_retrieval.py @@ -8,8 +8,9 @@ import os from typing import Annotated +import httpx + from kani import AIParam, Kani, ai_function, chat_in_terminal -from kani.engines.httpclient import BaseClient from kani.engines.openai import OpenAIEngine api_key = os.getenv("OPENAI_API_KEY") @@ -17,15 +18,10 @@ engine = OpenAIEngine(api_key, model="gpt-3.5-turbo") -# let's define a client using kani's BaseClient: -class WikipediaClient(BaseClient): - SERVICE_BASE = "https://en.wikipedia.org/w/api.php" - - class WikipediaRetrievalKani(Kani): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.wikipedia_client = WikipediaClient() + self.wikipedia_client = httpx.AsyncClient(base_url="https://en.wikipedia.org/w/api.php", follow_redirects=True) @ai_function() async def wikipedia( @@ -45,7 +41,8 @@ async def wikipedia( "formatversion": 2, }, ) - page = resp["query"]["pages"][0] + data = resp.json() + page = data["query"]["pages"][0] if extract := page.get("extract"): return extract return f"The page {title!r} does not exist on Wikipedia." @@ -55,7 +52,7 @@ async def search(self, query: str): """Find titles of Wikipedia articles similar to the given query.""" # https://en.wikipedia.org/w/api.php?action=opensearch&format=json&search=Train resp = await self.wikipedia_client.get("/", params={"action": "opensearch", "format": "json", "search": query}) - return json.dumps(resp[1]) + return json.dumps(resp.json()[1]) ai = WikipediaRetrievalKani(engine) diff --git a/examples/5_advanced_subkanis.py b/examples/5_advanced_subkanis.py index c4dd031..2b0de60 100644 --- a/examples/5_advanced_subkanis.py +++ b/examples/5_advanced_subkanis.py @@ -10,7 +10,7 @@ import os from kani import Kani, ai_function, chat_in_terminal -from kani.engines.openai import OpenAIClient, OpenAIEngine +from kani.engines.openai import OpenAIEngine api_key = os.getenv("OPENAI_API_KEY") @@ -19,18 +19,13 @@ class KaniWithAISummarization(Kani): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # save a single client here so we can reuse it across multiple calls - self.openai_client = OpenAIClient(api_key) - @ai_function() async def summarize_conversation(self): """Get the summary of the conversation so far.""" # in this AI Function, we can spawn a sub-kani with a model that can handle # longer contexts, since the conversation may be longer than the fast model's # context window - long_context_engine = OpenAIEngine(client=self.openai_client, model=long_context_model) + long_context_engine = OpenAIEngine(api_key, model=long_context_model) # first, copy the parent's chat history to the child, except the last user message # and the function call ([:-2]) sub_kani = Kani(long_context_engine, chat_history=self.chat_history[:-2]) diff --git a/examples/colab_examples.ipynb b/examples/colab_examples.ipynb index ac170c4..d29b9f1 100644 --- a/examples/colab_examples.ipynb +++ b/examples/colab_examples.ipynb @@ -505,23 +505,23 @@ "import json\n", "from typing import Annotated\n", "\n", + "import httpx\n", + "\n", "from kani import Kani, chat_in_terminal, ai_function, AIParam\n", - "from kani.engines.httpclient import BaseClient\n", "from kani.engines.openai import OpenAIEngine\n", "\n", "\n", - "class WikipediaClient(BaseClient):\n", - " SERVICE_BASE = \"https://en.wikipedia.org/w/api.php\"\n", - "\n", - "\n", "class WikipediaKani(Kani):\n", " def __init__(self, *args, **kwargs):\n", " super().__init__(*args, **kwargs)\n", - " self.wikipedia_client = WikipediaClient()\n", + " self.wikipedia_client = httpx.AsyncClient(base_url=\"https://en.wikipedia.org/w/api.php\", follow_redirects=True)\n", "\n", " @ai_function()\n", - " async def wikipedia(self, title: Annotated[str, AIParam(desc=\"The article title\")]):\n", - " \"\"\"Get information from Wikipedia.\"\"\"\n", + " async def wikipedia(\n", + " self,\n", + " title: Annotated[str, AIParam(desc='The article title on Wikipedia, e.g. \"Train_station\".')],\n", + " ):\n", + " \"\"\"Get additional information about a topic from Wikipedia.\"\"\"\n", " resp = await self.wikipedia_client.get(\n", " \"/\",\n", " params={\n", @@ -533,17 +533,17 @@ " \"formatversion\": 2,\n", " },\n", " )\n", - " page = resp[\"query\"][\"pages\"][0]\n", + " data = resp.json()\n", + " page = data[\"query\"][\"pages\"][0]\n", " if extract := page.get(\"extract\"):\n", " return extract\n", " return f\"The page {title!r} does not exist on Wikipedia.\"\n", "\n", " @ai_function()\n", " async def search(self, query: str):\n", - " \"\"\"Find article titles given a query.\"\"\"\n", - " param_dict = {\"action\": \"opensearch\", \"format\": \"json\", \"search\": query}\n", - " resp = await self.wikipedia_client.get(\"/\", params=param_dict)\n", - " return json.dumps(resp[1])\n", + " \"\"\"Find titles of Wikipedia articles similar to the given query.\"\"\"\n", + " resp = await self.wikipedia_client.get(\"/\", params={\"action\": \"opensearch\", \"format\": \"json\", \"search\": query})\n", + " return json.dumps(resp.json()[1])\n", "\n", "\n", "engine = OpenAIEngine(api_key, model=\"gpt-4\")\n", diff --git a/examples/mistral_gguf.py b/examples/mistral_gguf.py new file mode 100644 index 0000000..6daa587 --- /dev/null +++ b/examples/mistral_gguf.py @@ -0,0 +1,13 @@ +""" +A standalone Mistral-7B-Instruct file using llama.cpp. +""" + +from kani import Kani, chat_in_terminal +from kani.engines.llamacpp import LlamaCppEngine + +engine = LlamaCppEngine(repo_id="TheBloke/Mistral-7B-Instruct-v0.2-GGUF", filename="*.Q4_K_M.gguf") + +ai = Kani(engine) + +if __name__ == "__main__": + chat_in_terminal(ai) diff --git a/kani/__init__.py b/kani/__init__.py index ebfbcb9..30b59bd 100644 --- a/kani/__init__.py +++ b/kani/__init__.py @@ -1,9 +1,11 @@ from . import engines, exceptions, utils +from ._version import __version__ from .ai_function import AIFunction, AIParam, ai_function from .internal import ExceptionHandleResult, FunctionCallResult from .kani import Kani from .models import ChatMessage, ChatRole, FunctionCall, MessagePart, ToolCall -from .utils.cli import chat_in_terminal, chat_in_terminal_async +from .prompts import PromptPipeline +from .utils.cli import chat_in_terminal, chat_in_terminal_async, print_stream # declare that kani is also a namespace package __path__ = __import__("pkgutil").extend_path(__path__, __name__) diff --git a/kani/_version.py b/kani/_version.py new file mode 100644 index 0000000..bcf7f89 --- /dev/null +++ b/kani/_version.py @@ -0,0 +1 @@ +__version__ = "1.0.0rc0" diff --git a/kani/ai_function.py b/kani/ai_function.py index 4865a4b..70d7ccb 100644 --- a/kani/ai_function.py +++ b/kani/ai_function.py @@ -2,6 +2,7 @@ import functools import inspect import typing +import warnings from typing import Annotated from pydantic import validate_call @@ -55,6 +56,14 @@ def __init__( self.__module__ = inner.__module__ self.__doc__ = inner.__doc__ + # validation + if not self.desc: + warnings.warn( + f"The {self.name!r} @ai_function is missing a description. This may lead to request errors or poor" + ' performance by models. To add a description, add a """docstring""" beneath the signature or use' + ' @ai_function(desc="...").' + ) + async def __call__(self, *args, **kwargs): if self._inner_is_coro: return await self.inner(*args, **kwargs) @@ -62,8 +71,7 @@ async def __call__(self, *args, **kwargs): inner_partial = functools.partial(self.inner, *args, **kwargs) return await asyncio.get_event_loop().run_in_executor(None, inner_partial) - def create_json_schema(self) -> dict: - """Create a JSON schema representing this function's parameters as a JSON object.""" + def get_params(self) -> list[AIParamSchema]: # get list of params params = [] sig = inspect.signature(self.inner) @@ -89,9 +97,17 @@ def create_json_schema(self) -> dict: # get aiparam and add it to the list ai_param = get_aiparam(annotation) - params.append(AIParamSchema(name=name, t=type_hints[name], default=param.default, aiparam=ai_param)) + params.append( + AIParamSchema( + name=name, t=type_hints[name], default=param.default, aiparam=ai_param, inspect_param=param + ) + ) + return params + + def create_json_schema(self) -> dict: + """Create a JSON schema representing this function's parameters as a JSON object.""" # create a schema generator and generate - return create_json_schema(params) + return create_json_schema(self.get_params()) def ai_function( diff --git a/kani/engines/anthropic/engine.py b/kani/engines/anthropic/engine.py index b96cd8d..db318e2 100644 --- a/kani/engines/anthropic/engine.py +++ b/kani/engines/anthropic/engine.py @@ -1,28 +1,89 @@ +import functools +import itertools +import json import os import warnings +from typing import AsyncIterable from kani.ai_function import AIFunction -from kani.exceptions import MissingModelDependencies, PromptError -from kani.models import ChatMessage, ChatRole -from ..base import BaseEngine, Completion +from kani.exceptions import KaniException, MissingModelDependencies +from kani.models import ChatMessage, ChatRole, FunctionCall, ToolCall +from kani.prompts.pipeline import PromptPipeline +from ..base import BaseCompletion, BaseEngine, Completion try: from anthropic import AI_PROMPT, HUMAN_PROMPT, AsyncAnthropic - - # anthropic async client loads a small json file using anyio for some reason; hook into the underlying loader - # noinspection PyProtectedMember - from anthropic._tokenizers import sync_get_tokenizer except ImportError as e: raise MissingModelDependencies( 'The AnthropicEngine requires extra dependencies. Please install kani with "pip install kani[anthropic]".' ) from None CONTEXT_SIZES_BY_PREFIX = [ + ("claude-3", 200000), ("claude-2.1", 200000), ("", 100000), ] +# ==== pipe ==== +def content_transform(msg: ChatMessage): + # FUNCTION messages should look like: + # { + # "role": "user", + # "content": [ + # { + # "type": "tool_result", + # "tool_use_id": "toolu_01A09q90qw90lq917835lq9", + # "content": "65 degrees" + # } + # ] + # } + if msg.role == ChatRole.FUNCTION: + result = {"type": "tool_result", "tool_use_id": msg.tool_call_id, "content": msg.text} + + # tool call error + if msg.is_tool_call_error: + result["is_error"] = True + + return [result] + + # ASSISTANT messages with tool calls should look like: + # { + # "role": "assistant", + # "content": [ + # { + # "type": "text", + # "text": "I need to use the get_weather, and the user wants San Francisco, CA." + # }, + # { + # "type": "tool_use", + # "id": "toolu_01A09q90qw90lq917835lq9", + # "name": "get_weather", + # "input": {"location": "San Francisco, CA", "unit": "celsius"} + # } + # ] + # } + if msg.role == ChatRole.ASSISTANT and msg.tool_calls: + content = [{"type": "text", "text": str(part)} for part in msg.parts] + for tc in msg.tool_calls: + content.append({"type": "tool_use", "id": tc.id, "name": tc.function.name, "input": tc.function.kwargs}) + return content + + return msg.text + + +# assumes system messages are plucked before calling +CLAUDE_PIPELINE = ( + PromptPipeline() + .translate_role(role=ChatRole.SYSTEM, to=ChatRole.USER) + .merge_consecutive(role=ChatRole.USER, sep="\n") + .merge_consecutive(role=ChatRole.ASSISTANT, sep=" ") + .ensure_bound_function_calls() + .ensure_start(role=ChatRole.USER) + .conversation_dict(function_role="user", content_transform=content_transform) +) + + class AnthropicEngine(BaseEngine): """Engine for using the Anthropic API. @@ -32,13 +93,14 @@ class AnthropicEngine(BaseEngine): See https://docs.anthropic.com/claude/reference/selecting-a-model for a list of available models. """ - token_reserve = 4 # each prompt ends with \n\nAssistant: + # because we have to estimate tokens wildly and the ctx is so long we'll just reserve a bunch + token_reserve = 500 def __init__( self, api_key: str = None, - model: str = "claude-2.1", - max_tokens_to_sample: int = 512, + model: str = "claude-3-haiku-20240307", + max_tokens: int = 512, max_context_size: int = None, *, retry: int = 2, @@ -51,7 +113,7 @@ def __init__( :param api_key: Your Anthropic API key. By default, the API key will be read from the `ANTHROPIC_API_KEY` environment variable. :param model: The id of the model to use (e.g. "claude-2.1", "claude-instant-1.2"). - :param max_tokens_to_sample: The maximum number of tokens to sample at each generation (defaults to 450). + :param max_tokens: The maximum number of tokens to sample at each generation (defaults to 512). Generally, you should set this to the same number as your Kani's ``desired_response_tokens``. :param max_context_size: The maximum amount of tokens allowed in the chat prompt. If None, uses the given model's full context size. @@ -79,12 +141,59 @@ def __init__( api_key=api_key, max_retries=retry, base_url=api_base, default_headers=headers ) self.model = model - self.max_tokens_to_sample = max_tokens_to_sample + self.max_tokens = max_tokens self.max_context_size = max_context_size self.hyperparams = hyperparams - self.tokenizer = sync_get_tokenizer() + + # token counting - claude 3+ does not release tokenizer so we have to do heuristics and cache + self.token_cache = {} + if model.startswith("claude-2"): + # anthropic async client loads a json file using anyio for some reason; hook into the underlying loader + # noinspection PyProtectedMember + from anthropic._tokenizers import sync_get_tokenizer + + self.tokenizer = sync_get_tokenizer() + else: + # claude 3 tokenizer just... doesn't exist + # https://github.com/anthropics/anthropic-sdk-python/issues/375 pain + self.tokenizer = None + + # ==== token counting ==== + @staticmethod + def message_cache_key(message: ChatMessage): + # (role, content, tool calls) + + # we'll use msgpart identity for the hash here since we'll always have a ref as long as it's in a message + # history + hashable_content = tuple(part if isinstance(part, str) else id(part) for part in message.parts) + + # use (name, args) for tool calls + if message.tool_calls: + hashable_tool_calls = tuple((tc.function.name, tc.function.arguments) for tc in message.tool_calls) + else: + hashable_tool_calls = message.tool_calls + + return hash((message.role, hashable_content, hashable_tool_calls)) def message_len(self, message: ChatMessage) -> int: + # use cache + cache_key = self.message_cache_key(message) + if cache_key in self.token_cache: + return self.token_cache[cache_key] + + # use tokenizer + if self.tokenizer is not None: + return self._message_len_tokenizer(message) + + # panik - I guess we'll pretend that 4 chars = 1 token...? + n = len(message.role.value) + len(message.text) + if message.tool_calls: + for tc in message.tool_calls: + n += len(tc.function.name) + len(tc.function.arguments) + return n // 4 + + def _message_len_tokenizer(self, message): + # this only applies to claude-2 # human messages are prefixed with `\n\nHuman: ` and assistant with `\n\nAssistant:` if message.role == ChatRole.USER: mlen = 5 @@ -97,46 +206,137 @@ def message_len(self, message: ChatMessage) -> int: mlen += len(self.tokenizer.encode(message.text).ids) return mlen + def function_token_reserve(self, functions: list[AIFunction]) -> int: + if not functions: + return 0 + # wrap an inner impl to use lru_cache with frozensets + return self._function_token_reserve_impl(frozenset(functions)) + + @functools.lru_cache(maxsize=256) + def _function_token_reserve_impl(self, functions): + # panik, also assume len/4? + n = sum(len(f.name) + len(f.desc) + len(json.dumps(f.json_schema)) for f in functions) + return n // 4 + + # ==== requests ==== + def _get_client(self, functions): + if not functions: + return self.client + return self.client.beta.tools + @staticmethod - def build_prompt(messages: list[ChatMessage]): - # Claude prompts must start with a human message - first_human_idx = next((i for i, m in enumerate(messages) if m.role == ChatRole.USER), None) - if first_human_idx is None: - raise PromptError("Prompts to Anthropic models must contain at least one USER message.") + def _prepare_request(messages, functions): + kwargs = {} - # and make sure the system messages are included + # --- messages --- + # pluck system messages last_system_idx = next((i for i, m in enumerate(messages) if m.role != ChatRole.SYSTEM), None) if last_system_idx: - out = ["\n\n".join(m.text for m in messages[:last_system_idx])] - else: - out = [] + kwargs["system"] = "\n\n".join(m.text for m in messages[:last_system_idx]) + messages = messages[last_system_idx:] + + # enforce ordering and function call bindings + # and translate to dict spec + messages = CLAUDE_PIPELINE(messages) + + # merge FUNCTION, USER consecutives into one with multiple parts + prompt_msgs = [] + for role, group_msgs in itertools.groupby(messages, key=lambda m: m["role"]): + group_msgs = list(group_msgs) + # >1 consecutive user messages get merged + if role == "user" and len(group_msgs) > 1: + # turn str parts into {type: text, text: ...} + prompt_msg_content = [] + for msg in group_msgs: + if isinstance(msg["content"], str): + prompt_msg_content.append({"type": "text", "text": msg["content"]}) + else: + prompt_msg_content.append(msg["content"]) + # and output the final msg + prompt_msgs.append({"role": "user", "content": prompt_msg_content}) + # else send to output + else: + prompt_msgs.extend(group_msgs) + + # --- tools --- + if functions: + kwargs["tools"] = [ + {"name": f.name, "description": f.desc, "input_schema": f.json_schema} for f in functions + ] + + return kwargs, prompt_msgs - for idx, message in enumerate(messages[first_human_idx:]): - if message.role == ChatRole.USER: - out.append(f"{HUMAN_PROMPT} {message.text}") - elif message.role == ChatRole.ASSISTANT: - out.append(f"{AI_PROMPT} {message.text}") + def _translate_anthropic_message(self, message): + tool_calls = [] + text_parts = [] + for part in message.content: + if part.type == "text": + text_parts.append(part.text) + elif part.type == "tool_use": + fc = FunctionCall(name=part.name, arguments=json.dumps(part.input)) + tc = ToolCall(id=part.id, type="function", function=fc) + tool_calls.append(tc) else: - warnings.warn( - f"Encountered a {message.role} message in the middle of the prompt - Anthropic models expect an" - " optional SYSTEM message followed by alternating USER and ASSISTANT messages. Appending the" - " content to the prompt..." + raise KaniException( + f"The engine returned an unknown part: {part.type}. Please report this as an issue on GitHub with" + " instructions on how to reproduce this issue." ) - out.append(f"\n\n{message.text}") - return "".join(out) + AI_PROMPT + content = text_parts[0] if len(text_parts) == 1 else text_parts + kani_msg = ChatMessage.assistant(content, tool_calls=tool_calls or None) + + # also cache the message token len + cache_key = self.message_cache_key(kani_msg) + self.token_cache[cache_key] = message.usage.output_tokens + + return Completion( + message=kani_msg, + prompt_tokens=message.usage.input_tokens, + completion_tokens=message.usage.output_tokens, + ) async def predict( self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams ) -> Completion: - prompt = self.build_prompt(messages) - completion = await self.client.completions.create( + kwargs, prompt_msgs = self._prepare_request(messages, functions) + client = self._get_client(functions) + + # --- completion --- + message = await client.messages.create( model=self.model, - max_tokens_to_sample=self.max_tokens_to_sample, - prompt=prompt, + max_tokens=self.max_tokens, + messages=prompt_msgs, + **kwargs, **self.hyperparams, **hyperparams, ) - return Completion(message=ChatMessage.assistant(completion.completion.strip())) + + # translate to kani + return self._translate_anthropic_message(message) + + async def stream( + self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams + ) -> AsyncIterable[str | BaseCompletion]: + if functions: + warnings.warn("Claude 3 does not support streaming with function calling.") + async for elem in super().stream(messages, functions, **hyperparams): + yield elem + else: + # do the stream + kwargs, prompt_msgs = self._prepare_request(messages, functions) + + async with self.client.messages.stream( + model=self.model, + max_tokens=self.max_tokens, + messages=prompt_msgs, + **kwargs, + **self.hyperparams, + **hyperparams, + ) as stream: + async for text in stream.text_stream: + yield text + + message = await stream.get_final_message() + yield self._translate_anthropic_message(message) async def close(self): await self.client.close() diff --git a/kani/engines/base.py b/kani/engines/base.py index d31ab1d..71b1371 100644 --- a/kani/engines/base.py +++ b/kani/engines/base.py @@ -1,5 +1,6 @@ import abc import warnings +from collections.abc import AsyncIterable from kani.ai_function import AIFunction from kani.models import ChatMessage @@ -96,6 +97,35 @@ def function_token_reserve(self, functions: list[AIFunction]) -> int: ) return 0 + async def stream( + self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams + ) -> AsyncIterable[str | BaseCompletion]: + """ + Optional: Stream a completion from the engine, token-by-token. + + This method's signature is the same as :meth:`predict`. + + This method should yield strings as an asynchronous iterable. + + Optionally, this method may also yield a :class:`.BaseCompletion`. If it does, it MUST be the last item + yielded by this method. + + If an engine does not implement streaming, this method will yield the entire text of the completion in a single + chunk by default. + + :param messages: The messages in the current chat context. ``sum(message_len(m) for m in messages)`` is + guaranteed to be less than max_context_size. + :param functions: The functions the LM is allowed to call. + :param hyperparams: Any additional parameters to pass to the engine. + """ + warnings.warn( + f"This {type(self).__name__} does not implement streaming. This stream will yield the entire completion in" + " one single chunk." + ) + completion = await self.predict(messages, functions, **hyperparams) + yield completion.message.text + yield completion + async def close(self): """Optional: Clean up any resources the engine might need.""" pass diff --git a/kani/engines/ctransformers/__init__.py b/kani/engines/ctransformers/__init__.py deleted file mode 100644 index 260e1c1..0000000 --- a/kani/engines/ctransformers/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .base import CTransformersEngine diff --git a/kani/engines/ctransformers/base.py b/kani/engines/ctransformers/base.py deleted file mode 100644 index 4ea9e99..0000000 --- a/kani/engines/ctransformers/base.py +++ /dev/null @@ -1,102 +0,0 @@ -import abc - -from kani.ai_function import AIFunction -from kani.exceptions import MissingModelDependencies -from kani.models import ChatMessage -from ..base import BaseEngine, Completion - -try: - from ctransformers import AutoModelForCausalLM -except ImportError as e: - raise MissingModelDependencies( - 'The CTransformersEngine requires extra dependencies. Please install kani with "pip install' - ' kani[ctransformers]".' - ) from None - - -class CTransformersEngine(BaseEngine, abc.ABC): - """ - This class implements the main decoding logic for any GGML model based on a pretrained ``AutoModelForCausalLM``. - - **GPU Support** - - In order to load a model on GPU, the underlying GGML model must support CUDA. To see a list of supported models, - see `this table `_. - - .. caution:: - If your model supports CUDA, you must also install additional CUDA dependencies. Run - ``pip install 'ctransformers[cuda]'`` if you have not installed CUDA dependencies elsewhere (e.g. through - torch). - - To load some or all of the model layers on GPU, pass ``gpu_layers=...`` in the ``model_load_kwargs``. - """ - - def __init__( - self, - model_id: str, - model_file: str = None, - max_context_size: int = 1024, - model_load_kwargs: dict = None, - **hyperparams, - ): - """ - :param model_id: The ID of the model to load from HuggingFace or locally. - :param model_file: The file of the model to load from HuggingFace repo or locally. - :param max_context_size: The context size of the model. - :param model_load_kwargs: Additional arguments to pass to ``AutoModelForCausalLM.from_pretrained()``. - See `this link `_ for more info. - :param hyperparams: Additional arguments to supply the model during generation. - """ - - if model_load_kwargs is None: - model_load_kwargs = {} - - self.model_id = model_id - self.model_file = model_file - self.max_context_size = max_context_size - model_load_kwargs.setdefault("context_length", self.max_context_size) - self.model = AutoModelForCausalLM.from_pretrained(model_id, model_file=model_file, **model_load_kwargs) - self.hyperparams = hyperparams - - @abc.abstractmethod - def build_prompt(self, messages: list[ChatMessage], functions: list[AIFunction] | None = None) -> str | list[int]: - """Given the list of messages from kani, build either a single string representing the prompt for the model, - or build the token tensor.""" - raise NotImplementedError - - async def predict( - self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams - ) -> Completion: - """ - Given the current context of messages and available functions, get the next predicted chat message from the LM. - - :param messages: The messages in the current chat context. ``sum(message_len(m) for m in messages)`` is - guaranteed to be less than max_context_size. - :param functions: The functions the LM is allowed to call. - :param hyperparams: Any additional parameters to pass to GenerationMixin.generate(). (See - https://github.com/marella/ctransformers#method-llmgenerate) - """ - prompt = self.build_prompt(messages, functions) - if isinstance(prompt, str): - # prompt str to tokens - input_toks = self.model.tokenize(prompt) - input_len = len(input_toks) - elif isinstance(prompt, list): - input_toks = prompt - input_len = len(input_toks) - else: - raise TypeError("build_prompt should either return a str or a list[int].") - - output_toks = [] - output_len = 0 - for tok in self.model.generate(input_toks, **hyperparams): - output_toks.append(tok) - output_len += 1 - # ctransformers does not automatically stop at end of context length - # (e.g. https://github.com/zhudotexe/kani/actions/runs/6152842183/job/16695721588) - # so we force it to stop if we would bust the context length - if input_len + output_len >= self.max_context_size: - break - content = self.model.detokenize(output_toks).strip() - - return Completion(ChatMessage.assistant(content), prompt_tokens=input_len, completion_tokens=output_len) diff --git a/kani/engines/ctransformers/llama2.py b/kani/engines/ctransformers/llama2.py deleted file mode 100644 index 9df7a6e..0000000 --- a/kani/engines/ctransformers/llama2.py +++ /dev/null @@ -1,98 +0,0 @@ -import functools - -from kani.ai_function import AIFunction -from kani.models import ChatMessage, ChatRole -from .base import CTransformersEngine -from .. import llama2_prompt - - -class LlamaCTransformersEngine(CTransformersEngine): - """Implementation of LLaMA v2 using ctransformers. - - You may also use other models that use the LLaMA v2 prompt. - These models are converted from the original models in pytorch format to GGML format and can be found on - huggingface. - You can also download these models locally and pass the local path as an argument to the class. - - The huggingface GGML repos generally have multiple models in them (of different quantization levels), - so you can choose the model depending inference speed, memory, and quality tradeoffs depending on the quantization - performed on the model. - - A specific GGML model can be used from huggingface by passing the model id and the model file - (e.g. https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML#provided-files) to - `LlamaCTransformersEngine` and the model will be automatically downloaded and placed locally. - - Model IDs: - - - ``TheBloke/Llama-2-7B-Chat-GGML`` - - ``TheBloke/Llama-2-13B-chat-GGML`` - - ``TheBloke/Llama-2-70B-Chat-GGML`` - - .. code-block:: python - - engine = LlamaCTransformersEngine( - "TheBloke/Llama-2-7B-Chat-GGML", model_file="llama-2-7b-chat.ggmlv3.q5_K_M.bin" - ) - ai = Kani(engine) - - .. tip:: The ``LlamaCTransformersEngine`` is a drop-in replacement for the :class:`.LlamaEngine`. - - **GPU Support** - - In order to load a model on GPU, the underlying GGML model must support CUDA. To see a list of supported models, - see `this table `_. - - .. caution:: - If your model supports CUDA, you must also install additional CUDA dependencies. Run - ``pip install 'ctransformers[cuda]'`` if you have not installed CUDA dependencies elsewhere (e.g. through - torch). - - To load some or all of the model layers on GPU, pass ``gpu_layers=...`` in the ``model_load_kwargs``: - - .. code-block:: python - :emphasize-lines: 4 - - engine = LlamaCTransformersEngine( - "TheBloke/Llama-2-7B-Chat-GGML", - model_file="llama-2-7b-chat.ggmlv3.q5_K_M.bin", - model_load_kwargs=dict(gpu_layers=50), - ) - ai = Kani(engine) - - .. seealso:: https://github.com/marella/ctransformers/tree/main#method-llmgenerate for generation parameters - """ - - def __init__( - self, - model_id: str = "TheBloke/Llama-2-7B-Chat-GGML", - model_file: str = "llama-2-7b-chat.ggmlv3.q5_K_M.bin", - *args, - **kwargs, - ): - """ - :param model_id: The ID of the model to load from HuggingFace. - :param model_file: The file of the model to load from HuggingFace repo or locally. - :param max_context_size: The context size of the model. - :param model_load_kwargs: Additional arguments to pass to ``AutoModelForCausalLM.from_pretrained()``. - :param hyperparams: Additional arguments to supply the model during generation. - """ - - kwargs.setdefault("max_context_size", 4096) # LLaMA has 4096 token window - super().__init__(model_id, model_file, *args, **kwargs) - - def build_prompt(self, messages: list[ChatMessage], functions: list[AIFunction] | None = None) -> list[int]: - tokenize = functools.partial(self.model.tokenize, add_bos_token=False) - return llama2_prompt.build( - messages, tokenize=tokenize, bos_token_id=self.model.bos_token_id, eos_token_id=self.model.eos_token_id - ) - - def message_len(self, message: ChatMessage) -> int: - # https://github.com/facebookresearch/llama/blob/main/llama/generation.py#L212 - if message.role == ChatRole.USER: - # [INST] {} [/INST] -> 7 - return len(self.model.tokenize(message.text)) + 7 - elif message.role == ChatRole.ASSISTANT: - # {} -> 2 - return len(self.model.tokenize(f" {message.text} ")) + 2 - # [INST] <>\n{}\n<>\n\n [/INST] -> 20 - return len(self.model.tokenize(message.text)) + 20 diff --git a/kani/engines/httpclient.py b/kani/engines/httpclient.py index 39f8844..ff74ed9 100644 --- a/kani/engines/httpclient.py +++ b/kani/engines/httpclient.py @@ -2,13 +2,24 @@ import asyncio import logging -import aiohttp +from ..exceptions import HTTPException, HTTPStatusException, HTTPTimeout, MissingModelDependencies -from ..exceptions import HTTPException, HTTPStatusException, HTTPTimeout +try: + import aiohttp +except ImportError: + raise MissingModelDependencies( + f"The BaseClient has been deprecated as of v1.0.0. We recommend using `httpx` for HTTP requests. To continue" + f" using the BaseClient (not recommended), use `pip install aiohttp`." + ) class BaseClient(abc.ABC): - """aiohttp-based HTTP client to help implement HTTP-based engines.""" + """aiohttp-based HTTP client to help implement HTTP-based engines. + + .. deprecated:: 1.0.0 + We recommend using `httpx.AsyncClient `_ instead. This aiohttp-based client + will be removed in a future version. + """ SERVICE_BASE: str """The base route of the HTTP API.""" diff --git a/kani/engines/huggingface/base.py b/kani/engines/huggingface/base.py index 1c12a4d..9bcea20 100644 --- a/kani/engines/huggingface/base.py +++ b/kani/engines/huggingface/base.py @@ -1,26 +1,33 @@ -import abc +import functools +import logging +import warnings +from threading import Thread +from typing import AsyncIterable from kani.ai_function import AIFunction from kani.exceptions import MissingModelDependencies from kani.models import ChatMessage -from ..base import BaseEngine, Completion +from kani.prompts.pipeline import PromptPipeline +from ..base import BaseCompletion, BaseEngine, Completion try: import torch - from transformers import AutoModelForCausalLM, AutoTokenizer + from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer except ImportError: raise MissingModelDependencies( 'The HuggingEngine requires extra dependencies. Please install kani with "pip install kani[huggingface]". ' "You will also need to install PyTorch manually." ) from None +log = logging.getLogger(__name__) -class HuggingEngine(BaseEngine, abc.ABC): + +class HuggingEngine(BaseEngine): """Base engine for all HuggingFace text-generation models. This class implements the main decoding logic for any HuggingFace model based on a pretrained - ``AutoModelForCausalLM``. To implement a new HuggingFace model, just implement :meth:`~.HuggingEngine.build_prompt` - and :meth:`~.BaseEngine.message_len` for the specified model. + ``AutoModelForCausalLM``. As most models use model-specific chat templates, this base class accepts a + :class:`.PromptPipeline` to translate kani ChatMessages into a model-specific string. **GPU Support** @@ -33,8 +40,10 @@ class HuggingEngine(BaseEngine, abc.ABC): def __init__( self, model_id: str, - max_context_size: int, - use_auth_token=None, + max_context_size: int = None, + prompt_pipeline: PromptPipeline[str | torch.Tensor] = None, + *, + token=None, device: str | None = None, tokenizer_kwargs: dict = None, model_load_kwargs: dict = None, @@ -42,8 +51,10 @@ def __init__( ): """ :param model_id: The ID of the model to load from HuggingFace. - :param max_context_size: The context size of the model. - :param use_auth_token: The Hugging Face access token (for gated models). Pass True to load from huggingface-cli. + :param max_context_size: The context size of the model. If not given, will be set from the model's config. + :param prompt_pipeline: The pipeline to translate a list of kani ChatMessages into the model-specific chat + format (see :class:`.PromptPipeline`). + :param token: The Hugging Face access token (for gated models). Pass True to load from huggingface-cli. :param device: The hardware device to use. If not specified, uses CUDA if available; otherwise uses CPU. :param tokenizer_kwargs: Additional arguments to pass to ``AutoTokenizer.from_pretrained()``. :param model_load_kwargs: Additional arguments to pass to ``AutoModelForCausalLM.from_pretrained()``. @@ -54,47 +65,96 @@ def __init__( if model_load_kwargs is None: model_load_kwargs = {} - tokenizer_kwargs.setdefault("use_auth_token", use_auth_token) - model_load_kwargs.setdefault("use_auth_token", use_auth_token) + tokenizer_kwargs.setdefault("token", hyperparams.get("use_auth_token", token)) + model_load_kwargs.setdefault("token", hyperparams.pop("use_auth_token", token)) self.model_id = model_id self.max_context_size = max_context_size + self.pipeline = prompt_pipeline + self.tokenizer = AutoTokenizer.from_pretrained(model_id, **tokenizer_kwargs) self.model = AutoModelForCausalLM.from_pretrained(model_id, **model_load_kwargs) self.hyperparams = hyperparams + # ensure model is on correct device if device is None: - device = "cuda" if torch.has_cuda else "cpu" + device = "cuda" if torch.backends.cuda.is_built() else "cpu" self.device = device if self.model.device.type != self.device: self.model.to(device) - @abc.abstractmethod + # token counting stuff + # try and infer max context size from the model config if not specified + if self.max_context_size is None: + self.max_context_size = getattr( + self.model.config, + "model_max_len", + getattr(self.model.config, "max_position_embeddings", None), + ) + log.debug(f"Inferred max context size: {self.max_context_size}") + + if self.max_context_size is None: + raise ValueError( + "Could not infer the model's max context size from the config. Please pass the `max_context_size`" + " arg." + ) + elif self.max_context_size > 1e20: + warnings.warn( + f"The inferred max context size of this model is extremely large ({self.max_context_size}). This" + " may mean that the model has not configured their model_max_len correctly (or you are still using" + " my code in 2050). Please pass the `max_context_size` arg to use the correct model size." + ) + + # infer the token reserve from the pipeline + if self.token_reserve == 0 and self.pipeline: + self.token_reserve = self._infer_token_reserve() + + def _infer_token_reserve(self): + """If token_reserve is not set and we have a pipeline, infer it.""" + prompt = self.pipeline.execute([], for_measurement=True) + if isinstance(prompt, torch.Tensor): + return len(prompt[0]) + # prompt str to tokens + tokenized = self.tokenizer.encode(prompt, add_special_tokens=False) + return len(tokenized) + + def message_len(self, message: ChatMessage) -> int: + # default concrete base behaviour: + if self.pipeline is None: + raise NotImplementedError( + "You must pass a prompt_pipeline to the HuggingEngine to use it as a non-abstract class." + ) + prompt = self.pipeline.execute([message], for_measurement=True) + if isinstance(prompt, torch.Tensor): + return len(prompt[0]) + # prompt str to tokens + tokenized = self.tokenizer.encode(prompt, add_special_tokens=False) + return len(tokenized) + def build_prompt( self, messages: list[ChatMessage], functions: list[AIFunction] | None = None ) -> str | torch.Tensor: - """Given the list of messages from kani, build either a single string representing the prompt for the model, - or build the token tensor.""" - raise NotImplementedError - - async def predict( - self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams - ) -> Completion: """ - Given the current context of messages and available functions, get the next predicted chat message from the LM. + Given the list of messages from kani, build either a single string representing the prompt for the model, + or build the token tensor. - :param messages: The messages in the current chat context. ``sum(message_len(m) for m in messages)`` is - guaranteed to be less than max_context_size. - :param functions: The functions the LM is allowed to call. - :param hyperparams: Any additional parameters to pass to GenerationMixin.generate(). (See - https://huggingface.co/docs/transformers/main_classes/text_generation) + The default behaviour is to call the supplied pipeline. """ - prompt = self.build_prompt(messages, functions) + if self.pipeline is None: + raise NotImplementedError( + "You must pass a prompt_pipeline to the HuggingEngine to use it as a non-abstract class." + ) + prompt = self.pipeline(messages) + log.debug(f"BUILT PROMPT: {prompt}") + return prompt + + def _get_generate_args(self, prompt: str | torch.Tensor, **hyperparams): + """Internal method to build common params for the generate call""" if isinstance(prompt, str): # prompt str to tokens - tokenized = self.tokenizer(prompt, return_tensors="pt", return_length=True) - input_len = int(tokenized.length) - input_toks = tokenized.input_ids + tokenized = self.tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt") + input_toks = tokenized + input_len = len(tokenized[0]) elif isinstance(prompt, torch.Tensor): input_toks = prompt input_len = len(input_toks[0]) @@ -106,6 +166,23 @@ async def predict( # set up hyperparams for HF decode hyperparams = {**self.hyperparams, **hyperparams} hyperparams.setdefault("max_length", self.max_context_size) # by default HF sets this to 20, which is too small + return input_toks, input_len, hyperparams + + async def predict( + self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams + ) -> Completion: + """ + Given the current context of messages and available functions, get the next predicted chat message from the LM. + + :param messages: The messages in the current chat context. ``sum(message_len(m) for m in messages)`` is + guaranteed to be less than max_context_size. + :param functions: The functions the LM is allowed to call. + :param hyperparams: Any additional parameters to pass to GenerationMixin.generate(). (See + https://huggingface.co/docs/transformers/main_classes/text_generation) + """ + prompt = self.build_prompt(messages, functions) + input_toks, input_len, hyperparams = self._get_generate_args(prompt, **hyperparams) + # run it through the model output = self.model.generate(input_toks, **hyperparams) # decode to tokens @@ -114,3 +191,39 @@ async def predict( return Completion( ChatMessage.assistant(content), prompt_tokens=input_len, completion_tokens=len(output[0]) - (input_len + 1) ) + + async def stream( + self, + messages: list[ChatMessage], + functions: list[AIFunction] | None = None, + *, + streamer_timeout=None, + **hyperparams, + ) -> AsyncIterable[str | BaseCompletion]: + """ + Given the current context of messages and available functions, get the next predicted chat message from the LM. + + :param messages: The messages in the current chat context. ``sum(message_len(m) for m in messages)`` is + guaranteed to be less than max_context_size. + :param functions: The functions the LM is allowed to call. + :param streamer_timeout: The maximum number of seconds to wait for the next token when streaming. + :param hyperparams: Any additional parameters to pass to GenerationMixin.generate(). (See + https://huggingface.co/docs/transformers/main_classes/text_generation) + """ + prompt = self.build_prompt(messages, functions) + input_toks, input_len, hyperparams = self._get_generate_args(prompt, **hyperparams) + streamer = TextIteratorStreamer( + self.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=streamer_timeout + ) + + # run it through the model in another thread so that we can get the tokens in this thread + generate_func = functools.partial(self.model.generate, input_toks, streamer=streamer, **hyperparams) + thread = Thread(target=generate_func) + thread.start() + + # then wait for tokens from the task + for token in streamer: + yield token + + # finally clean up the thread + thread.join() diff --git a/kani/engines/huggingface/cohere.py b/kani/engines/huggingface/cohere.py new file mode 100644 index 0000000..8e2e28e --- /dev/null +++ b/kani/engines/huggingface/cohere.py @@ -0,0 +1,321 @@ +import functools +import json +import logging +import re +from collections.abc import AsyncIterable +from threading import Thread + +from kani.ai_function import AIFunction +from kani.exceptions import MissingModelDependencies +from kani.models import ChatMessage, ChatRole, FunctionCall, ToolCall +from kani.prompts.impl.cohere import ( + COMMAND_R_PIPELINE, + DEFAULT_PREAMBLE, + DEFAULT_TASK, + DEFAULT_TOOL_INSTRUCTIONS, + DEFAULT_TOOL_PROMPT, + build_rag_pipeline, + build_tool_pipeline, + function_prompt, + tool_call_formatter, +) +from .base import HuggingEngine +from ..base import Completion + +try: + import torch + from torch import tensor + from transformers import TextIteratorStreamer +except ImportError: + raise MissingModelDependencies( + 'The CommandREngine requires extra dependencies. Please install kani with "pip install' + " 'kani[huggingface]'\". You will also need to install PyTorch manually." + ) from None + +log = logging.getLogger(__name__) + + +class CommandREngine(HuggingEngine): + """Implementation of Command R (35B) and Command R+ (104B) using huggingface transformers. + + Model IDs: + + - ``CohereForAI/c4ai-command-r-v01`` + - ``CohereForAI/c4ai-command-r-plus`` + + **GPU Support** + + By default, the CommandREngine loads the model on GPU if CUDA is detected on your system. To override the device + the model is loaded on, pass ``device="cpu|cuda"`` to the constructor. + + **Usage** + + .. code-block:: python + + engine = CommandREngine("CohereForAI/c4ai-command-r-v01") + ai = KaniWithFunctions(engine) + + **Configuration** + + Command R has many configurations that enable function calling and/or RAG, and it is poorly documented exactly + how certain prompts affect the model. In this implementation, we default to the Cohere-supplied "preamble" if + function definitions are supplied, and assume that we pass every generated function call and results each turn. + + When generating the result of a tool call turn, this implementation does NOT request the model to generate + citations by default (unlike the Cohere API). You can enable citations by setting the ``rag_prompt_instructions`` + parameter to ``DEFAULT_RAG_INSTRUCTIONS_ACC`` or ``DEFAULT_RAG_INSTRUCTIONS_FAST`` (imported from + ``kani.prompts.impl.cohere``). + + See the constructor's available parameters for more information. + + .. caution:: + + Command R requires ``transformers>=4.39.1`` as a dependency. If you see warnings about a missing + ``CohereTokenizerFast``, please update your version with ``pip install transformers>=4.39.1``. + + .. seealso:: https://huggingface.co/CohereForAI/c4ai-command-r-v01 + + .. tip:: See :ref:`4b_quant` for information about loading a quantized model for lower memory usage. + """ + + token_reserve = 200 # generous reserve due to large ctx size and weird 3-mode prompt + + def __init__( + self, + model_id: str = "CohereForAI/c4ai-command-r-v01", + *args, + tool_prompt_include_function_calls=True, + tool_prompt_include_function_results=True, + tool_prompt_instructions=DEFAULT_TOOL_INSTRUCTIONS, + rag_prompt_include_function_results=True, + rag_prompt_instructions=None, + **kwargs, + ): + """ + :param model_id: The ID of the model to load from HuggingFace. + :param max_context_size: The context size of the model (defaults to Command R's size of 128k). + :param device: The hardware device to use. If not specified, uses CUDA if available; otherwise uses CPU. + :param tokenizer_kwargs: Additional arguments to pass to ``AutoTokenizer.from_pretrained()``. + :param model_load_kwargs: Additional arguments to pass to ``AutoModelForCausalLM.from_pretrained()``. + :param tool_prompt_include_function_calls: Whether to include previous turns' function calls or just the model's + answers. + :param tool_prompt_include_function_results: Whether to include the results of previous turns' function calls in + the context. + :param tool_prompt_instructions: The system prompt to send just before the model's generation turn that includes + instructions on the format to generate tool calls in. Generally you shouldn't change this. + :param rag_prompt_include_function_results: Whether to include the results of previous turns' function calls in + the context. + :param rag_prompt_instructions: The system prompt to send just before the model's generation turn that includes + instructions on the format to generate the result in. Can be None to only generate a model turn. Defaults + to ``None`` to for maximum interoperability between models. Options: + + - ``from kani.prompts.impl.cohere import DEFAULT_RAG_INSTRUCTIONS_ACC`` + - ``from kani.prompts.impl.cohere import DEFAULT_RAG_INSTRUCTIONS_FAST`` + - ``None`` (default) + - another user-supplied string + :param hyperparams: Additional arguments to supply the model during generation. + """ + kwargs.setdefault("max_context_size", 128000) + super().__init__(model_id, *args, **kwargs) + + self._tool_prompt_include_function_calls = tool_prompt_include_function_calls + + self._tool_pipeline = build_tool_pipeline( + include_function_calls=tool_prompt_include_function_calls, + include_function_results=tool_prompt_include_function_results, + tool_instructions=tool_prompt_instructions, + ) + self._rag_pipeline = build_rag_pipeline( + include_previous_results=rag_prompt_include_function_results, + rag_instructions=rag_prompt_instructions, + ) + + # ==== token counting ==== + def message_len(self, message: ChatMessage) -> int: + # prompt str to tokens + if message.text: + tokenized = self.tokenizer.encode(message.text, add_special_tokens=False) + else: + tokenized = 0 + + # worst-case function calls if we have them + if self._tool_prompt_include_function_calls and message.role == ChatRole.ASSISTANT: + func_body = tool_call_formatter(message) + tokenized = self.tokenizer.encode(func_body, add_special_tokens=False) + return len(tokenized) + 3 + + # + if message.role == ChatRole.FUNCTION: + return len(tokenized) + 12 + + return len(tokenized) + 3 + + def function_token_reserve(self, functions: list[AIFunction]) -> int: + # include the additional default system prompt tokens here plus the directly_answer tool + default_prompt_tokens = 325 + function_text = "\n\n".join(map(function_prompt, functions)) + function_tokens = len(self.tokenizer.encode(function_text, add_special_tokens=False)) + return function_tokens + default_prompt_tokens + + # ==== prompt ==== + def build_prompt( + self, messages: list[ChatMessage], functions: list[AIFunction] | None = None + ) -> str | torch.Tensor: + # no functions: we can just do the default simple format + if not functions: + prompt = COMMAND_R_PIPELINE(messages) + log.debug(f"PROMPT: {prompt}") + return prompt + + # if we do have functions things get wacky + # is the last message a FUNCTION? if so, we need to use the RAG template + if messages and messages[-1].role == ChatRole.FUNCTION: + prompt = self._build_prompt_rag(messages) + log.debug(f"RAG PROMPT: {prompt}") + return prompt + + # otherwise use the TOOL template + prompt = self._build_prompt_tools(messages, functions) + log.debug(f"TOOL PROMPT: {prompt}") + return prompt + + def _build_prompt_tools(self, messages: list[ChatMessage], functions: list[AIFunction]): + # get the function definitions + function_text = "\n\n".join(map(function_prompt, functions)) + tool_prompt = DEFAULT_TOOL_PROMPT.format(user_functions=function_text) + + # wrap the initial system message, if any + messages = messages.copy() + if messages and messages[0].role == ChatRole.SYSTEM: + messages[0] = messages[0].copy_with(content=DEFAULT_PREAMBLE + messages[0].text + tool_prompt) + # otherwise add it in + else: + messages.insert(0, ChatMessage.system(DEFAULT_PREAMBLE + DEFAULT_TASK + tool_prompt)) + + return self._tool_pipeline(messages) + + def _build_prompt_rag(self, messages: list[ChatMessage]): + # wrap the initial system message, if any + messages = messages.copy() + if messages and messages[0].role == ChatRole.SYSTEM: + messages[0] = messages[0].copy_with(content=DEFAULT_PREAMBLE + messages[0].text) + # otherwise add it in + else: + messages.insert(0, ChatMessage.system(DEFAULT_PREAMBLE + DEFAULT_TASK)) + + return self._rag_pipeline(messages) + + # ==== generate ==== + def _generate(self, input_toks, input_len, hyperparams, functions): + """Generate and return a completion (may be a directly_answer call).""" + # run it through the model + output = self.model.generate(input_toks, **hyperparams) + # decode to tokens + # the completion shouldn't include the prompt or stop token + content = self.tokenizer.decode(output[0][input_len:-1]).strip() + completion_tokens = len(output[0]) - (input_len + 1) + log.debug(f"COMPLETION: {content}") + + # if we have tools, possibly parse out the Action + tool_calls = None + if functions and (action_json := re.match(r"Action:\s*```json\n(.+)\n```", content, re.IGNORECASE | re.DOTALL)): + actions = json.loads(action_json.group(1)) + + # translate back to kani spec + tool_calls = [] + for action in actions: + tool_name = action["tool_name"] + tool_args = json.dumps(action["parameters"]) + tool_call = ToolCall.from_function_call(FunctionCall(name=tool_name, arguments=tool_args)) + tool_calls.append(tool_call) + + content = None + log.debug(f"PARSED TOOL CALLS: {tool_calls}") + + return Completion( + ChatMessage.assistant(content, tool_calls=tool_calls), + prompt_tokens=input_len, + completion_tokens=completion_tokens, + ) + + async def _stream(self, input_toks, hyperparams, *, streamer_timeout=None) -> AsyncIterable[str | Completion]: + """Low-level stream yielder (kind of weird duplicated code but it's ok)""" + # noinspection DuplicatedCode + streamer = TextIteratorStreamer( + self.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=streamer_timeout + ) + + # run it through the model in another thread so that we can get the tokens in this thread + generate_func = functools.partial(self.model.generate, input_toks, streamer=streamer, **hyperparams) + thread = Thread(target=generate_func) + thread.start() + + # then wait for tokens from the task + for token in streamer: + log.debug(f"STREAM DELTA: {token}") + yield token + + # finally clean up the thread + thread.join() + + async def predict( + self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams + ) -> Completion: + prompt = self.build_prompt(messages, functions) + input_toks, input_len, hyperparams = self._get_generate_args(prompt, **hyperparams) + completion = self._generate(input_toks, input_len, hyperparams, functions) + + tool_calls = completion.message.tool_calls or [] + # if the model generated multiple calls that happen to include a directly_answer, remove the directly_answer + if len(tool_calls) > 1: + completion.message.tool_calls = [ + tc for tc in completion.message.tool_calls if tc.function.name != "directly_answer" + ] + # if tool says directly answer, call again with the rag pipeline (but no result) + elif len(tool_calls) == 1 and tool_calls[0].function.name == "directly_answer": + log.debug("GOT DIRECTLY_ANSWER, REPROMPTING RAG...") + prompt = self._build_prompt_rag(messages) + log.debug(f"RAG PROMPT: {prompt}") + input_toks, input_len, hyperparams = self._get_generate_args(prompt, **hyperparams) + completion = self._generate(input_toks, input_len, hyperparams, functions) + # otherwise don't touch it + return completion + + async def stream( + self, + messages: list[ChatMessage], + functions: list[AIFunction] | None = None, + *, + streamer_timeout=None, + **hyperparams, + ) -> AsyncIterable[str | Completion]: + # if we have functions things get weird + # if we have tools and the last turn is not FUNCTION, no-stream the first round to get the Action + if functions and not (messages and messages[-1].role == ChatRole.FUNCTION): + prompt = self.build_prompt(messages, functions) + input_toks, input_len, hyperparams = self._get_generate_args(prompt, **hyperparams) + completion = self._generate(input_toks, input_len, hyperparams, functions) + + tool_calls = completion.message.tool_calls or [] + # if the model generated multiple calls that happen to include a directly_answer, remove the directly_answer + if len(tool_calls) > 1: + completion.message.tool_calls = [ + tc for tc in completion.message.tool_calls if tc.function.name != "directly_answer" + ] + # if tool says directly answer, stream with the rag pipeline (but no result) + elif len(tool_calls) == 1 and tool_calls[0].function.name == "directly_answer": + log.debug("GOT DIRECTLY_ANSWER, REPROMPTING RAG...") + prompt = self._build_prompt_rag(messages) + log.debug(f"RAG PROMPT: {prompt}") + input_toks, input_len, hyperparams = self._get_generate_args(prompt, **hyperparams) + async for elem in self._stream(input_toks, hyperparams, streamer_timeout=streamer_timeout): + yield elem + # otherwise yield as normal + else: + if completion.message.text: + yield completion.message.text + yield completion + # otherwise stream as normal + else: + async for elem in super().stream(messages, functions, streamer_timeout=streamer_timeout, **hyperparams): + yield elem diff --git a/kani/engines/huggingface/llama2.py b/kani/engines/huggingface/llama2.py index 8b8594b..d0ec6f9 100644 --- a/kani/engines/huggingface/llama2.py +++ b/kani/engines/huggingface/llama2.py @@ -1,11 +1,6 @@ -import functools - -from kani.ai_function import AIFunction from kani.exceptions import MissingModelDependencies -from kani.models import ChatMessage, ChatRole +from kani.prompts.impl import LLAMA2_PIPELINE from .base import HuggingEngine -from .. import llama2_prompt -from ..llama2_prompt import B_INST, B_SYS, E_INST, E_SYS try: import sentencepiece @@ -19,7 +14,8 @@ class LlamaEngine(HuggingEngine): - """Implementation of LLaMA v2 using huggingface transformers. + r""" + Implementation of LLaMA v2 using huggingface transformers. You may also use the 13b, 70b, or other LLaMA models that use the LLaMA prompt by passing the HuggingFace model ID to the initializer. @@ -37,6 +33,13 @@ class LlamaEngine(HuggingEngine): By default, the HuggingEngine loads the model on GPU if CUDA is detected on your system. To override the device the model is loaded on, pass ``device="cpu|cuda"`` to the constructor. + **Usage** + + .. code-block:: python + + engine = LlamaEngine("meta-llama/Llama-2-7b-chat-hf", use_auth_token=True) + ai = Kani(engine) + .. attention:: You will need to accept Meta's license in order to download the LLaMA v2 weights. Visit @@ -49,73 +52,44 @@ class LlamaEngine(HuggingEngine): .. tip:: See :ref:`4b_quant` for information about loading a quantized model for lower memory usage. - .. code-block:: python - - engine = LlamaEngine("meta-llama/Llama-2-7b-chat-hf", use_auth_token=True) - ai = Kani(engine) + .. tip:: + + This engine is equivalent to the following usage of the base :class:`.HuggingEngine`. + + .. code-block:: python + + LLAMA2_PIPELINE = ( + PromptPipeline() + .wrap(role=ChatRole.SYSTEM, prefix="<>\n", suffix="\n<>\n") + .translate_role(role=ChatRole.SYSTEM, to=ChatRole.USER) + .merge_consecutive(role=ChatRole.USER, sep="\n") + .merge_consecutive(role=ChatRole.ASSISTANT, sep=" ") + .conversation_fmt( + user_prefix="[INST] ", + user_suffix=" [/INST]", + assistant_prefix=" ", + assistant_suffix=" ", + assistant_suffix_if_last="", + ) + ) + + engine = HuggingEngine( + "meta-llama/Llama-2-7b-chat-hf", + prompt_pipeline=LLAMA2_PIPELINE + ) + + See :class:`.PromptPipeline` for more information on reusable prompt pipelines. """ - def __init__(self, model_id: str = "meta-llama/Llama-2-7b-chat-hf", *args, strict=False, **kwargs): + def __init__(self, model_id: str = "meta-llama/Llama-2-7b-chat-hf", *args, **kwargs): """ :param model_id: The ID of the model to load from HuggingFace. - :param use_auth_token: The Hugging Face access token (for gated models). Pass True to load from huggingface-cli. - :param strict: Whether to enforce the same prompt constraints as Meta's LLaMA (always starting with - system, alternating user/assistant). This is recommended for using the base LLaMA model. + :param token: The Hugging Face access token (for gated models). Pass True to load from huggingface-cli. :param max_context_size: The context size of the model. :param device: The hardware device to use. If not specified, uses CUDA if available; otherwise uses CPU. :param tokenizer_kwargs: Additional arguments to pass to ``AutoTokenizer.from_pretrained()``. :param model_load_kwargs: Additional arguments to pass to ``AutoModelForCausalLM.from_pretrained()``. :param hyperparams: Additional arguments to supply the model during generation. """ - self.strict = strict - - kwargs.setdefault("max_context_size", 4096) # LLaMA has 4096 token window + kwargs.setdefault("prompt_pipeline", LLAMA2_PIPELINE) super().__init__(model_id, *args, **kwargs) - - # noinspection PyMethodMayBeStatic - def _build_prompt_strict(self, messages: list[ChatMessage]): - if len(messages) < 2: - raise ValueError("You must specify at least one message and a system prompt (LLaMA strict mode).") - if messages[0].role != ChatRole.SYSTEM: - raise ValueError("The first message must be a system prompt (LLaMA strict mode).") - if not ( - all([m.role == ChatRole.USER for m in messages[1::2]]) - and all([m.role == ChatRole.ASSISTANT for m in messages[2::2]]) - ): - raise ValueError("Messages after the first must alternate between user and system (LLaMA strict mode).") - if messages[-1].role != ChatRole.USER: - raise ValueError("The last message must be from the user (LLaMA strict mode).") - # implementation based on llama code - dialog = [f"{B_SYS}{messages[0].text}{E_SYS}{messages[1].text}"] + [m.text for m in messages[2:]] - dialog_tokens = sum( - [ - self.tokenizer.encode(f"{B_INST} {prompt} {E_INST} {answer}") + [self.tokenizer.eos_token_id] - for prompt, answer in zip(dialog[::2], dialog[1::2]) - ], - [], - ) - dialog_tokens += self.tokenizer.encode(f"{B_INST} {dialog[-1]} {E_INST}") - return torch.tensor([dialog_tokens], device=self.device) - - def build_prompt(self, messages: list[ChatMessage], functions: list[AIFunction] | None = None) -> torch.Tensor: - if self.strict: - return self._build_prompt_strict(messages) - tokenize = functools.partial(self.tokenizer.encode, add_special_tokens=False) - tokens = llama2_prompt.build( - messages, - tokenize=tokenize, - bos_token_id=self.tokenizer.bos_token_id, - eos_token_id=self.tokenizer.eos_token_id, - ) - return torch.tensor([tokens], device=self.device) - - def message_len(self, message: ChatMessage) -> int: - # https://github.com/facebookresearch/llama/blob/main/llama/generation.py#L212 - if message.role == ChatRole.USER: - # [INST] {} [/INST] -> 7 - return self.tokenizer(message.text, return_length=True).length[0] + 7 - elif message.role == ChatRole.ASSISTANT: - # {} -> 2 - return self.tokenizer(f" {message.text} ", return_length=True).length[0] + 2 - # [INST] <>\n{}\n<>\n\n [/INST] -> 20 - return self.tokenizer(message.text, return_length=True).length[0] + 20 diff --git a/kani/engines/huggingface/vicuna.py b/kani/engines/huggingface/vicuna.py index 0956b10..6232229 100644 --- a/kani/engines/huggingface/vicuna.py +++ b/kani/engines/huggingface/vicuna.py @@ -1,6 +1,5 @@ -from kani.ai_function import AIFunction from kani.exceptions import MissingModelDependencies -from kani.models import ChatMessage, ChatRole +from kani.prompts.impl import VICUNA_PIPELINE from .base import HuggingEngine try: @@ -20,7 +19,8 @@ class VicunaEngine(HuggingEngine): - """Implementation of Vicuna (a LLaMA v1 fine-tune) using huggingface transformers. + """ + Implementation of Vicuna (a LLaMA v1 fine-tune) using huggingface transformers. You may also use the 13b, 33b, or other LLaMA models that use the Vicuna prompt by passing the HuggingFace model ID to the initializer. @@ -40,10 +40,6 @@ class VicunaEngine(HuggingEngine): ai = Kani(engine) """ - # all prompts start with a hidden token and ASSISTANT: - token_reserve = 7 - """The Vicuna prompt starts with a hidden token and "ASSISTANT:", which uses 7 tokens.""" - def __init__(self, model_id: str = "lmsys/vicuna-7b-v1.3", *args, **kwargs): tokenizer_kwargs = kwargs.pop("tokenizer_kwargs", {}) tokenizer_kwargs.setdefault("use_fast", False) @@ -51,31 +47,7 @@ def __init__(self, model_id: str = "lmsys/vicuna-7b-v1.3", *args, **kwargs): model_load_kwargs = kwargs.pop("model_load_kwargs", {}) model_load_kwargs.setdefault("low_cpu_mem_usage", _low_cpu_mem_usage) - kwargs.setdefault("max_context_size", 2048) # LLaMA has 2048 token window + kwargs.setdefault("prompt_pipeline", VICUNA_PIPELINE) super().__init__( model_id, *args, tokenizer_kwargs=tokenizer_kwargs, model_load_kwargs=model_load_kwargs, **kwargs ) - - def build_prompt(self, messages: list[ChatMessage], functions: list[AIFunction] | None = None) -> str: - prompt_lines = [] - for message in messages: - if message.role == ChatRole.USER: - prompt_lines.append(f"USER: {message.text}") - elif message.role == ChatRole.ASSISTANT: - prompt_lines.append(f"ASSISTANT: {message.text}") - else: - prompt_lines.append(f"{message.text}\n") - prompt = "\n".join(prompt_lines) - return f"{prompt}\nASSISTANT:" - - def message_len(self, message: ChatMessage) -> int: - # https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#example-prompt-weights-v11-and-v13 - # remove 1 for the token at the start - if message.role == ChatRole.USER: - # USER: {}\n -> 5 - return self.tokenizer(message.text, return_length=True).length + 4 - elif message.role == ChatRole.ASSISTANT: - # ASSISTANT: {}\n -> 8 - return self.tokenizer(message.text, return_length=True).length + 7 - # {}\n\n -> 2 - return self.tokenizer(message.text, return_length=True).length + 1 diff --git a/kani/engines/llama2_prompt.py b/kani/engines/llama2_prompt.py deleted file mode 100644 index 66bf4b7..0000000 --- a/kani/engines/llama2_prompt.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Common builder for the LLaMAv2-chat prompt. - -This file is responsible for implementing the common non-strict ChatMessage to tokens translation, while handling -the nuance of the INST and SYS tokens as best as possible. -""" - -import itertools -from typing import Callable, Iterable - -from kani.models import ChatMessage, ChatRole - -B_INST, E_INST = "[INST]", "[/INST]" -B_SYS, E_SYS = "<>\n", "\n<>\n\n" - - -def build( - messages: list[ChatMessage], tokenize: Callable[[str], list[int]], bos_token_id: int = 1, eos_token_id: int = 2 -) -> list[int]: - """Build the tokens for a list of messages. `tokenize` should tokenize a str without special tokens.""" - tokens = [] - for content, bos, eos in build_str(messages): - if bos: - tokens.append(bos_token_id) - tokens.extend(tokenize(content)) - if eos: - tokens.append(eos_token_id) - return tokens - - -def build_str(messages: list[ChatMessage]) -> Iterable[tuple[str, bool, bool]]: - """Given a list of messages, yield a list of pairs of (content string, bos, eos).""" - # combine consecutive instruction messages and non-instruction messages - for is_inst, role_messages in itertools.groupby( - messages, key=lambda m: m.role == ChatRole.USER or m.role == ChatRole.SYSTEM - ): - # get content within tags (if any) - content = [] - for message in role_messages: - if message.role == ChatRole.SYSTEM: - content.append(f"{B_SYS}{message.text}{E_SYS}") - else: - content.append(message.text) - # if the content is an instruction, return it wrapped in inst tags; otherwise don't - content_str = "".join(content) - if is_inst: - yield f"{B_INST} {content_str} {E_INST}", True, False - else: - yield f" {content_str} ", False, True diff --git a/kani/engines/llamacpp/__init__.py b/kani/engines/llamacpp/__init__.py new file mode 100644 index 0000000..0574d9a --- /dev/null +++ b/kani/engines/llamacpp/__init__.py @@ -0,0 +1 @@ +from .base import LlamaCppEngine diff --git a/kani/engines/llamacpp/base.py b/kani/engines/llamacpp/base.py new file mode 100644 index 0000000..75e01d4 --- /dev/null +++ b/kani/engines/llamacpp/base.py @@ -0,0 +1,174 @@ +from typing import AsyncIterable + +from kani.ai_function import AIFunction +from kani.exceptions import MissingModelDependencies +from kani.models import ChatMessage +from kani.prompts.impl import LLAMA2_PIPELINE +from kani.prompts.pipeline import PromptPipeline +from ..base import BaseCompletion, BaseEngine, Completion + +try: + import torch + from llama_cpp import Llama +except ImportError: + raise MissingModelDependencies( + 'The LlamaCppEngine requires extra dependencies. Please install kani with "pip install kani[cpp]". ' + ) from None + + +class LlamaCppEngine(BaseEngine): + """ + This class implements the main decoding logic for any GGUF model (not just LLaMA as the name might suggest). + + This engine defaults to LLaMA 2 Chat 7B with 4-bit quantization. + + **GPU Support** + + llama.cpp supports multiple acceleration backends, which may require different flags to be set during installation. + To see the full list of backends, see their README at https://github.com/abetlen/llama-cpp-python. + + To load some or all of the model layers on GPU, pass ``n_gpu_layers=...`` in the ``model_load_kwargs``. Use + ``-1`` to specify all layers. + """ + + def __init__( + self, + repo_id: str, + filename: str = None, + max_context_size: int = 0, + prompt_pipeline: PromptPipeline[str | list[int]] = LLAMA2_PIPELINE, + *, + model_load_kwargs: dict = None, + **hyperparams, + ): + """ + :param repo_id: The ID of the model repo to load from Hugging Face. + :param filename: A filename or glob pattern to match the model file in the repo. + :param max_context_size: The context size of the model. + :param prompt_pipeline: The pipeline to translate a list of kani ChatMessages into the model-specific chat + format (see :class:`.PromptPipeline`). + :param model_load_kwargs: Additional arguments to pass to ``Llama.from_pretrained()``. + See `this link `_ + for more info. + :param hyperparams: Additional arguments to supply the model during generation. + """ + + if model_load_kwargs is None: + model_load_kwargs = {} + + self.repo_id = repo_id + self.filename = filename + self.pipeline = prompt_pipeline + + model_load_kwargs.setdefault("n_ctx", max_context_size) + self.model = Llama.from_pretrained(repo_id=repo_id, filename=filename, **model_load_kwargs) + self.hyperparams = hyperparams + + self.max_context_size = max_context_size or self.model.n_ctx() + + if self.token_reserve == 0 and self.pipeline: + self.token_reserve = self._infer_token_reserve() + + def _infer_token_reserve(self): + """If token_reserve is not set and we have a pipeline, infer it.""" + prompt = self.pipeline.execute([], for_measurement=True) + if isinstance(prompt, list): + return len(prompt) + tokenized = self.model.tokenize(prompt.encode(), add_bos=False, special=True) + return len(tokenized) + + def message_len(self, message: ChatMessage) -> int: + # default concrete base behaviour: + if self.pipeline is None: + raise NotImplementedError( + "You must pass a prompt_pipeline to the HuggingEngine to use it as a non-abstract class." + ) + prompt = self.pipeline.execute([message], for_measurement=True) + if isinstance(prompt, list): + return len(prompt) + tokenized = self.model.tokenize(prompt.encode(), add_bos=False, special=True) + return len(tokenized) + + def build_prompt(self, messages: list[ChatMessage], functions: list[AIFunction] | None = None) -> str | list[int]: + """ + Given the list of messages from kani, build either a single string representing the prompt for the model, + or build the token list. + + The default behaviour is to call the supplied pipeline. + """ + if self.pipeline is None: + raise NotImplementedError( + "You must pass a prompt_pipeline to the HuggingEngine to use it as a non-abstract class." + ) + return self.pipeline(messages) + + def _get_generate_args(self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams): + """Internal method to build common params for the generate call""" + prompt = self.build_prompt(messages, functions) + if isinstance(prompt, str): + # prompt str to tokens + input_toks = self.model.tokenize(prompt.encode(), add_bos=False, special=True) + input_len = len(input_toks) + elif isinstance(prompt, list): + input_toks = prompt + input_len = len(input_toks) + else: + raise TypeError("build_prompt should either return a str or a list[int].") + # set up hyperparams + hyperparams = {**self.hyperparams, **hyperparams} + hyperparams.setdefault("max_tokens", None) # by default llama.cpp sets this to 16, which is too small + return input_toks, input_len, hyperparams + + async def predict( + self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams + ) -> Completion: + """ + Given the current context of messages and available functions, get the next predicted chat message from the LM. + + :param messages: The messages in the current chat context. ``sum(message_len(m) for m in messages)`` is + guaranteed to be less than max_context_size. + :param functions: The functions the LM is allowed to call. + :param hyperparams: Any additional parameters to pass to ``Llama.create_completion()``. (See + https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_completion) + """ + input_toks, input_len, hyperparams = self._get_generate_args(messages, functions, **hyperparams) + + completion = self.model.create_completion(input_toks, **hyperparams) + return Completion( + ChatMessage.assistant(completion["choices"][0]["text"]), + prompt_tokens=input_len, + completion_tokens=completion["usage"]["completion_tokens"], + ) + + async def stream( + self, + messages: list[ChatMessage], + functions: list[AIFunction] | None = None, + **hyperparams, + ) -> AsyncIterable[str | BaseCompletion]: + """ + Given the current context of messages and available functions, get the next predicted chat message from the LM. + + :param messages: The messages in the current chat context. ``sum(message_len(m) for m in messages)`` is + guaranteed to be less than max_context_size. + :param functions: The functions the LM is allowed to call. + :param hyperparams: Any additional parameters to pass to ``Llama.create_completion()``. (See + https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.create_completion) + """ + input_toks, input_len, hyperparams = self._get_generate_args(messages, functions, **hyperparams) + + stream = self.model.create_completion(input_toks, stream=True, **hyperparams) + + # iterate over the stream and yield/save + content_chunks = [] + for chunk in stream: + text = chunk["choices"][0]["text"] + + # yield content + if text is not None: + content_chunks.append(text) + yield text + + # construct the final completion + content = None if not content_chunks else "".join(content_chunks) + yield Completion(message=ChatMessage.assistant(content)) diff --git a/kani/engines/openai/__init__.py b/kani/engines/openai/__init__.py index eac19e6..1739297 100644 --- a/kani/engines/openai/__init__.py +++ b/kani/engines/openai/__init__.py @@ -1,2 +1 @@ -from .client import OpenAIClient from .engine import OpenAIEngine diff --git a/kani/engines/openai/client.py b/kani/engines/openai/client.py deleted file mode 100644 index 688e0cf..0000000 --- a/kani/engines/openai/client.py +++ /dev/null @@ -1,179 +0,0 @@ -import asyncio -import warnings -from typing import Literal, overload - -import aiohttp -import pydantic - -from .models import ( - ChatCompletion, - Completion, - FunctionSpec, - OpenAIChatMessage, - ResponseFormat, - SpecificFunctionCall, - ToolChoice, - ToolSpec, -) -from ..httpclient import BaseClient, HTTPException, HTTPStatusException, HTTPTimeout - - -class OpenAIClient(BaseClient): - """Simple HTTP client to interface with the OpenAI API. - - This client includes exponential backoff on retryable errors. - """ - - def __init__( - self, - api_key: str, - http: aiohttp.ClientSession = None, - organization: str = None, - retry: int = 5, - api_base: str = "https://api.openai.com/v1", - headers: dict = None, - ): - """ - :param api_key: The OpenAI API key to use. - :param http: An aiohttp session to use (creates one by default). - :param organization: The OpenAI org ID to bill requests to (default based on api key). - :param retry: How many times to retry a failed request (default 5). - :param api_base: The base URL of the API to use (default OpenAI). - :param headers: Any additional headers to send with each request. - """ - if headers is None: - headers = {} - super().__init__(http) - self.api_key = api_key - self.organization = organization - self.retry = retry - self.SERVICE_BASE = api_base - self.headers = headers - - async def request(self, method: str, route: str, headers=None, retry=None, **kwargs): - if headers is None: - headers = self.headers.copy() - # set up auth headers - if "Authorization" not in headers: - headers["Authorization"] = f"Bearer {self.api_key}" - if self.organization: - headers["OpenAI-Organization"] = self.organization - - # make the request - retry = retry if retry is not None else self.retry - for i in range(retry): - try: - return await super().request(method, route, headers=headers, **kwargs) - except (HTTPStatusException, HTTPTimeout) as e: - # raise if out of retries - if (i + 1) >= retry: - raise - - # handle http status exceptions - # non-retryable error (see https://platform.openai.com/docs/guides/error-codes/api-errors) - if isinstance(e, HTTPStatusException) and e.status_code in (400, 401, 403, 404): - raise - - # otherwise, wait and try again - retry_sec = 2**i - self.logger.warning(f"OpenAI returned {e}, retrying in {retry_sec} sec...") - await asyncio.sleep(retry_sec) - raise RuntimeError( - "Ran out of retries but no error encountered. This should never happen; please report an issue on GitHub!" - ) - - # ==== completions ==== - @overload - async def create_completion( - self, - model: str, - *, - prompt: str = "<|endoftext|>", - suffix: str = None, - max_tokens: int = 16, - temperature: float = 1.0, - top_p: float = 1.0, - n: int = 1, - logprobs: int = None, - echo: bool = False, - seed: int | None = None, - stop: str | list[str] = None, - presence_penalty: float = 0.0, - frequency_penalty: float = 0.0, - best_of: int = 1, - logit_bias: dict = None, - user: str = None, - ) -> Completion: ... - - async def create_completion(self, model: str, **kwargs) -> Completion: - data = await self.post("/completions", json={"model": model, **kwargs}) - try: - return Completion.model_validate(data) - except pydantic.ValidationError as e: - self.logger.exception(f"Failed to deserialize OpenAI response: {data}") - raise HTTPException(f"Could not deserialize response: {data}") from e - - # ==== chat ==== - @overload - async def create_chat_completion( - self, - model: str, - messages: list[OpenAIChatMessage], - *, - tools: list[ToolSpec] | None = None, - tool_choice: ToolChoice | Literal["auto"] | Literal["none"] | None = None, - temperature: float = 1.0, - top_p: float = 1.0, - n: int = 1, - response_format: ResponseFormat | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - max_tokens: int | None = None, - presence_penalty: float = 0.0, - frequency_penalty: float = 0.0, - logit_bias: dict | None = None, - user: str | None = None, - # deprecated - functions: list[FunctionSpec] | None = None, - function_call: SpecificFunctionCall | Literal["auto"] | Literal["none"] | None = None, - ) -> ChatCompletion: ... - - async def create_chat_completion( - self, - model: str, - messages: list[OpenAIChatMessage], - *, - tools: list[ToolSpec] | None = None, - **kwargs, - ) -> ChatCompletion: - """Create a chat completion. - - See https://platform.openai.com/docs/api-reference/chat/create. - """ - # transform pydantic models - if tools: - kwargs["tools"] = [t.model_dump(exclude_unset=True) for t in tools] - if "tool_choice" in kwargs and isinstance(kwargs["tool_choice"], SpecificFunctionCall): - kwargs["tool_choice"] = kwargs["tool_choice"].model_dump(exclude_unset=True) - # deprecated function calling - if "functions" in kwargs: - warnings.warn("The functions parameter is deprecated. Use tools instead.", DeprecationWarning) - kwargs["functions"] = [f.model_dump(exclude_unset=True) for f in kwargs["functions"]] - if "function_call" in kwargs and isinstance(kwargs["function_call"], SpecificFunctionCall): - warnings.warn("The function_call parameter is deprecated. Use tool_choice instead.", DeprecationWarning) - kwargs["function_call"] = kwargs["function_call"].model_dump(exclude_unset=True) - # call API - data = await self.post( - "/chat/completions", - json={ - "model": model, - # this needs to be exclude_defaults since content can be None but is required regardless - "messages": [cm.model_dump(exclude_defaults=True, mode="json") for cm in messages], - **kwargs, - }, - ) - try: - return ChatCompletion.model_validate(data) - except pydantic.ValidationError: - self.logger.exception(f"Failed to deserialize OpenAI response: {data}") - raise HTTPException(f"Could not deserialize response: {data}") diff --git a/kani/engines/openai/engine.py b/kani/engines/openai/engine.py index fa242b7..81bdad0 100644 --- a/kani/engines/openai/engine.py +++ b/kani/engines/openai/engine.py @@ -1,16 +1,16 @@ import functools -import os +from typing import AsyncIterable from kani.ai_function import AIFunction -from kani.exceptions import MissingModelDependencies, PromptError +from kani.exceptions import MissingModelDependencies from kani.models import ChatMessage, ChatRole from . import function_calling -from .client import OpenAIClient -from .models import ChatCompletion, FunctionSpec, OpenAIChatMessage, ToolSpec -from ..base import BaseEngine +from .translation import ChatCompletion, openai_tc_to_kani_tc, translate_functions, translate_messages +from ..base import BaseCompletion, BaseEngine, Completion try: import tiktoken + from openai import AsyncOpenAI as OpenAIClient except ImportError as e: raise MissingModelDependencies( 'The OpenAIEngine requires extra dependencies. Please install kani with "pip install kani[openai]".' @@ -18,10 +18,9 @@ # https://platform.openai.com/docs/models CONTEXT_SIZES_BY_PREFIX = [ - ("gpt-3.5-turbo-0125", 16384), - ("gpt-3.5-turbo-1106", 16384), - ("gpt-3.5-turbo-16k", 16384), - ("gpt-3.5-turbo", 4096), + ("gpt-3.5-turbo-instruct", 4096), + ("gpt-3.5-turbo-0613", 4096), + ("gpt-3.5-turbo", 16385), # gpt-4-turbo models aren't prefixed differently... # TODO make the default gpt-4 128k and keep the pre-turbo ones at 8k after gpt-4 defaults to 128k ("gpt-4-1106", 128000), @@ -31,13 +30,14 @@ ("gpt-4-32k", 32768), ("gpt-4", 8192), # fine-tunes - ("ft:gpt-3.5-turbo-16k", 16384), - ("ft:gpt-3.5-turbo", 4096), + ("ft:gpt-3.5-turbo-instruct", 4096), + ("ft:gpt-3.5-turbo-0613", 4096), + ("ft:gpt-3.5-turbo", 16385), ("ft:gpt-4-32k", 32768), ("ft:gpt-4", 8192), # completion models - ("text-davinci-", 4096), - ("code-", 8000), + ("babbage-002", 16384), + ("davinci-002", 16384), # catch-all ("", 2048), # e.g. aba/babbage/curie/davinci ] @@ -68,29 +68,24 @@ def __init__( :param model: The id of the model to use (e.g. "gpt-3.5-turbo", "ft:gpt-3.5-turbo:my-org:custom_suffix:id"). :param max_context_size: The maximum amount of tokens allowed in the chat prompt. If None, uses the given model's full context size. - :param organization: The OpenAI organization to use in requests (defaults to the API key's default org). + :param organization: The OpenAI organization to use in requests. By default, the org ID would be read from the + `OPENAI_ORG_ID` environment variable (defaults to the API key's default org if not set). :param retry: How many times the engine should retry failed HTTP calls with exponential backoff (default 5). :param api_base: The base URL of the OpenAI API to use. :param headers: A dict of HTTP headers to include with each request. - :param client: An instance of :class:`.OpenAIClient` (for reusing the same client in multiple engines). You must - specify exactly one of (api_key, client). If this is passed the ``organization``, ``retry``, ``api_base``, - and ``headers`` params will be ignored. - :param hyperparams: Any additional parameters to pass to - :meth:`.OpenAIClient.create_chat_completion`. + :param client: An instance of `openai.AsyncOpenAI `_ + (for reusing the same client in multiple engines). + You must specify exactly one of ``(api_key, client)``. If this is passed the ``organization``, ``retry``, + ``api_base``, and ``headers`` params will be ignored. + :param hyperparams: The arguments to pass to the ``create_chat_completion`` call with each request. See + https://platform.openai.com/docs/api-reference/chat/create for a full list of params. """ if api_key and client: raise ValueError("You must supply no more than one of (api_key, client).") - if api_key is None and client is None: - api_key = os.getenv("OPENAI_API_KEY") - if api_key is None: - raise ValueError( - "You must supply an `api_key`, `client`, or set the `OPENAI_API_KEY` environment variable to use" - " the OpenAIEngine." - ) if max_context_size is None: max_context_size = next(size for prefix, size in CONTEXT_SIZES_BY_PREFIX if model.startswith(prefix)) self.client = client or OpenAIClient( - api_key, organization=organization, retry=retry, api_base=api_base, headers=headers + api_key=api_key, organization=organization, max_retries=retry, base_url=api_base, default_headers=headers ) self.model = model self.max_context_size = max_context_size @@ -116,72 +111,68 @@ def message_len(self, message: ChatMessage) -> int: mlen += len(self.tokenizer.encode(tc.function.arguments)) return mlen - # translation helpers - @staticmethod - def translate_functions(functions: list[AIFunction], cls: type[ToolSpec] = ToolSpec) -> list[ToolSpec]: - return [ - cls.from_function(FunctionSpec(name=f.name, description=f.desc, parameters=f.json_schema)) - for f in functions - ] - - @staticmethod - def translate_messages( - messages: list[ChatMessage], cls: type[OpenAIChatMessage] = OpenAIChatMessage - ) -> list[OpenAIChatMessage]: - translated_messages = [] - free_toolcall_ids = set() - for m in messages: - # if this is not a function result and there are free tool call IDs, raise - if m.role != ChatRole.FUNCTION and free_toolcall_ids: - raise PromptError( - f"Encountered a {m.role.value!r} message but expected a FUNCTION message to satisfy the pending" - f" tool call(s): {free_toolcall_ids}" - ) - # asst: add tool call IDs to freevars - if m.role == ChatRole.ASSISTANT and m.tool_calls: - for tc in m.tool_calls: - free_toolcall_ids.add(tc.id) - # func: bind freevars - elif m.role == ChatRole.FUNCTION: - # has ID: bind it if requested; translate to FUNCTION if not - if m.tool_call_id is not None: - if m.tool_call_id in free_toolcall_ids: - free_toolcall_ids.remove(m.tool_call_id) - else: - # this happens if the tool call is pushed out of context but the result is still here, - # and we have always included messages beforehand - # TODO: this will eventually be deprecated - maube we just skip this message? - m = m.copy_with(tool_call_id=None) - # no ID: bind if unambiguous, otherwise cry - elif m.tool_call_id is None: - if len(free_toolcall_ids) == 1: - m = m.copy_with(tool_call_id=free_toolcall_ids.pop()) - elif len(free_toolcall_ids) > 1: - raise PromptError( - "Got a FUNCTION message with no tool_call_id but multiple tool calls are pending" - f" ({free_toolcall_ids})! Set the tool_call_id to resolve the pending tool requests." - ) - translated_messages.append(cls.from_chatmessage(m)) - # if the translated messages start with a hanging TOOL call, strip it (openai limitation) - # though hanging FUNCTION messages are OK - while translated_messages and translated_messages[0].role == "tool": - translated_messages.pop(0) - return translated_messages - async def predict( self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams ) -> ChatCompletion: if functions: - tool_specs = self.translate_functions(functions) + tool_specs = translate_functions(functions) else: tool_specs = None # translate to openai spec - group any tool messages together and ensure all free ToolCall IDs are bound - translated_messages = self.translate_messages(messages) + translated_messages = translate_messages(messages) # make API call - completion = await self.client.create_chat_completion( + completion = await self.client.chat.completions.create( model=self.model, messages=translated_messages, tools=tool_specs, **self.hyperparams, **hyperparams ) - return completion + # translate into Kani spec and return + return ChatCompletion(openai_completion=completion) + + async def stream( + self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams + ) -> AsyncIterable[str | BaseCompletion]: + if functions: + tool_specs = translate_functions(functions) + else: + tool_specs = None + # translate to openai spec - group any tool messages together and ensure all free ToolCall IDs are bound + translated_messages = translate_messages(messages) + # make API call + stream = await self.client.chat.completions.create( + model=self.model, + messages=translated_messages, + tools=tool_specs, + stream=True, + **self.hyperparams, + **hyperparams, + ) + + # save requested tool calls and content as streamed + content_chunks = [] + tool_call_partials = {} # index -> tool call + + # iterate over the stream and yield/save + async for chunk in stream: + delta = chunk.choices[0].delta + + # yield content + if delta.content is not None: + content_chunks.append(delta.content) + yield delta.content + + # tool calls are partials, save a mapping to the latest state and we'll translate them later once complete + if delta.tool_calls: + # each tool call can have EITHER the function.name/id OR function.arguments + for tc in delta.tool_calls: + if tc.id is not None: + tool_call_partials[tc.index] = tc + else: + partial = tool_call_partials[tc.index] + partial.function.arguments += tc.function.arguments + + # construct the final completion with streamed tool calls + content = None if not content_chunks else "".join(content_chunks) + tool_calls = [openai_tc_to_kani_tc(tc) for tc in sorted(tool_call_partials.values(), key=lambda c: c.index)] + yield Completion(message=ChatMessage(role=ChatRole.ASSISTANT, content=content, tool_calls=tool_calls)) def function_token_reserve(self, functions: list[AIFunction]) -> int: if not functions: diff --git a/kani/engines/openai/models.py b/kani/engines/openai/models.py deleted file mode 100644 index 69fa5af..0000000 --- a/kani/engines/openai/models.py +++ /dev/null @@ -1,148 +0,0 @@ -from typing import Literal - -from kani.models import BaseModel, ChatMessage, ChatRole, FunctionCall, ToolCall -from ..base import BaseCompletion - - -# ==== text completions ==== -class CompletionLogProbs(BaseModel): - tokens: list[str] - token_logprobs: list[float] - top_logprobs: list[dict[str, float]] - text_offset: list[int] - - -class CompletionChoice(BaseModel): - text: str - index: int - logprobs: CompletionLogProbs | None = None - finish_reason: str | None = None - - -class CompletionUsage(BaseModel): - prompt_tokens: int - completion_tokens: int - total_tokens: int - - -class Completion(BaseModel): - id: str - object: Literal["text_completion"] - created: int - model: str - system_fingerprint: str | None = None - choices: list[CompletionChoice] - usage: CompletionUsage - - @property - def text(self): - """The text of the top completion.""" - return self.choices[0].text - - -# ==== chat completions ==== -class FunctionSpec(BaseModel): - name: str - description: str | None = None - parameters: dict - - -class ToolSpec(BaseModel): - type: str - function: FunctionSpec - - @classmethod - def from_function(cls, spec: FunctionSpec): - return cls(type="function", function=spec) - - -class SpecificFunctionCall(BaseModel): - name: str - - -class ToolChoice(BaseModel): - type: str - function: SpecificFunctionCall - - @classmethod - def from_function(cls, name: str): - return cls(type="function", function=SpecificFunctionCall(name=name)) - - -class ResponseFormat(BaseModel): - type: str - - @classmethod - def text(cls): - return cls(type="text") - - @classmethod - def json_object(cls): - return cls(type="json_object") - - -class OpenAIChatMessage(BaseModel): - role: str - content: str | list[BaseModel | str] | None - name: str | None = None - tool_call_id: str | None = None - tool_calls: list[ToolCall] | None = None - # deprecated - function_call: FunctionCall | None = None - - @classmethod - def from_chatmessage(cls, m: ChatMessage): - # translate tool responses to a function to the right openai format - if m.role == ChatRole.FUNCTION: - if m.tool_call_id is not None: - return cls(role="tool", content=m.text, name=m.name, tool_call_id=m.tool_call_id) - return cls(role=m.role.value, content=m.text, name=m.name) - return cls(role=m.role.value, content=m.text, name=m.name, tool_call_id=m.tool_call_id, tool_calls=m.tool_calls) - - def to_chatmessage(self) -> ChatMessage: - # translate tool role to function role - if self.role == "tool": - role = ChatRole.FUNCTION - else: - role = ChatRole(self.role) - # translate FunctionCall to singular ToolCall - if self.tool_calls: - tool_calls = self.tool_calls - elif self.function_call: - tool_calls = [ToolCall.from_function_call(self.function_call)] - else: - tool_calls = None - return ChatMessage( - role=role, content=self.content, name=self.name, tool_call_id=self.tool_call_id, tool_calls=tool_calls - ) - - -# ---- response ---- -class ChatCompletionChoice(BaseModel): - message: OpenAIChatMessage - index: int - finish_reason: str | None = None - - -class ChatCompletion(BaseCompletion, BaseModel): - id: str - object: Literal["chat.completion"] - created: int - model: str - system_fingerprint: str | None = None - usage: CompletionUsage - choices: list[ChatCompletionChoice] - - @property - def message(self) -> ChatMessage: - return self.choices[0].message.to_chatmessage() - - @property - def prompt_tokens(self): - return self.usage.prompt_tokens - - @property - def completion_tokens(self): - # for some reason, the OpenAI API doesn't return the tokens used by ChatML - # so we add on the length of "<|im_start|>assistant" and "<|im_end|>" here - return self.usage.completion_tokens + 5 diff --git a/kani/engines/openai/translation.py b/kani/engines/openai/translation.py new file mode 100644 index 0000000..dd19057 --- /dev/null +++ b/kani/engines/openai/translation.py @@ -0,0 +1,137 @@ +"""Helpers to translate kani chat objects into OpenAI params.""" + +from kani.ai_function import AIFunction +from kani.engines.base import BaseCompletion +from kani.exceptions import MissingModelDependencies +from kani.models import ChatMessage, ChatRole, FunctionCall, ToolCall +from kani.prompts.pipeline import PromptPipeline + +try: + from openai.types.chat import ( + ChatCompletion as OpenAIChatCompletion, + ChatCompletionAssistantMessageParam, + ChatCompletionFunctionMessageParam, + ChatCompletionMessage, + ChatCompletionMessageParam, + ChatCompletionMessageToolCall, + ChatCompletionMessageToolCallParam, + ChatCompletionSystemMessageParam, + ChatCompletionToolMessageParam, + ChatCompletionToolParam, + ChatCompletionUserMessageParam, + ) + from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall + from openai.types.chat.chat_completion_message_tool_call import Function as ChatCompletionMessageFunctionCall + from openai.types.chat.chat_completion_message_tool_call_param import ( + Function as ChatCompletionMessageToolCallFunctionParam, + ) + from openai.types.shared_params import FunctionDefinition +except ImportError as e: + raise MissingModelDependencies( + 'The OpenAIEngine requires extra dependencies. Please install kani with "pip install kani[openai]".' + ) from None + + +# ==== kani -> openai ==== +# decomp +def kani_cm_to_openai_cm(msg: ChatMessage) -> ChatCompletionMessageParam: + """Translate a kani ChatMessage into an OpenAI Message.""" + # translate tool responses to a function to the right openai format + match msg.role: + case ChatRole.FUNCTION if msg.tool_call_id is not None: + return ChatCompletionToolMessageParam(role="tool", content=msg.text, tool_call_id=msg.tool_call_id) + case ChatRole.FUNCTION: + return ChatCompletionFunctionMessageParam(**_msg_kwargs(msg)) + case ChatRole.SYSTEM: + return ChatCompletionSystemMessageParam(**_msg_kwargs(msg)) + case ChatRole.USER: + return ChatCompletionUserMessageParam(**_msg_kwargs(msg)) + case _: # assistant + if msg.tool_calls: + tool_calls = [kani_tc_to_openai_tc(tc) for tc in msg.tool_calls] + return ChatCompletionAssistantMessageParam(**_msg_kwargs(msg), tool_calls=tool_calls) + return ChatCompletionAssistantMessageParam(**_msg_kwargs(msg)) + + +def _msg_kwargs(msg: ChatMessage) -> dict: + data = dict(role=msg.role.value, content=msg.text) + if msg.name is not None: + data["name"] = msg.name + return data + + +def kani_tc_to_openai_tc(tc: ToolCall) -> ChatCompletionMessageToolCallParam: + """Translate a kani ToolCall into an OpenAI dict""" + oai_function = ChatCompletionMessageToolCallFunctionParam(name=tc.function.name, arguments=tc.function.arguments) + return ChatCompletionMessageToolCallParam(id=tc.id, type="function", function=oai_function) + + +# main +OPENAI_PIPELINE = ( + PromptPipeline() + .ensure_bound_function_calls() + .ensure_start(predicate=lambda msg: msg.role != ChatRole.FUNCTION) + .apply(kani_cm_to_openai_cm) +) + + +def translate_functions(functions: list[AIFunction]) -> list[ChatCompletionToolParam]: + return [ + ChatCompletionToolParam( + type="function", function=FunctionDefinition(name=f.name, description=f.desc, parameters=f.json_schema) + ) + for f in functions + ] + + +def translate_messages(messages: list[ChatMessage]) -> list[ChatCompletionMessageParam]: + return OPENAI_PIPELINE(messages) + + +# ==== openai -> kani ==== +def openai_cm_to_kani_cm(msg: ChatCompletionMessage) -> ChatMessage: + """Translate an OpenAI ChatCompletionMessage into a kani ChatMessage.""" + # translate tool role to function role + if msg.role == "tool": + role = ChatRole.FUNCTION + else: + role = ChatRole(msg.role) + # translate FunctionCall to singular ToolCall + if msg.tool_calls: + tool_calls = [openai_tc_to_kani_tc(tc) for tc in msg.tool_calls] + elif msg.function_call: + tool_calls = [ToolCall.from_function_call(msg.function_call)] + else: + tool_calls = None + return ChatMessage(role=role, content=msg.content, tool_calls=tool_calls) + + +def openai_tc_to_kani_tc(tc: ChatCompletionMessageToolCall | ChoiceDeltaToolCall) -> ToolCall: + return ToolCall(id=tc.id, type=tc.type, function=openai_fc_to_kani_fc(tc.function)) + + +def openai_fc_to_kani_fc(fc: ChatCompletionMessageFunctionCall) -> FunctionCall: + return FunctionCall(name=fc.name, arguments=fc.arguments) + + +class ChatCompletion(BaseCompletion): + """A wrapper around the OpenAI ChatCompletion to make it compatible with the Kani interface.""" + + def __init__(self, openai_completion: OpenAIChatCompletion): + self.openai_completion = openai_completion + """The underlying OpenAI ChatCompletion.""" + self._message = openai_cm_to_kani_cm(openai_completion.choices[0].message) + + @property + def message(self): + return self._message + + @property + def prompt_tokens(self): + return self.openai_completion.usage.prompt_tokens + + @property + def completion_tokens(self): + # for some reason, the OpenAI API doesn't return the tokens used by ChatML + # so we add on the length of "<|im_start|>assistant" and "<|im_end|>" here + return self.openai_completion.usage.completion_tokens + 5 diff --git a/kani/exceptions.py b/kani/exceptions.py index 74a126d..97d1960 100644 --- a/kani/exceptions.py +++ b/kani/exceptions.py @@ -1,4 +1,7 @@ -import aiohttp +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import aiohttp class KaniException(Exception): @@ -11,17 +14,26 @@ class MessageTooLong(KaniException): # ==== HTTP ==== class HTTPException(KaniException): - """Base class for all HTTP errors (for HTTP engines).""" + """Base class for all HTTP errors (for HTTP engines). + + .. deprecated:: 1.0.0 + """ class HTTPTimeout(HTTPException): - """Timeout occurred connecting to or waiting for a response from an HTTP request.""" + """Timeout occurred connecting to or waiting for a response from an HTTP request. + + .. deprecated:: 1.0.0 + """ class HTTPStatusException(HTTPException): - """The HTTP server returned a non-200 status code.""" + """The HTTP server returned a non-200 status code. + + .. deprecated:: 1.0.0 + """ - def __init__(self, response: aiohttp.ClientResponse, msg: str): + def __init__(self, response: "aiohttp.ClientResponse", msg: str): super().__init__(msg) self.response = response self.status_code = response.status diff --git a/kani/json_schema.py b/kani/json_schema.py index da956e7..91aefd3 100644 --- a/kani/json_schema.py +++ b/kani/json_schema.py @@ -18,11 +18,12 @@ class AIParamSchema: This class is only used internally within kani and generally shouldn't be constructed manually. """ - def __init__(self, name: str, t: type, default, aiparam: Optional["AIParam"] = None): + def __init__(self, name: str, t: type, default, aiparam: Optional["AIParam"], inspect_param: inspect.Parameter): self.name = name - self.type = t + self.type = t # will not include Annotated if present self.default = default self.aiparam = aiparam + self.inspect_param = inspect_param @property def required(self): @@ -37,6 +38,13 @@ def origin_type(self): def description(self): return self.aiparam.desc if self.aiparam is not None else None + def __str__(self): + default = "" + if not self.required: + default = f" = {self.default!r}" + annotation = inspect.formatannotation(self.type) + return f"{self.name}: {annotation}{default}" + class JSONSchemaBuilder(pydantic.json_schema.GenerateJsonSchema): """Subclass of the Pydantic JSON schema builder to provide more fine-grained control over titles and refs.""" diff --git a/kani/kani.py b/kani/kani.py index 2554688..4c5deaf 100644 --- a/kani/kani.py +++ b/kani/kani.py @@ -2,14 +2,15 @@ import inspect import logging import warnings -import weakref -from typing import AsyncIterable, Callable +from collections.abc import AsyncIterable +from typing import Callable from .ai_function import AIFunction from .engines.base import BaseCompletion, BaseEngine from .exceptions import FunctionCallException, MessageTooLong, NoSuchFunction, WrappedCallException from .internal import ExceptionHandleResult, FunctionCallResult from .models import ChatMessage, ChatRole, FunctionCall, QueryType, ToolCall +from .streaming import DummyStream, StreamManager from .utils.message_formatters import assistant_message_contents from .utils.typing import PathLike, SavedKani @@ -24,12 +25,16 @@ class Kani: ``chat_round(query: str, **kwargs) -> ChatMessage`` - ``full_round(query: str, **kwargs) -> AsyncIterable[ChatMessage]`` - ``chat_round_str(query: str, **kwargs) -> str`` + ``chat_round_stream(query: str, **kwargs) -> StreamManager`` + + ``full_round(query: str, **kwargs) -> AsyncIterable[ChatMessage]`` + ``full_round_str(query: str, message_formatter: Callable[[ChatMessage], str], **kwargs) -> AsyncIterable[str]`` + ``full_round_stream(query: str, **kwargs) -> AsyncIterable[StreamManager]`` + **Function Calling** Subclass and use ``@ai_function()`` to register functions. The schema will be autogenerated from the function @@ -114,20 +119,9 @@ def __init__( raise ValueError(f"AIFunction {f.name!r} is already registered!") self.functions[f.name] = f - # cache - self._message_tokens = weakref.WeakKeyDictionary() - - # === main entrypoints === - async def chat_round(self, query: QueryType, **kwargs) -> ChatMessage: - """Perform a single chat round (user -> model -> user, no functions allowed). - - This is slightly faster when you are chatting with a kani with no AI functions defined. - - :param query: The contents of the user's chat message. Can be None to generate a completion without a user - prompt. - :param kwargs: Additional arguments to pass to the model engine (e.g. hyperparameters). - :returns: The model's reply. - """ + # ==== internals ==== + async def _chat_round_before(self, query: QueryType, **kwargs): + """Common preflight for chat_round_*, returns the kwargs for get_model_completion/stream""" # warn if the user has functions defined and has not explicitly silenced them in this call if self.functions and "include_functions" not in kwargs: warnings.warn( @@ -135,39 +129,20 @@ async def chat_round(self, query: QueryType, **kwargs) -> ChatMessage: " functions. Use full_round() instead.\nIf this is intentional, use chat_round(...," " include_functions=False) to silence this warning." ) - kwargs = {**kwargs, "include_functions": False} - # do the chat round - async with self.lock: - # add the user's chat input to the state - if query is not None: - await self.add_to_history(ChatMessage.user(query)) - - # and get a completion - completion = await self.get_model_completion(**kwargs) - message = completion.message - await self.add_to_history(message) - return message - - async def chat_round_str(self, query: QueryType, **kwargs) -> str: - """Like :meth:`chat_round`, but only returns the text content of the message.""" - msg = await self.chat_round(query, **kwargs) - return msg.text - - async def full_round(self, query: QueryType, **kwargs) -> AsyncIterable[ChatMessage]: - """Perform a full chat round (user -> model [-> function -> model -> ...] -> user). - - Yields each non-user ChatMessage created during the round. - A ChatMessage will have at least one of (content, function_call). - - Use this in an async for loop, like so:: - - async for msg in kani.full_round("How's the weather?"): - print(msg.text) + kwargs["include_functions"] = False + # add the user's chat input to the state + if query is not None: + await self.add_to_history(ChatMessage.user(query)) + return kwargs + + async def _add_completion_to_history(self, completion: BaseCompletion): + """Add the message in the completion to the chat history and return it""" + message = completion.message + await self.add_to_history(message) + return message - :param query: The content of the user's chat message. Can be None to generate a completion without a user - prompt. - :param kwargs: Additional arguments to pass to the model engine (e.g. hyperparameters). - """ + async def _full_round(self, query: QueryType, *, _kani_is_stream=False, **kwargs): + """Underlying handler for full_round with stream support.""" retry = 0 is_model_turn = True async with self.lock: @@ -175,11 +150,16 @@ async def full_round(self, query: QueryType, **kwargs) -> AsyncIterable[ChatMess await self.add_to_history(ChatMessage.user(query)) while is_model_turn: - # do the model prediction - completion = await self.get_model_completion(**kwargs) - message = completion.message - await self.add_to_history(message) - yield message + # do the model prediction (stream or no stream) + if _kani_is_stream: + stream = self.get_model_stream(**kwargs) + manager = StreamManager(stream, role=ChatRole.ASSISTANT, after=self._add_completion_to_history) + yield manager + message = await manager.message() + else: + completion = await self.get_model_completion(**kwargs) + message = await self._add_completion_to_history(completion) + yield message # if function call, do it and attempt retry if it's wrong if not message.tool_calls: @@ -187,10 +167,15 @@ async def full_round(self, query: QueryType, **kwargs) -> AsyncIterable[ChatMess # run each tool call in parallel async def _do_tool_call(tc: ToolCall): + # call the method and set the is_tool_call_error attr (if the impl has not already set it) try: - return await self.do_function_call(tc.function, tool_call_id=tc.id) + tc_result = await self.do_function_call(tc.function, tool_call_id=tc.id) + if tc_result.message.is_tool_call_error is None: + tc_result.message.is_tool_call_error = False except FunctionCallException as e: - return await self.handle_function_call_exception(tc.function, e, retry, tool_call_id=tc.id) + tc_result = await self.handle_function_call_exception(tc.function, e, retry, tool_call_id=tc.id) + tc_result.message.is_tool_call_error = True + return tc_result # and update results after they are completed is_model_turn = False @@ -200,7 +185,13 @@ async def _do_tool_call(tc: ToolCall): for result in results: # save the result to the chat history await self.add_to_history(result.message) - yield result.message + + # yield it, possibly in dummy streammanager + if _kani_is_stream: + yield DummyStream(result.message) + else: + yield result.message + if isinstance(result, ExceptionHandleResult): is_model_turn = True n_errs += 1 @@ -219,6 +210,73 @@ async def _do_tool_call(tc: ToolCall): else: retry = 0 + # === main entrypoints === + async def chat_round(self, query: QueryType, **kwargs) -> ChatMessage: + """Perform a single chat round (user -> model -> user, no functions allowed). + + :param query: The contents of the user's chat message. Can be None to generate a completion without a user + prompt. + :param kwargs: Additional arguments to pass to the model engine (e.g. hyperparameters). + :returns: The model's reply. + """ + async with self.lock: + kwargs = await self._chat_round_before(query, **kwargs) + completion = await self.get_model_completion(**kwargs) + return await self._add_completion_to_history(completion) + + async def chat_round_str(self, query: QueryType, **kwargs) -> str: + """Like :meth:`chat_round`, but only returns the text content of the message.""" + msg = await self.chat_round(query, **kwargs) + return msg.text + + def chat_round_stream(self, query: QueryType, **kwargs) -> StreamManager: + """ + Returns a stream of tokens from the engine as they are generated. + + To consume tokens from a stream, use this class as so:: + + stream = ai.chat_round_stream("What is the airspeed velocity of an unladen swallow?") + async for token in stream: + print(token, end="") + msg = await stream.message() + + .. tip:: + For compatibility and ease of refactoring, awaiting the stream itself will also return the message, i.e.:: + + msg = await ai.chat_round_stream("What is the airspeed velocity of an unladen swallow?") + + (note the ``await`` that is not present in the above examples). + + The arguments are the same as :meth:`chat_round`. + """ + + # this is kind of cursed - we need to run the preflight stuff before we start yielding tokens but + # this is a synch context so we'll delegate the iterator to do it before it starts yielding + async def _impl(): + _kwargs = await self._chat_round_before(query, **kwargs) + async for elem in self.get_model_stream(**_kwargs): + yield elem + + return StreamManager(_impl(), role=ChatRole.ASSISTANT, after=self._add_completion_to_history, lock=self.lock) + + async def full_round(self, query: QueryType, **kwargs) -> AsyncIterable[ChatMessage]: + """Perform a full chat round (user -> model [-> function -> model -> ...] -> user). + + Yields each non-user ChatMessage created during the round. + A ChatMessage will have at least one of ``(content, function_call)``. + + Use this in an async for loop, like so:: + + async for msg in kani.full_round("How's the weather?"): + print(msg.text) + + :param query: The content of the user's chat message. Can be None to generate a completion without a user + prompt. + :param kwargs: Additional arguments to pass to the model engine (e.g. hyperparameters). + """ + async for elem in self._full_round(query, _kani_is_stream=False, **kwargs): + yield elem + async def full_round_str( self, query: QueryType, @@ -236,6 +294,28 @@ async def full_round_str( if text := message_formatter(message): yield text + async def full_round_stream(self, query: QueryType, **kwargs) -> AsyncIterable[StreamManager]: + """ + Perform a full chat round (user -> model [-> function -> model -> ...] -> user). + + Yields a stream of tokens for each non-user ChatMessage created during the round. + + To consume tokens from a stream, use this class as so:: + + async for stream in ai.full_round_stream("What is the airspeed velocity of an unladen swallow?"): + async for token in stream: + print(token, end="") + msg = await stream.message() + + Each :class:`.StreamManager` object yielded by this method contains a :attr:`.StreamManager.role` attribute + that can be used to determine if a message is from the engine or a function call. This attribute will be + available *before* iterating over the stream. + + The arguments are the same as :meth:`full_round`. + """ + async for elem in self._full_round(query, _kani_is_stream=True, **kwargs): + yield elem + # ==== helpers ==== @property def always_len(self) -> int: @@ -252,12 +332,7 @@ def always_len(self) -> int: def message_token_len(self, message: ChatMessage): """Returns the number of tokens used by a given message.""" - try: - return self._message_tokens[message] - except KeyError: - mlen = self.engine.message_len(message) - self._message_tokens[message] = mlen - return mlen + return self.engine.message_len(message) async def get_model_completion(self, include_functions: bool = True, **kwargs) -> BaseCompletion: """Get the model's completion with the current chat state. @@ -284,13 +359,30 @@ async def get_model_completion(self, include_functions: bool = True, **kwargs) - else: completion = await self.engine.predict(messages=messages, **kwargs) - # cache its length (if the completion isn't saved to state, this weakrefs and gc's later) - message = completion.message - self._message_tokens[message] = completion.completion_tokens or self.message_token_len(message) - # and log it too - message_log.debug(f"<<< {message}") + message_log.debug(f"<<< {completion.message}") return completion + async def get_model_stream(self, include_functions: bool = True, **kwargs) -> AsyncIterable[str | BaseCompletion]: + """Get the model's completion with the current chat state as a stream. + This is a low-level method like :meth:`get_model_completion` but for streams. + """ + messages = await self.get_prompt() + n_messages = len(messages) + if n_messages == 0: + message_log.debug("[0]>>> [requested completion with no prompt]") + else: + message_log.debug(f"[{n_messages}]>>> {messages[-1]}") + + # get the model's completion at the given state + if include_functions: + stream = self.engine.stream(messages=messages, functions=list(self.functions.values()), **kwargs) + else: + stream = self.engine.stream(messages=messages, **kwargs) + + message_log.debug(f"<<< STREAM...") + async for elem in stream: + yield elem + # ==== overridable methods ==== async def get_prompt(self) -> list[ChatMessage]: """ diff --git a/kani/models.py b/kani/models.py index a65a1b2..05c3556 100644 --- a/kani/models.py +++ b/kani/models.py @@ -5,9 +5,10 @@ import json import uuid import warnings -from typing import ClassVar, Sequence, Type, TypeAlias, Union +from functools import cached_property +from typing import Any, ClassVar, Sequence, Type, TypeAlias, Union -from pydantic import BaseModel as PydanticBase, ConfigDict, model_serializer, model_validator +from pydantic import BaseModel as PydanticBase, model_serializer, model_validator from .exceptions import MissingMessagePartType @@ -50,23 +51,23 @@ class ChatRole(enum.Enum): class FunctionCall(BaseModel): """Represents a model's request to call a single function.""" - model_config = ConfigDict(frozen=True) - name: str """The name of the requested function.""" arguments: str """The arguments to call it with, encoded in JSON.""" - @property - def kwargs(self) -> dict: - """The arguments to call the function with, with JSON decoded to a Python dict.""" + @cached_property + def kwargs(self) -> dict[str, Any]: + """The arguments to call the function with, as a Python dictionary.""" return json.loads(self.arguments) @classmethod def with_args(cls, name: str, **kwargs): """Create a function call with the given arguments (e.g. for few-shot prompting).""" - return cls(name=name, arguments=json.dumps(kwargs)) + inst = cls(name=name, arguments=json.dumps(kwargs)) + inst.__dict__["kwargs"] = kwargs # set the cached property here as a minor optimization + return inst class ToolCall(BaseModel): @@ -75,8 +76,6 @@ class ToolCall(BaseModel): See :ref:`functioncall_v_toolcall` for more information about tool calls vs function calls. """ - model_config = ConfigDict(frozen=True) - id: str """The request ID created by the engine. This should be passed back to the engine in :attr:`.ChatMessage.tool_call_id` in order to associate a FUNCTION @@ -115,8 +114,6 @@ class MessagePart(BaseModel, abc.ABC): For more information see :doc:`advanced/messageparts`. """ - model_config = ConfigDict(frozen=True) - # ==== serdes ==== # used for saving/loading - map qualname to messagepart type _messagepart_registry: ClassVar[dict[str, Type["MessagePart"]]] = {} @@ -184,20 +181,19 @@ def __str__(self): class ChatMessage(BaseModel): """Represents a message in the chat context.""" - model_config = ConfigDict(frozen=True) - def __init__(self, **kwargs): # translate a function_call into tool_calls if "function_call" in kwargs: if "tool_calls" in kwargs: - raise ValueError("Only one of function_call or tool_calls may be provided.") + raise ValueError("Only one of `function_call` or `tool_calls` may be provided.") kwargs["tool_calls"] = (ToolCall.from_function_call(kwargs.pop("function_call")),) super().__init__(**kwargs) role: ChatRole """Who said the message?""" - content: str | tuple[MessagePart | str, ...] | None + # ==== content ==== + content: str | list[MessagePart | str] | None """The data used to create this message. Generally, you should use :attr:`text` or :attr:`parts` instead.""" @property @@ -226,17 +222,22 @@ def parts(self) -> list[MessagePart | str]: return [] elif isinstance(content, str): return [content] - return list(content) + return content name: str | None = None """The name of the user who sent the message, if set (user/function messages only).""" + # ==== tool calling ==== tool_call_id: str | None = None """The ID for a requested :class:`.ToolCall` which this message is a response to (function messages only).""" - tool_calls: tuple[ToolCall, ...] | None = None + tool_calls: list[ToolCall] | None = None """The tool calls requested by the model (assistant messages only).""" + is_tool_call_error: bool | None = None + """If this is a FUNCTION message containing the results of a function call, whether the function call raised an + exception.""" + @property def function_call(self) -> FunctionCall | None: """If there is exactly one tool call to a function, return that tool call's requested function. @@ -253,6 +254,7 @@ def function_call(self) -> FunctionCall | None: ) return self.tool_calls[0].function + # ==== constructors ==== @classmethod def system(cls, content: str | Sequence[MessagePart | str], **kwargs): """Create a new system message.""" @@ -273,6 +275,7 @@ def function(cls, name: str | None, content: str | Sequence[MessagePart | str], """Create a new function message.""" return cls(role=ChatRole.FUNCTION, content=content, name=name, tool_call_id=tool_call_id, **kwargs) + # ==== helpers ==== def copy_with(self, **new_values): """Make a shallow copy of this object, updating the passed attributes (if any) to new values. @@ -284,9 +287,6 @@ def copy_with(self, **new_values): Only one of (tool_calls, function_call) may be passed and will update the other accordingly. """ # === content === - # ensure the content is immutable - if "content" in new_values and not isinstance(new_values["content"], str): - new_values["content"] = tuple(new_values["content"]) # ensure that setting either text or parts works if "text" in new_values: if "content" in new_values: @@ -295,10 +295,12 @@ def copy_with(self, **new_values): if "parts" in new_values: if "content" in new_values: raise ValueError("At most one of ('content', 'text', 'parts') can be set.") - new_values["content"] = tuple(new_values.pop("parts")) + new_values["content"] = list(new_values.pop("parts")) + # === tool calls === if "function_call" in new_values: if "tool_calls" in new_values: - raise ValueError("Only one of function_call or tool_calls may be provided.") + raise ValueError("Only one of 'function_call' or 'tool_calls' may be provided.") new_values["tool_calls"] = (ToolCall.from_function_call(new_values.pop("function_call")),) + return super().copy_with(**new_values) diff --git a/kani/prompts/__init__.py b/kani/prompts/__init__.py new file mode 100644 index 0000000..83f0b81 --- /dev/null +++ b/kani/prompts/__init__.py @@ -0,0 +1,7 @@ +""" +This submodule contains utilities to transform a list of Kani :class:`.ChatMessage` into low-level formats to be +consumed by an engine (e.g. ``str``, ``list[dict]``, or ``torch.Tensor``). +""" + +from .base import PipelineStep +from .pipeline import PromptPipeline diff --git a/kani/prompts/base.py b/kani/prompts/base.py new file mode 100644 index 0000000..db587db --- /dev/null +++ b/kani/prompts/base.py @@ -0,0 +1,101 @@ +import abc +from typing import Iterable + +from kani.models import ChatRole +from kani.prompts.types import PipelineMsgT, PredicateFilterT, RoleFilterT + + +class PipelineStep(abc.ABC): + """ + The base class for all pipeline steps. + + If needed, you can subclass this and manually add steps to a :class:`.PromptPipeline`, but this is generally not + necessary (consider using :meth:`.PromptPipeline.apply` instead). + """ + + def execute(self, msgs: list[PipelineMsgT]): + """Apply this step's effects on the pipeline.""" + raise NotImplementedError + + def explain(self) -> str: + """Return a string explaining what this step does.""" + raise NotImplementedError + + def explain_example_kwargs(self) -> dict[str, bool]: + """Return a dict of kwargs to pass to examples.build_conversation to ensure relevant examples are included.""" + return {} + + def __repr__(self): + attrs = ", ".join(f"{k}={v!r}" for k, v in self.__dict__.items()) + return f"{type(self).__name__}({attrs})" + + +class FilterMixin: + """helper mixin to implement filtering operations""" + + def __init__(self, role: RoleFilterT = None, predicate: PredicateFilterT = None): + self.role = role + self.predicate = predicate + + def filtered(self, msgs: list[PipelineMsgT]) -> Iterable[PipelineMsgT]: + """Yield all messages that match the filter""" + for msg in msgs: + if self.matches_filter(msg): + yield msg + + def matches_filter(self, msg: PipelineMsgT) -> bool: + """Whether or not a message matches the filter""" + # role(s) + if not self.matches_role(msg.role): + return False + + # predicate + if self.predicate is not None and not self.predicate(msg): + return False + + # default + return True + + def matches_role(self, role: ChatRole) -> bool: + """Whether or not this filter unconditionally matches the given role (only checks *role*, not *predicate*)""" + if isinstance(self.role, ChatRole): + return role == self.role + elif self.role: + return role in self.role + return True + + # explain helpers + def explain_note(self, join_sep="or", plural=True) -> str: + """Returns a short note with the conditions this step applies to. + (e.g. "messages" or "system messages that match the given predicate") + + :param join_sep: If the filter applies to more than one role, what word to use to join the role names + :param plural: e.g. "each message" vs "all messages" + """ + out = "messages" if plural else "message" + + # role(s) + if isinstance(self.role, ChatRole): + out = f"{self.role.value} {out}" + elif self.role: + msgs = natural_join([r.value for r in self.role], join_sep) + out = f"{msgs} {out}" + + # predicate + if self.predicate is not None: + out += f" that {'match' if plural else 'matches'} the given predicate" + + return out + + # by default, let's include function call if any filtered step targets functions + def explain_example_kwargs(self) -> dict[str, bool]: + if self.matches_role(ChatRole.FUNCTION): + return {"function_call": True} + return {} + + +def natural_join(elems: list[str], sep: str): + sep = f" {sep} " + if len(elems) < 3: + return sep.join(elems) + return ", ".join(elems[:-1]) + f",{sep}{elems[-1]}" diff --git a/kani/prompts/docutils.py b/kani/prompts/docutils.py new file mode 100644 index 0000000..6496718 --- /dev/null +++ b/kani/prompts/docutils.py @@ -0,0 +1,26 @@ +import textwrap + +PARAM_DOCS = { + "ROLE_PARAM": ( + # ":type role: ChatRole | list[ChatRole]\n" + ":param role: The role (if a single role is given) or roles (if a list is given) to apply this operation to. If" + " not set, ignores the role of the message." + ), + "PREDICATE_PARAM": ( + # ":type predicate: Callable[[ChatMessage], bool]\n" + ":param predicate: A function that takes a :class:`.ChatMessage` and returns a boolean specifying whether to" + " operate on this message or not." + ), + "MULTIFILTER": ( + "If multiple filter params are supplied, this method will only operate on messages that match ALL of the" + " filters." + ), +} + +PARAM_DOCS["ALL_FILTERS"] = "\n\n".join(PARAM_DOCS.values()) + + +def autoparams(f): + """A decorator to template in the RST-formatted param docs above into docstrings.""" + f.__doc__ = textwrap.dedent(f.__doc__).format(**PARAM_DOCS) + return f diff --git a/kani/prompts/examples.py b/kani/prompts/examples.py new file mode 100644 index 0000000..759974f --- /dev/null +++ b/kani/prompts/examples.py @@ -0,0 +1,95 @@ +import inspect + +from kani.models import ChatMessage, FunctionCall, ToolCall + + +def build_conversation( + *, + function_call=False, + consecutive_user=False, + consecutive_assistant=False, + consecutive_system=False, + multi_function_call=False, + end_on_assistant=False, +) -> list[list[ChatMessage]]: + """Return a list of groups of messages for a test case""" + function_call_msgs = [] + consecutive_user_msgs = [] + consecutive_assistant_msgs = [] + consecutive_system_msgs = [] + multi_function_call_msgs = [] + end_on_assistant_msgs = [] + + if function_call: + # with and without content + function_call_msgs = [ + ChatMessage.user("What's the weather in Tokyo?"), + ChatMessage.assistant( + content=None, function_call=FunctionCall.with_args("get_weather", location="Tokyo, JP", unit="celsius") + ), + ChatMessage.function("get_weather", "Weather in Tokyo, JP: Partly cloudy, 21 degrees celsius."), + ChatMessage.assistant("It's partly cloudy and 21 degrees in Tokyo."), + ChatMessage.user("In Fahrenheit please."), + ChatMessage.assistant( + content="Let me check that.", + function_call=FunctionCall.with_args("get_weather", location="Tokyo, JP", unit="fahrenheit"), + ), + ChatMessage.function("get_weather", "Weather in Tokyo, JP: Partly cloudy, 70 degrees fahrenheit."), + ChatMessage.assistant("It's partly cloudy and 70 degrees in Tokyo."), + ] + + if consecutive_user: + consecutive_user_msgs = [ + ChatMessage.user("What does kani mean?"), + ChatMessage.user("It's in Japanese."), + ChatMessage.assistant("Kani means 'crab' in Japanese."), + ] + + if consecutive_assistant: + consecutive_assistant_msgs = [ + ChatMessage.user("Ctrl-C!"), + ChatMessage.assistant("Hey, don't"), + ChatMessage.assistant("interrupt me!"), + ] + + if consecutive_system: + consecutive_system_msgs = [ChatMessage.system("Please follow the following instructions:")] + + if multi_function_call: + multi_function_call_msgs = [ + ChatMessage.user("Please make me coffee and a bowl of cereal."), + ChatMessage.assistant( + content=None, + tool_calls=[ + tc := ToolCall.from_function("make_food", food="coffee", appliance="teapot"), + tc2 := ToolCall.from_function("make_food", food="cereal", dish="bowl"), + ], + ), + ChatMessage.function("make_food", "Error 428: I'm a teapot!", tc.id), + ChatMessage.function("make_food", "Beep boop. Cerealbot has been dispatched.", tc2.id), + ChatMessage.assistant("You need a coffee pot to make coffee, but Cerealbot is on the way."), + ] + + if end_on_assistant: + end_on_assistant_msgs = [ChatMessage.assistant("African or European?")] + + grps = [ + [ + *consecutive_system_msgs, + ChatMessage.system("You are a helpful assistant."), + ChatMessage.user("Hello there."), + ChatMessage.assistant("Hi! How can I help?"), + ], + function_call_msgs, + multi_function_call_msgs, + consecutive_user_msgs, + consecutive_assistant_msgs, + [ + ChatMessage.user("What is the airspeed velocity of an unladen swallow?"), + *end_on_assistant_msgs, + ], + ] + return [grp for grp in grps if grp] + + +ALL_EXAMPLE_KWARGS = list(inspect.signature(build_conversation).parameters.keys()) diff --git a/kani/prompts/impl/__init__.py b/kani/prompts/impl/__init__.py new file mode 100644 index 0000000..45a3ee5 --- /dev/null +++ b/kani/prompts/impl/__init__.py @@ -0,0 +1,7 @@ +"""Implementations of pipelines for popular models.""" + +from .gemma import GEMMA_PIPELINE +from .llama2 import LLAMA2_PIPELINE +from .vicuna import VICUNA_PIPELINE + +MISTRAL_PIPELINE = LLAMA2_PIPELINE diff --git a/kani/prompts/impl/cohere.py b/kani/prompts/impl/cohere.py new file mode 100644 index 0000000..9c0acdc --- /dev/null +++ b/kani/prompts/impl/cohere.py @@ -0,0 +1,225 @@ +import inspect +import json + +from kani import AIFunction +from kani.models import ChatMessage, ChatRole +from kani.prompts.pipeline import PromptPipeline +from kani.prompts.steps import Apply, ConversationFmt, MergeConsecutive, Remove + +# ==== default prompts ==== +# fmt: off +DEFAULT_SYSTEM_PROMPT = "You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere." + +DEFAULT_PREAMBLE = """# Safety Preamble +The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral. + +# System Preamble +## Basic Rules +You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions. + +# User Preamble +""" +DEFAULT_TASK = """## Task and Context +You help people answer their questions and other requests interactively. You will be asked a very wide array of requests on all kinds of topics. You will be equipped with a wide range of search engines or similar tools to help you, which you use to research your answer. You should focus on serving the user's needs as best you can, which will be wide-ranging. + +## Style Guide +Unless the user asks for a different style of answer, you should answer in full sentences, using proper grammar and spelling.""" + +DEFAULT_TOOL_PROMPT = ''' + +## Available Tools +Here is a list of tools that you have available to you: + +{user_functions} + +```python +def directly_answer() -> List[Dict]: + """Calls a standard (un-augmented) AI chatbot to generate a response given the conversation history + """ + pass +``` +''' +DEFAULT_TOOL_INSTRUCTIONS = """Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example: +```json +[ + { + "tool_name": title of the tool in the specification, + "parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters + } +] +```""" + +DEFAULT_RAG_INSTRUCTIONS_ACC = """Carefully perform the following instructions, in order, starting each with a new line. +Firstly, Decide which of the retrieved documents are relevant to the user's last input by writing 'Relevant Documents:' followed by comma-separated list of document numbers. If none are relevant, you should instead write 'None'. +Secondly, Decide which of the retrieved documents contain facts that should be cited in a good answer to the user's last input by writing 'Cited Documents:' followed a comma-separated list of document numbers. If you dont want to cite any of them, you should instead write 'None'. +Thirdly, Write 'Answer:' followed by a response to the user's last input in high quality natural english. Use the retrieved documents to help you. Do not insert any citations or grounding markup. +Finally, Write 'Grounded answer:' followed by a response to the user's last input in high quality natural english. Use the symbols and to indicate when a fact comes from a document in the search result, e.g my fact for a fact from document 0.""" +DEFAULT_RAG_INSTRUCTIONS_FAST = """Carefully perform the following instructions, in order, starting each with a new line. +Firstly, Decide which of the retrieved documents are relevant to the user's last input by writing 'Relevant Documents:' followed by comma-separated list of document numbers. If none are relevant, you should instead write 'None'. +Secondly, Decide which of the retrieved documents contain facts that should be cited in a good answer to the user's last input by writing 'Cited Documents:' followed a comma-separated list of document numbers. If you dont want to cite any of them, you should instead write 'None'. +Finally, Write 'Grounded answer:' followed by a response to the user's last input in high quality natural english. Use the symbols and to indicate when a fact comes from a document in the search result, e.g my fact for a fact from document 0.""" +# fmt: on + +# ==== no tool calling ==== +COMMAND_R_PIPELINE = ( + PromptPipeline() + .conversation_fmt( + prefix="", + generation_suffix="<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", + user_prefix="<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", + user_suffix="<|END_OF_TURN_TOKEN|>", + assistant_prefix="<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", + assistant_suffix="<|END_OF_TURN_TOKEN|>", + assistant_suffix_if_last="", + system_prefix="<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>", + system_suffix="<|END_OF_TURN_TOKEN|>", + function_prefix="<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>\n\n", + function_suffix="\n<|END_OF_TURN_TOKEN|>", + ) +) # fmt: skip +"""The pipeline to use when interfacing with Command R without tools defined.""" + + +# ==== tool calling (function not last) ==== +def function_result_joiner(msgs): + contents = [] + for idx, msg in enumerate(msgs): + contents.append(f"Document: {idx}\n{msg.text}") + return "\n\n".join(contents) + + +def tool_call_formatter(msg: ChatMessage) -> str: + if msg.tool_calls: + text = msg.text + "\n" if msg.text else "" + tool_calls = json.dumps( + [{"tool_name": tc.function.name, "parameters": tc.function.kwargs} for tc in msg.tool_calls], + indent=4, + ) + return f"{text}Action: ```json\n{tool_calls}\n```" + # else: + # return ( # is the EOT/SOT token doing weird stuff here? + # 'Action: ```json\n[\n {\n "tool_name": "directly_answer",\n "parameters": {}\n' + # f" }}\n]\n```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{msg.text}" + # ) + return msg.content + + +def build_tool_pipeline( + *, include_function_calls=True, include_function_results=True, tool_instructions=DEFAULT_TOOL_INSTRUCTIONS +): + """ + The pipeline to use when interfacing with Command R WITH tools defined. Use this pipeline if the last message is + NOT a FUNCTION message. + + :param include_function_calls: Whether to include previous turns' function calls or just the model's answers. + :param include_function_results: Whether to include the results of previous turns' function calls in the context. + :param tool_instructions: The system prompt to send just before the model's generation turn that includes + instructions on the format to generate tool calls in. Generally you shouldn't change this. + """ + + steps = [] + + # format function calls with an Action: prefix + if include_function_calls: + + def apply_tc_format(msg): + msg.content = tool_call_formatter(msg) + return msg + + steps.append(Apply(apply_tc_format, role=ChatRole.ASSISTANT)) + else: + steps.append(Remove(role=ChatRole.ASSISTANT, predicate=lambda msg: msg.content is None)) + + # keep function results around as SYSTEM messages + if include_function_results: + steps.append(MergeConsecutive(role=ChatRole.FUNCTION, joiner=function_result_joiner)) + else: + steps.append(Remove(role=ChatRole.FUNCTION)) + + steps.append( + ConversationFmt( + prefix="", + generation_suffix=( + f"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{tool_instructions}<|END_OF_TURN_TOKEN|>" + "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" + ), + user_prefix="<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", + user_suffix="<|END_OF_TURN_TOKEN|>", + assistant_prefix="<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", + assistant_suffix="<|END_OF_TURN_TOKEN|>", + assistant_suffix_if_last="", + system_prefix="<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>", + system_suffix="<|END_OF_TURN_TOKEN|>", + function_prefix="<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>\n\n", + function_suffix="\n<|END_OF_TURN_TOKEN|>", + ) + ) + + return PromptPipeline(steps) + + +# ==== tool calling (function last) ==== +def build_rag_pipeline(*, include_previous_results=True, rag_instructions=DEFAULT_RAG_INSTRUCTIONS_ACC): + """ + The pipeline to use when interfacing with Command R WITH tools defined. Use this pipeline if the last message IS a + FUNCTION message. + + :param include_previous_results: Include previous turns' results in the chat history. + :param rag_instructions: The system prompt to send just before the model's generation turn that includes + instructions on the format to generate the result in. Can be None to only generate a model turn. Defaults + to the "accurate" grounded RAG prompt (``from kani.prompts.impl.cohere import DEFAULT_RAG_INSTRUCTIONS_ACC``). + """ + + def remover(m, is_last): + return None if is_last and not include_previous_results else m + + return ( + PromptPipeline() + .merge_consecutive(role=ChatRole.FUNCTION, joiner=function_result_joiner) + # remove all but the last function message + .apply(remover, role=ChatRole.FUNCTION) + # remove asst messages with no content (function calls) + .remove(role=ChatRole.ASSISTANT, predicate=lambda msg: msg.content is None) + .conversation_fmt( + prefix="", + generation_suffix=( + f"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{rag_instructions}<|END_OF_TURN_TOKEN|>" + "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" + if rag_instructions + else "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" + ), + user_prefix="<|START_OF_TURN_TOKEN|><|USER_TOKEN|>", + user_suffix="<|END_OF_TURN_TOKEN|>", + assistant_prefix="<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", + assistant_suffix="<|END_OF_TURN_TOKEN|>", + assistant_suffix_if_last="", + system_prefix="<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>", + system_suffix="<|END_OF_TURN_TOKEN|>", + function_prefix="<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>\n\n", + function_suffix="\n<|END_OF_TURN_TOKEN|>", + ) + ) + + +# ==== helpers ==== +def function_prompt(f: AIFunction) -> str: + params = f.get_params() + + # build params + param_parts = [] + for param in params: + param_parts.append(str(param)) + params_str = ", ".join(param_parts) + + # build docstring + args = "" + if len(params): + args = f"\n\n Args:\n " + doc_params = [] + for param in params: + desc = f": {param.description}" if param.description else "" + doc_params.append(f"{param.name} ({inspect.formatannotation(param.type)}){desc}") + args += "\n ".join(doc_params) + + # return + return f'```python\ndef {f.name}({params_str}) -> List[Dict]:\n """{f.desc}{args}\n """\n pass\n```' diff --git a/kani/prompts/impl/gemma.py b/kani/prompts/impl/gemma.py new file mode 100644 index 0000000..befb61a --- /dev/null +++ b/kani/prompts/impl/gemma.py @@ -0,0 +1,43 @@ +from kani.models import ChatRole +from kani.prompts.pipeline import PromptPipeline + +GEMMA_PIPELINE = ( + PromptPipeline() + .translate_role( + role=(ChatRole.SYSTEM, ChatRole.FUNCTION), + to=ChatRole.USER, + warn=( + "The Gemma prompt format does not natively support the SYSTEM or FUNCTION roles. These messages will be" + " sent to the model as USER messages." + ), + ) + .conversation_fmt( + prefix="", + generation_suffix="model\n", + user_prefix="user\n", + user_suffix="\n", + assistant_prefix="model\n", + assistant_suffix="\n", + assistant_suffix_if_last="", + ) +) + +# from https://huggingface.co/google/gemma-1.1-2b-it/blob/main/tokenizer_config.json: +# {{ bos_token }} +# {% if messages[0]['role'] == 'system' %} +# {{ raise_exception('System role not supported') }} +# {% endif %} +# {% for message in messages %} +# {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} +# {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} +# {% endif %} +# {% if (message['role'] == 'assistant') %} +# {% set role = 'model' %} +# {% else %} +# {% set role = message['role'] %} +# {% endif %} +# {{ '' + role + '\n' + message['content'] | trim + '\n' }} +# {% endfor %} +# {% if add_generation_prompt %} +# {{'model\n'}} +# {% endif %} diff --git a/kani/prompts/impl/llama2.py b/kani/prompts/impl/llama2.py new file mode 100644 index 0000000..b4b795c --- /dev/null +++ b/kani/prompts/impl/llama2.py @@ -0,0 +1,32 @@ +"""Common builder for the LLaMAv2-chat prompt. + +This file is responsible for implementing the common non-strict ChatMessage to tokens translation, while handling +the nuance of the INST and SYS tokens as best as possible. +""" + +from kani.models import ChatRole +from kani.prompts.pipeline import PromptPipeline + +LLAMA2_PIPELINE = ( + PromptPipeline() + # System messages should be wrapped with this tag. We'll translate them to USER + # messages since a system and user message go together in a single [INST] pair. + .wrap(role=ChatRole.SYSTEM, prefix="<>\n", suffix="\n<>\n") + .translate_role(role=ChatRole.SYSTEM, to=ChatRole.USER) + # If we see two consecutive USER messages, merge them together into one with a + # newline in between. + .merge_consecutive(role=ChatRole.USER, sep="\n") + # Similarly for ASSISTANT, but with a space (kani automatically strips whitespace from the ends of + # generations). + .merge_consecutive(role=ChatRole.ASSISTANT, sep=" ") + # Finally, wrap USER and ASSISTANT messages in the instruction tokens. If our + # message list ends with an ASSISTANT message, don't add the EOS token + # (we want the model to continue the generation). + .conversation_fmt( + user_prefix="[INST] ", + user_suffix=" [/INST]", + assistant_prefix=" ", + assistant_suffix=" ", + assistant_suffix_if_last="", + ) +) # fmt: skip diff --git a/kani/prompts/impl/vicuna.py b/kani/prompts/impl/vicuna.py new file mode 100644 index 0000000..e4365d8 --- /dev/null +++ b/kani/prompts/impl/vicuna.py @@ -0,0 +1,18 @@ +"""Common builder for the Vicuna chat prompt.""" + +from kani.models import ChatRole +from kani.prompts.pipeline import PromptPipeline + +VICUNA_PIPELINE = ( + PromptPipeline() + .merge_consecutive(role=(ChatRole.USER, ChatRole.FUNCTION), sep="\n", out_role=ChatRole.USER) + .merge_consecutive(role=ChatRole.ASSISTANT, sep=" ") + .conversation_fmt( + sep="\n", + user_prefix="USER: ", + assistant_prefix="ASSISTANT: ", + assistant_suffix="", + assistant_suffix_if_last="", + generation_suffix="ASSISTANT:" + ) +) # fmt: skip diff --git a/kani/prompts/pipeline.py b/kani/prompts/pipeline.py new file mode 100644 index 0000000..a7dcb24 --- /dev/null +++ b/kani/prompts/pipeline.py @@ -0,0 +1,472 @@ +import functools +import itertools +import operator +import pprint +import time +from typing import Any, Callable, Generic, TypeVar, overload + +from kani.models import ChatMessage, ChatRole +from kani.prompts.base import PipelineStep +from kani.prompts.docutils import autoparams +from kani.prompts.examples import ALL_EXAMPLE_KWARGS, build_conversation +from kani.prompts.steps import ( + Apply, + ConversationDict, + ConversationFmt, + EnsureBoundFunctionCalls, + EnsureStart, + FunctionCallFmt, + MergeConsecutive, + Remove, + TranslateRole, + Wrap, +) +from kani.prompts.types import ( + ApplyCallableT, + ApplyResultT, + FunctionCallStrT, + MessageContentT, + PipelineMsgT, + PredicateFilterT, + RoleFilterT, +) + +# pre python3.11 Self type - let's just call it a PromptPipeline +try: + from typing import Self +except ImportError: + # noinspection PyTypeHints + # this returns a forwardref to the class since it is unlikely that it will be subclassed, and it makes sphinx + # sad to use a generic typevar here + Self = "PromptPipeline" + +# use a generic to specify the return type of the pipeline +T = TypeVar("T") + + +class PromptPipeline(Generic[T]): + r""" + This class creates a reproducible pipeline for translating a list of :class:`.ChatMessage` into an engine-specific + format using fluent-style chaining. + + To build a pipeline, create an instance of ``PromptPipeline()`` and add steps by calling the step methods documented + below. Most pipelines will end with a call to one of the terminals, which translates the intermediate form into + the desired output format. + + **Usage** + + To use the pipeline, call the created pipeline object with a list of kani chat messages. + + To inspect the inputs/outputs of your pipeline, you can use :meth:`explain` to print a detailed explanation of the + pipeline and multiple examples (selected based on the pipeline steps). + + **Example** + + Here's an example using the PromptPipeline to build a LLaMA 2 chat-style prompt: + + .. code-block:: python + + from kani import PromptPipeline, ChatRole + + pipe = ( + PromptPipeline() + + # System messages should be wrapped with this tag. We'll translate them to USER + # messages since a system and user message go together in a single [INST] pair. + .wrap(role=ChatRole.SYSTEM, prefix="<>\n", suffix="\n<>\n") + .translate_role(role=ChatRole.SYSTEM, to=ChatRole.USER) + + # If we see two consecutive USER messages, merge them together into one with a + # newline in between. + .merge_consecutive(role=ChatRole.USER, sep="\n") + # Similarly for ASSISTANT, but with a space (kani automatically strips whitespace from the ends of + # generations). + .merge_consecutive(role=ChatRole.ASSISTANT, sep=" ") + + # Finally, wrap USER and ASSISTANT messages in the instruction tokens. If our + # message list ends with an ASSISTANT message, don't add the EOS token + # (we want the model to continue the generation). + .conversation_fmt( + user_prefix="[INST] ", + user_suffix=" [/INST]", + assistant_prefix=" ", + assistant_suffix=" ", + assistant_suffix_if_last="", + ) + ) + + # We can see what this pipeline does by calling explain()... + pipe.explain() + + # And use it in our engine to build a string prompt for the LLM. + prompt = pipe(ai.get_prompt()) + """ + + def __init__(self, steps: list[PipelineStep] = None): + if steps is None: + steps = [] + self.steps = steps + + # ==== steps ==== + @overload + def translate_role( + self, *, to: ChatRole, warn: str = None, role: RoleFilterT = None, predicate: PredicateFilterT = None + ) -> Self: ... + + @autoparams + def translate_role(self, **kwargs): + """ + Change the role of the matching messages. + (e.g. for models which do not support native function calling, make all FUNCTION messages a USER message) + + :param to: The new role to translate the matching messages to. + :param warn: A warning to emit if any messages are translated (e.g. if a model does not support certain roles). + {ALL_FILTERS} + """ + self.steps.append(TranslateRole(**kwargs)) + return self + + @overload + def wrap( + self, *, prefix: str = None, suffix: str = None, role: RoleFilterT = None, predicate: PredicateFilterT = None + ) -> Self: ... + + @autoparams + def wrap(self, **kwargs): + """ + Wrap the matching messages with a given string prefix and/or suffix. + + For more fine-grained control over user/assistant message pairs as the last step in a pipeline, use + :meth:`conversation_fmt` instead. + + :param prefix: The prefix to add before each matching message, if any. + :param suffix: The suffix to add after each matching message, if any. + {ALL_FILTERS} + """ + self.steps.append(Wrap(**kwargs)) + return self + + @overload + def merge_consecutive( + self, + *, + sep: str = None, + joiner: Callable[[list[PipelineMsgT]], MessageContentT] = None, + out_role: ChatRole = None, + role: RoleFilterT = None, + predicate: PredicateFilterT = None, + ) -> Self: ... + + @autoparams + def merge_consecutive(self, **kwargs): + r""" + If multiple messages that match are found consecutively, merge them by either joining their contents with a + string or call a joiner function. + + .. caution:: + If multiple roles are specified, this method will merge them as a group (e.g. if ``role=(USER, ASSISTANT)``, + a USER message followed by an ASSISTANT message will be merged together into one with a role of + ``out_role``). + + Similarly, if a predicate is specified, this method will merge all consecutive messages which match the + given predicate. + + :param sep: The string to add between each matching message. Mutually exclusive with ``joiner``. + If this is set, this is roughly equivalent to ``joiner=lambda msgs: sep.join(m.text for m in msgs)``. + :param joiner: A function that will take a list of all messages in a consecutive group and return the final + string. Mutually exclusive with ``sep``. + :param out_role: The role of the merged message to use. This is required if multiple ``role``\ s are specified + or ``role`` is not set; otherwise it defaults to the common role of the merged messages. + {ALL_FILTERS} + """ + self.steps.append(MergeConsecutive(**kwargs)) + return self + + @overload + def function_call_fmt( + self, func: FunctionCallStrT, *, prefix: str = "\n", sep: str = "", suffix: str = "" + ) -> Self: ... + + def function_call_fmt(self, *args, **kwargs): + """ + For each message with one or more requested tool calls, call the provided function on each requested tool call + and append it to the message's content. + + :param func: A function taking a :class:`.ToolCall` and returning a string to append to the content of the + message containing the requested call, or None to ignore the tool call. + :param prefix: If at least one tool call is formatted, a prefix to insert after the message's contents and + before the formatted string. + :param sep: If two or more tool calls are formatted, the string to insert between them. + :param suffix: If at least one tool call is formatted, a suffix to insert after the formatted string. + """ + self.steps.append(FunctionCallFmt(*args, **kwargs)) + return self + + # removers + @overload + def remove(self, *, role: RoleFilterT = None, predicate: PredicateFilterT = None) -> Self: ... + + @autoparams + def remove(self, **kwargs): + """ + Remove all messages that match the filters from the output. + + {ALL_FILTERS} + """ + self.steps.append(Remove(**kwargs)) + return self + + @overload + def ensure_start(self, *, role: RoleFilterT = None, predicate: PredicateFilterT = None) -> Self: ... + + @autoparams + def ensure_start(self, **kwargs): + """ + Ensure that the output starts with a message with the given role by removing all messages from the start that + do NOT match the given filters, such that the first message in the output matches. + + This should NOT be used to ensure that a system prompt is passed; the intent of this step is to prevent + an orphaned FUNCTION result or ASSISTANT reply after earlier messages were context-managed out. + + {ALL_FILTERS} + """ + self.steps.append(EnsureStart(**kwargs)) + return self + + @overload + def ensure_bound_function_calls(self) -> Self: ... + + @autoparams + def ensure_bound_function_calls(self): + """ + Ensure that each FUNCTION message is preceded by an ASSISTANT message requesting it, and that each FUNCTION + message's ``tool_call_id`` matches the request. If a FUNCTION message has no ``tool_call_id`` (e.g. a few-shot + prompt), bind it to a preceding ASSISTANT message if it is unambiguous. + + This is generally only useful if you end the pipeline with :meth:`conversation_dict`. + + Will remove hanging FUNCTION messages (i.e. messages where the corresponding request was managed out of the + model's context) from the beginning of the prompt if necessary. + + :raises PromptError: if it is impossible to bind each function call to a request unambiguously. + """ + self.steps.append(EnsureBoundFunctionCalls()) + return self + + @overload + def apply( + self, func: ApplyCallableT, *, role: RoleFilterT = None, predicate: PredicateFilterT = None + ) -> "PromptPipeline[list[ApplyResultT]]": ... + + @autoparams + def apply(self, *args, **kwargs): + """ + Apply the given function to all matched messages. Replace the message with the function's return value. + + The function may take 1-3 positional parameters: the first will always be the matched message at the current + pipeline step, the second will be whether or not the message is the last one in the list of messages, and the + third will be the index of the message in the list of messages. + + :param func: A function that takes 1-3 positional parameters ``(msg, is_last, idx)`` that will be called + on each matching message. If this function does not return a :class:`ChatMessage`, it should be the last + step in the pipeline. If this function returns ``None``, the input message will be removed from the output. + {ALL_FILTERS} + """ + self.steps.append(Apply(*args, **kwargs)) + return self + + # ==== terminals ==== + @overload + def conversation_fmt( + self, + *, + # general formatting + prefix: str = "", + sep: str = "", + suffix: str = "", + generation_suffix: str = "", + # message-specific formatting + # USER messages + user_prefix: str = "", + user_suffix: str = "", + # ASSISTANT messages + assistant_prefix: str = "", + assistant_suffix: str = "", + assistant_suffix_if_last: str = None, + # SYSTEM messages + system_prefix: str = "", + system_suffix: str = "", + # FUNCTION messages (if not specified, defaults to user) + function_prefix: str = None, + function_suffix: str = None, + ) -> "PromptPipeline[str]": ... + + def conversation_fmt(self, **kwargs): + """ + Takes in the list of messages and joins them into a single conversation-formatted string by: + + * wrapping messages with the defined prefixes/suffixes by role + * joining the messages' contents with the defined *sep* + * adding a generation suffix, if necessary. + + This method should be the last step in a pipeline and will cause the pipeline to return a :class:`str`. + + :param prefix: A string to insert once before the rest of the prompt, unconditionally. + :param sep: A string to insert between messages, if any. Similar to ``sep.join(...)``. + :param suffix: A string to insert once after the rest of the prompt, unconditionally. + :param generation_suffix: A string to add to the end of the prompt to prompt the model to begin its turn. + :param user_prefix: A prefix to add before each USER message. + :param user_suffix: A suffix to add after each USER message. + :param assistant_prefix: A prefix to add before each ASSISTANT message. + :param assistant_suffix: A suffix to add after each ASSISTANT message. + :param assistant_suffix_if_last: If not None and the prompt ends with an ASSISTANT message, this string will be + added to the end of the prompt *instead* of the ``assistant_suffix + generation_suffix``. + This is intended to allow consecutive ASSISTANT messages to continue generation from an unfinished prior + message. + :param system_prefix: A prefix to add before each SYSTEM message. + :param system_suffix: A suffix to add after each SYSTEM message. + :param function_prefix: A prefix to add before each FUNCTION message. + :param function_suffix: A suffix to add after each FUNCTION message. + """ + self.steps.append(ConversationFmt(**kwargs)) + return self + + @overload + def conversation_dict( + self, + *, + system_role: str = "system", + user_role: str = "user", + assistant_role: str = "assistant", + function_role: str = "tool", + content_transform: Callable[[ChatMessage], Any] = lambda msg: msg.text, + additional_keys: Callable[[ChatMessage], dict] = lambda msg: {}, + ) -> "PromptPipeline[list[dict[str, Any]]]": ... + + def conversation_dict(self, **kwargs): + """ + Takes in the list of messages and returns a list of dictionaries with ("role", "content") keys. + + By default, the "role" key will be "system", "user", "assistant", or "tool" unless the respective role + override is specified. + + By default, the "content" key will be ``message.text`` unless the ``content_transform`` argument is specified. + + This method should be the last step in a pipeline and will cause the pipeline to return a ``list[dict]``. + + :param system_role: The role to give to SYSTEM messages (default "system"). + :param user_role: The role to give to USER messages (default "user"). + :param assistant_role: The role to give to ASSISTANT messages (default "assistant"). + :param function_role: The role to give to FUNCTION messages (default "tool"). + :param content_transform: A function taking in the message and returning the contents of the "content" key + (defaults to ``msg.text``). + :param additional_keys: A function taking in the message and returning a dictionary containing any additional + keys to add to the message's dict. + """ + self.steps.append(ConversationDict(**kwargs)) + return self + + # ==== eval ==== + def __call__(self, msgs: list[ChatMessage]) -> T: + """ + Apply the pipeline to a list of kani messages. The return type will vary based on the steps in the pipeline; + if no steps are defined the return type will be a copy of the input messages. + """ + return self.execute(msgs) + + def execute(self, msgs: list[ChatMessage], *, deepcopy=False, for_measurement=False) -> T: + """ + Apply the pipeline to a list of kani messages. The return type will vary based on the steps in the pipeline; + if no steps are defined the return type will be a copy of the input messages. + + This lower-level method offers more fine-grained control over the steps that are run (e.g. to measure the + length of a single message). + + :param deepcopy: Whether to deep-copy each message before running the pipeline. + :param for_measurement: If the pipeline is being run to measure the length of a single message. In this case, + any ``ensure_start`` steps will be ignored. + """ + # let's use the lower-level model_copy() since we aren't changing anything + data = [m.model_copy(deep=deepcopy) for m in msgs] + + # and apply the pipeline + for step in self.steps: + # for measurement: ignore ensure_start + if for_measurement and isinstance(step, EnsureStart): + continue + + # apply step + data = step.execute(data) + + # return the result + return data + + # ==== utils ==== + def explain(self, example: list[ChatMessage] = None, *, all_cases=False, **kwargs): + """ + Print out a summary of the pipeline and an example conversation transformation based on the steps in the + pipeline. + + .. caution:: + This method will run the pipeline on an example constructed based on the steps in this pipeline. You may + encounter unexpected side effects if your pipeline uses :meth:`apply` with a function with side effects. + """ + hdg = f"Prompt Pipeline ({len(self.steps)} steps)" + print(f"{hdg}\n{'=' * len(hdg)}") + listwidth = len(str(len(self.steps))) + for idx, step in enumerate(self.steps): + print(f"{idx + 1:>{listwidth}}. {step.explain()}") + print() + + # example + print("Example\n-------") + if all_cases: + example_kwargs = {k: True for k in ALL_EXAMPLE_KWARGS} + else: + example_kwargs = functools.reduce(operator.or_, (s.explain_example_kwargs() for s in self.steps), kwargs) + + examples_msg_grps = build_conversation(**example_kwargs) + examples_msgs = list(itertools.chain.from_iterable(examples_msg_grps)) + + # run and time example + start = time.perf_counter() + example_out = self(examples_msgs) + end = time.perf_counter() + exec_time = end - start + + # print example i/o + print(f"*Execution time: {exec_time * 1000:.3}ms*") + print("### Input\n```py\n[", end="") + for grp in examples_msg_grps: + print() + for msg in grp: + print(f" {msg!r}") + print("]\n```\n") + + print("### Output") + if isinstance(example_out, str): + print("```text") + print(example_out) + else: + print("```py") + pprint.pprint(example_out) + print("```\n") + + # print note + unused_kwargs = [k for k in ALL_EXAMPLE_KWARGS if k not in example_kwargs] + if unused_kwargs and not example: + print( + "### Note\nSome edge cases are not represented in this example. To view all test cases, use" + f" `.explain(all_cases=True)` or set any of the following keyword arguments: {unused_kwargs}.\n" + "You may also specify your own test case with `.explain(example=...)`." + ) + + def __repr__(self): + steps_fmt = pprint.pformat(self.steps) + return f"{type(self).__name__}({steps_fmt})" + + def __len__(self): + return len(self.steps) + + def __getitem__(self, item): + return self.steps[item] diff --git a/kani/prompts/steps.py b/kani/prompts/steps.py new file mode 100644 index 0000000..4760214 --- /dev/null +++ b/kani/prompts/steps.py @@ -0,0 +1,415 @@ +import inspect +import itertools +import warnings +from typing import Any, Callable + +from kani.exceptions import PromptError +from kani.models import ChatMessage, ChatRole +from kani.prompts.base import FilterMixin, PipelineStep +from kani.prompts.types import ApplyCallableT, ApplyResultT, FunctionCallStrT, MessageContentT, PipelineMsgT + + +# ==== steps ==== +class TranslateRole(FilterMixin, PipelineStep): + def __init__(self, *, to: ChatRole, warn: str = None, **filter_kwargs): + super().__init__(**filter_kwargs) + self.to = to + self.warn = warn + + def execute(self, msgs: list[PipelineMsgT]) -> list[PipelineMsgT]: + has_warned = False + for msg in self.filtered(msgs): + if self.warn and not has_warned: + has_warned = True + warnings.warn(self.warn) + msg.role = self.to + return msgs + + def explain(self) -> str: + return f"Translate all {self.explain_note()} to {self.to.value} messages" + + +class Wrap(FilterMixin, PipelineStep): + def __init__(self, *, prefix: str = None, suffix: str = None, **filter_kwargs): + super().__init__(**filter_kwargs) + self.prefix = prefix + self.suffix = suffix + + def execute(self, msgs: list[PipelineMsgT]) -> list[PipelineMsgT]: + for msg in self.filtered(msgs): + _wrap_content_inplace(msg, self.prefix, self.suffix) + return msgs + + def explain(self) -> str: + if self.prefix and self.suffix: + return f"Wrap each {self.explain_note('or', plural=False)} with {self.prefix!r} and {self.suffix!r}" + elif self.prefix: + return f"Add {self.prefix!r} to the start of each {self.explain_note('or', plural=False)}" + elif self.suffix: + return f"Add {self.suffix!r} to the end of each {self.explain_note('or', plural=False)}" + return "This step doesn't do anything as configured - you should supply a prefix or suffix." + + +class MergeConsecutive(FilterMixin, PipelineStep): + def __init__( + self, + *, + sep: str = None, + joiner: Callable[[list[PipelineMsgT]], MessageContentT] = None, + out_role: ChatRole = None, + **filter_kwargs, + ): + super().__init__(**filter_kwargs) + if sep is not None and joiner is not None: + raise ValueError("Only one of (sep, joiner) may be set.") + if sep is None and joiner is None: + raise ValueError("You must set at least one of (sep, joiner).") + + self.sep = sep + self.joiner = joiner + self.out_role = out_role + + if self.out_role is None: + if isinstance(self.role, ChatRole): + self.out_role = self.role + else: + raise ValueError( + "You must supply an `out_role` if a `.merge_consecutive` pipeline step can match messages with" + " different roles." + ) + + def execute(self, msgs: list[PipelineMsgT]) -> list[PipelineMsgT]: + out = [] + + # group messages by whether they match the filter, putting consecutive matching messages into a list + for matching, group_msgs in itertools.groupby(msgs, key=self.matches_filter): + group_msgs = list(group_msgs) + # nonmatching messages get sent to output unchanged + if not matching: + out.extend(group_msgs) + # >1 consecutive matching messages get merged + else: + out.append(self._do_merge(group_msgs)) + + return out + + def explain(self) -> str: + if self.sep: + how = f"by inserting {self.sep!r} between them" + else: + how = "by calling the given function" + + if self.out_role: + return f"Merge consecutive {self.explain_note()} into a single {self.out_role.value} message {how}" + return f"Merge consecutive {self.explain_note()} into a single message {how}" + + def explain_example_kwargs(self) -> dict[str, bool]: + kwargs = super().explain_example_kwargs() + if self.matches_role(ChatRole.USER): + kwargs["consecutive_user"] = True + if self.matches_role(ChatRole.ASSISTANT): + kwargs["consecutive_assistant"] = True + if self.matches_role(ChatRole.SYSTEM): + kwargs["consecutive_system"] = True + if self.matches_role(ChatRole.FUNCTION): + kwargs["multi_function_call"] = True + return kwargs + + # helper + def _do_merge(self, msgs: list[PipelineMsgT]) -> PipelineMsgT: + """Helper to merge multiple messages' content, respecting if it has parts or not""" + if len(msgs) < 1: + raise ValueError("At least one message must be supplied to merge") + + # if we're doing sep and only have 1 message, just ignore it and return it unchanged + if self.sep and len(msgs) == 1: + return msgs[0] + + has_parts = any(isinstance(m.content, list) for m in msgs) + + # we'll use the first message as the returned one + out_msg = msgs[0] + out_msg.role = self.out_role + out_msg.tool_call_id = None # if we are merging function calls, something is wacky and we should strip this + if not all(m.tool_calls is None for m in msgs): + out_msg.tool_calls = list(itertools.chain.from_iterable(m.tool_calls or [] for m in msgs)) + + # SEP + if self.sep: + # all text: replace the first message and return it + if not has_parts: + out_msg.content = self.sep.join(m.content or "" for m in msgs) + return out_msg + + # otherwise, build up a list, inserting sep between each + out_msg.content = out_msg.parts + for msg in msgs[1:]: + out_msg.content.append(self.sep) + out_msg.content.extend(msg.parts) + # JOINER + else: + out_msg.content = self.joiner(msgs) + return out_msg + + +class FunctionCallFmt(PipelineStep): + def __init__(self, func: FunctionCallStrT, *, prefix: str = "\n", sep: str = "", suffix: str = ""): + self.func = func + self.prefix = prefix + self.sep = sep + self.suffix = suffix + + def execute(self, msgs: list[PipelineMsgT]) -> list[PipelineMsgT]: + # for each message with 1+ tool calls, + for msg in msgs: + if not msg.tool_calls: + continue + + # call the func on each tc + tc_fmt = (fmt for tc in msg.tool_calls if (fmt := self.func(tc)) is not None) + # join and append to the message + tc_content = self.sep.join(tc_fmt) + _wrap_content_inplace(msg, prefix=None, suffix=f"{self.prefix}{tc_content}{self.suffix}") + + return msgs + + def explain(self) -> str: + fmt_repr = f"{self.prefix}{{{self.sep!r}.join(f(tc) for tc in msg.tool_calls)}}{self.suffix}" + return f"Apply the given function to each tool call, and append {fmt_repr!r} to each message with tool calls" + + def explain_example_kwargs(self) -> dict[str, bool]: + kwargs = {"function_call": True} + if self.sep: + kwargs["multi_function_call"] = True + return kwargs + + +class Remove(FilterMixin, PipelineStep): + def __init__(self, **filter_kwargs): + super().__init__(**filter_kwargs) + + def execute(self, msgs: list[PipelineMsgT]) -> list[PipelineMsgT]: + return [m for m in msgs if not self.matches_filter(m)] + + def explain(self) -> str: + return f"Remove all {self.explain_note()}" + + +class EnsureStart(FilterMixin, PipelineStep): + def __init__(self, **filter_kwargs): + super().__init__(**filter_kwargs) + + def execute(self, msgs: list[PipelineMsgT]) -> list[PipelineMsgT]: + first_matching_idx = next((i for i, msg in enumerate(msgs) if self.matches_filter(msg)), len(msgs)) + return msgs[first_matching_idx:] + + def explain(self) -> str: + return f"Ensure that the prompt starts with a {self.explain_note('or', plural=False)}" + + +class EnsureBoundFunctionCalls(PipelineStep): + def execute(self, msgs: list[PipelineMsgT]) -> list[PipelineMsgT]: + bound = [] + free_toolcall_ids = set() + for m in msgs: + # if this is not a function result and there are free tool call IDs, raise + if m.role != ChatRole.FUNCTION and free_toolcall_ids: + raise PromptError( + f"Encountered a {m.role.value!r} message but expected a FUNCTION message to satisfy the pending" + f" tool call(s): {free_toolcall_ids}" + ) + # asst: add tool call IDs to freevars + if m.role == ChatRole.ASSISTANT and m.tool_calls: + for tc in m.tool_calls: + free_toolcall_ids.add(tc.id) + # func: bind freevars + elif m.role == ChatRole.FUNCTION: + # has ID: bind it if requested; yeet if not + if m.tool_call_id is not None: + if m.tool_call_id in free_toolcall_ids: + free_toolcall_ids.remove(m.tool_call_id) + else: + # this happens if the tool call is pushed out of context but the result is still here, + # and we have always included messages beforehand + continue # yeet this message + # no ID: bind if unambiguous + elif len(free_toolcall_ids) == 1: + m.tool_call_id = free_toolcall_ids.pop() + # no ID: error if ambiguous + elif len(free_toolcall_ids) > 1: + raise PromptError( + "Got a FUNCTION message with no tool_call_id but multiple tool calls are pending" + f" ({free_toolcall_ids})! Set the tool_call_id to resolve the pending tool requests." + ) + # otherwise pass the FUNCTION message through + bound.append(m) + return bound + + def explain(self) -> str: + return "Ensure that each function call is bound" + + def explain_example_kwargs(self) -> dict[str, bool]: + return {"function_call": True} + + +class Apply(FilterMixin, PipelineStep): + def __init__(self, func: ApplyCallableT, **filter_kwargs): + super().__init__(**filter_kwargs) + self.func = func + + # func introspection: generate a wrapper for the right number of args provided + sig = inspect.signature(func) + if len(sig.parameters) == 1: + self.func_wrapped = lambda msg, is_last, idx: self.func(msg) + elif len(sig.parameters) == 2: + self.func_wrapped = lambda msg, is_last, idx: self.func(msg, is_last) + elif len(sig.parameters) == 3: + self.func_wrapped = self.func + else: + raise ValueError( + "The applied function must have 1 to 3 positional parameters (msg, is_last, idx) (got a function with" + f" {len(sig.parameters)} parameters)." + ) + + def execute(self, msgs: list[PipelineMsgT]) -> list[ApplyResultT]: + out = [] + for i, msg in enumerate(msgs): + # for each matching message, append f(msg) if it's not None + if self.matches_filter(msg): + replacement = self.func_wrapped(msg, i == len(msgs) - 1, i) + if replacement is not None: + out.append(replacement) + # else just append the msg unchanged + else: + out.append(msg) + return out + + def explain(self) -> str: + return f"Apply the given function to each {self.explain_note('and', plural=False)}" + + +# ==== terminals ==== +class ConversationFmt(PipelineStep): + def __init__( + self, + *, + # general formatting + prefix: str = "", + sep: str = "", + suffix: str = "", + generation_suffix: str = "", + # message-specific formatting + # USER messages + user_prefix: str = "", + user_suffix: str = "", + # ASSISTANT messages + assistant_prefix: str = "", + assistant_suffix: str = "", + assistant_suffix_if_last: str = None, + # SYSTEM messages + system_prefix: str = "", + system_suffix: str = "", + # FUNCTION messages (if not specified, defaults to user) + function_prefix: str = None, + function_suffix: str = None, + ): + self.prefix = prefix + self.sep = sep + self.suffix = suffix + self.generation_suffix = generation_suffix + self.user_prefix = user_prefix + self.user_suffix = user_suffix + self.assistant_prefix = assistant_prefix + self.assistant_suffix = assistant_suffix + self.assistant_suffix_if_last = assistant_suffix_if_last + self.system_prefix = system_prefix + self.system_suffix = system_suffix + self.function_prefix = function_prefix if function_prefix is not None else user_prefix + self.function_suffix = function_suffix if function_suffix is not None else user_suffix + + def execute(self, msgs: list[PipelineMsgT]) -> str: + parts = [] + for idx, msg in enumerate(msgs): + if msg.role == ChatRole.USER: + parts.append(f"{self.user_prefix}{msg.text}{self.user_suffix}") + elif msg.role == ChatRole.ASSISTANT: + # if this is the last message and we want custom ends, use the last suffix instead + if idx == len(msgs) - 1 and self.assistant_suffix_if_last is not None: + assistant_suffix = self.assistant_suffix_if_last + else: + assistant_suffix = self.assistant_suffix + parts.append(f"{self.assistant_prefix}{msg.text}{assistant_suffix}") + elif msg.role == ChatRole.SYSTEM: + parts.append(f"{self.system_prefix}{msg.text}{self.system_suffix}") + else: # function + parts.append(f"{self.function_prefix}{msg.text}{self.function_suffix}") + + # generation suffix if we aren't ending on an assistant message with a special asst_suffix_if_last + if not (msgs and msgs[-1].role == ChatRole.ASSISTANT and self.assistant_suffix_if_last is not None): + parts.append(self.generation_suffix) + + # join + return self.prefix + self.sep.join(parts) + self.suffix + + def explain(self) -> str: + return "Format the messages into a single conversation-formatted string (see example)" + + def explain_example_kwargs(self) -> dict[str, bool]: + kwargs = {} + if self.function_prefix != self.user_prefix or self.function_suffix != self.user_suffix: + kwargs["function_call"] = True + return kwargs + + +class ConversationDict(PipelineStep): + def __init__( + self, + *, + system_role: str = "system", + user_role: str = "user", + assistant_role: str = "assistant", + function_role: str = "tool", + content_transform: Callable[[ChatMessage], Any] = lambda msg: msg.text, + additional_keys: Callable[[ChatMessage], dict] = lambda msg: {}, + ): + self.role_to_str = { + ChatRole.SYSTEM: system_role, + ChatRole.USER: user_role, + ChatRole.ASSISTANT: assistant_role, + ChatRole.FUNCTION: function_role, + } + self.content_transform = content_transform + self.additional_keys = additional_keys + + def execute(self, msgs: list[PipelineMsgT]) -> list[dict[str, Any]]: + out = [] + for msg in msgs: + msg_dict = { + "role": self.role_to_str[msg.role], + "content": self.content_transform(msg), + **self.additional_keys(msg), + } + out.append(msg_dict) + return out + + def explain(self) -> str: + return 'Return the messages as dictionaries with {"role": ..., "content": ...} keys (see example)' + + def explain_example_kwargs(self) -> dict[str, bool]: + kwargs = {} + if self.role_to_str[ChatRole.FUNCTION] != "function": + kwargs["function_call"] = True + return kwargs + + +# ==== helpers ==== +def _wrap_content_inplace(msg: PipelineMsgT, prefix: str | None, suffix: str | None): + """Helper to add a prefix/suffix to a message's contents, respecting if it has parts or not""" + if isinstance(msg.content, str) or msg.content is None: + msg.content = f"{prefix or ''}{msg.content or ''}{suffix or ''}" + # list + else: + if prefix: + msg.content.insert(0, prefix) + if suffix: + msg.content.append(suffix) diff --git a/kani/prompts/types.py b/kani/prompts/types.py new file mode 100644 index 0000000..d051e15 --- /dev/null +++ b/kani/prompts/types.py @@ -0,0 +1,28 @@ +from collections.abc import Collection +from typing import Callable, TypeVar + +from kani.models import ChatMessage, ChatRole, MessagePart, ToolCall + +PipelineMsgT = ChatMessage +"""The type of messages in the pipeline""" + +MessageContentT = str | list[MessagePart | str] | None +"""The type of ChatMessage.content""" + +RoleFilterT = ChatRole | Collection[ChatRole] +"""A role or list of roles to apply a step to""" + +PredicateFilterT = Callable[[PipelineMsgT], bool] +"""A callable that determines whether or not to apply a step to an input""" + +FunctionCallStrT = Callable[[ToolCall], str | None] +"""A callable to format a toolcall as a str""" + +ApplyResultT = TypeVar("ApplyResultT") + +ApplyCallableT = ( + Callable[[PipelineMsgT], ApplyResultT] + | Callable[[PipelineMsgT, bool], ApplyResultT] + | Callable[[PipelineMsgT, bool, int], ApplyResultT] +) +"""A function taking 1-3 args""" diff --git a/kani/streaming.py b/kani/streaming.py new file mode 100644 index 0000000..4aec54a --- /dev/null +++ b/kani/streaming.py @@ -0,0 +1,165 @@ +import asyncio +import contextlib +from collections.abc import AsyncIterable + +from kani.engines.base import BaseCompletion, Completion +from kani.models import ChatMessage, ChatRole + + +class StreamManager: + """ + This class is responsible for managing a stream returned by an engine. It should not be constructed manually. + + To consume tokens from a stream, use this class as so:: + + # CHAT ROUND: + stream = ai.chat_round_stream("What is the airspeed velocity of an unladen swallow?") + async for token in stream: + print(token, end="") + msg = await stream.message() + + # FULL ROUND: + async for stream in ai.full_round_stream("What is the airspeed velocity of an unladen swallow?") + async for token in stream: + print(token, end="") + msg = await stream.message() + + After a stream finishes, its contents will be available as a :class:`.ChatMessage`. You can retrieve the final + message or :class:`.BaseCompletion` with:: + + msg = await stream.message() + completion = await stream.completion() + + The final :class:`.ChatMessage` may contain non-yielded tokens (e.g. a request for a function call). If the final + message or completion is requested before the stream is iterated over, the stream manager will consume the entire + stream. + + .. tip:: + For compatibility and ease of refactoring, awaiting the stream itself will also return the message, i.e.: + + .. code-block:: python + + msg = await ai.chat_round_stream("What is the airspeed velocity of an unladen swallow?") + + (note the ``await`` that is not present in the above examples). + """ + + def __init__( + self, stream_iter: AsyncIterable[str | BaseCompletion], role: ChatRole, *, after=None, lock: asyncio.Lock = None + ): + """ + :param stream_iter: The async iterable that generates elements of the stream. + :param role: The role of the message that will be returned eventually. + :param after: A coro to call with the generated completion as its argument after the stream is fully consumed. + :param lock: A lock to hold for the duration of the stream run. + """ + self.role = role + """The role of the message that this stream will return.""" + + # private + self._stream_iter = stream_iter + self._after = after + self._lock = lock if lock is not None else contextlib.nullcontext() + + # results + self._completion = None # the final completion result + self._awaited = False # whether or not this stream has already been consumed + self._finished = asyncio.Event() # whether or not the stream has finished + + # ==== stream components ==== + async def _stream_impl_outer(self): + # simple wrapper to lock the lock too, but *async* + async with self._lock: + async for elem in self._stream_impl(): + yield elem + + async def _stream_impl(self): + """ + Wrap the underlying stream iterable to handle mixed yield types and build completion when the stream finishes + without setting the completion. + """ + yielded_tokens = [] + + # for each token or completion yielded by the engine, + async for elem in self._stream_iter: + if self._completion is not None: + raise RuntimeError( + "Expected `BaseCompletion` to be final yield of stream iterable but got another value after!" + ) + + # re-yield if str + if isinstance(elem, str): + yield elem + yielded_tokens.append(elem) + # save if completion + elif isinstance(elem, BaseCompletion): + self._completion = elem + # panic otherwise + else: + raise TypeError( + "Expected yielded value from stream iterable to be `str` or `BaseCompletion` but got" + f" {type(elem)!r}!" + ) + + # if the stream is complete but we did not get a completion, we'll construct one here as the concatenation + # of all the yielded tokens + if self._completion is None: + content = "".join(yielded_tokens) + self._completion = Completion(message=ChatMessage(role=self.role, content=content.strip())) + + # run the callback, if any + if self._after is not None: + await self._after(self._completion) + + # allow anything waiting on the stream to finish to progress + self._finished.set() + + def __aiter__(self) -> AsyncIterable[str]: + """Iterate over tokens yielded from the engine.""" + # enforce that it can only be iterated over once + if self._awaited: + raise RuntimeError( + "This stream has already been consumed. If you are consuming both the stream and the final Completion," + " make sure you iterate over the stream first." + ) + self._awaited = True + # delegate to a wrapper for an async context + return self._stream_impl_outer() + + # ==== final result getters ==== + def __await__(self): + """Awaiting the StreamManager is equivalent to awaiting :meth:`message`.""" + return self.message().__await__() + + async def completion(self) -> BaseCompletion: + """Get the final :class:`.BaseCompletion` generated by the model.""" + # if we are getting the completion but no one has consumed our stream yet, just dummy do it so we build + # the completion + if not self._awaited: + async for _ in self: + pass + + # otherwise, wait for the stream to be complete then return the saved completion + await self._finished.wait() + return self._completion + + async def message(self) -> ChatMessage: + """Get the final :class:`.ChatMessage` generated by the model.""" + completion = await self.completion() + return completion.message + + +class DummyStream(StreamManager): + """Function calling helper: we already have the message.""" + + def __init__(self, message: ChatMessage): + # init a dummy iterable + async def _iter(): + yield message.text + yield Completion(message) + + super().__init__(_iter(), role=message.role) + self._message = message + + async def message(self) -> ChatMessage: + return self._message diff --git a/kani/utils/cli.py b/kani/utils/cli.py index 40fcfef..189c3de 100644 --- a/kani/utils/cli.py +++ b/kani/utils/cli.py @@ -8,7 +8,8 @@ from kani.kani import Kani from kani.models import ChatRole -from kani.utils.message_formatters import assistant_message_contents_thinking +from kani.streaming import StreamManager +from kani.utils.message_formatters import assistant_message_contents_thinking, assistant_message_thinking async def chat_in_terminal_async( @@ -22,6 +23,7 @@ async def chat_in_terminal_async( show_function_args: bool = False, show_function_returns: bool = False, verbose: bool = False, + stream: bool = True, ): """Async version of :func:`.chat_in_terminal`. Use in environments when there is already an asyncio loop running (e.g. Google Colab). @@ -43,18 +45,34 @@ async def chat_in_terminal_async( print_width(query, width=width, prefix="USER: ") if stopword and query == stopword: break + # print completion(s) else: query = None # print completion(s) - async for msg in kani.full_round(query): - # assistant - if msg.role == ChatRole.ASSISTANT: - text = assistant_message_contents_thinking(msg, show_args=show_function_args) - print_width(text, width=width, prefix="AI: ") - # function - elif msg.role == ChatRole.FUNCTION and show_function_returns: - print_width(msg.text, width=width, prefix="FUNC: ") + if stream: + async for stream in kani.full_round_stream(query): + # assistant + if stream.role == ChatRole.ASSISTANT: + await print_stream(stream, width=width, prefix="AI: ") + msg = await stream.message() + text = assistant_message_thinking(msg, show_args=show_function_args) + if text: + print_width(text, width=width, prefix="AI: ") + # function + elif stream.role == ChatRole.FUNCTION and show_function_returns: + msg = await stream.message() + print_width(msg.text, width=width, prefix="FUNC: ") + # completions only + else: + async for msg in kani.full_round(query): + # assistant + if msg.role == ChatRole.ASSISTANT: + text = assistant_message_contents_thinking(msg, show_args=show_function_args) + print_width(text, width=width, prefix="AI: ") + # function + elif msg.role == ChatRole.FUNCTION and show_function_returns: + print_width(msg.text, width=width, prefix="FUNC: ") except KeyboardInterrupt: pass finally: @@ -73,6 +91,7 @@ def chat_in_terminal( show_function_args: bool = False, show_function_returns: bool = False, verbose: bool = False, + stream: bool = True, ): ... @@ -98,6 +117,7 @@ def chat_in_terminal(kani: Kani, **kwargs): (default false). :param bool show_function_returns: Whether to print the results of each function call (default false). :param bool verbose: Equivalent to setting ``echo``, ``show_function_args``, and ``show_function_returns`` to True. + :param bool stream: Whether or not to print tokens as soon as they are generated by the model (default true). """ try: asyncio.get_running_loop() @@ -141,3 +161,44 @@ def print_width(msg: str, width: int = None, prefix: str = ""): out.append(wrapper.fill(line)) wrapper.initial_indent = wrapper.subsequent_indent print("\n".join(out)) + + +async def print_stream(stream: StreamManager, width: int = None, prefix: str = ""): + """ + Print tokens from a stream to the terminal, with the width of each line less than *width*. + If *prefix* is provided, indents each line after the first by the length of the prefix. + + This is a helper function intended to be used with :meth:`.Kani.chat_round_stream` or + :meth:`.Kani.full_round_stream`. + """ + prefix_len = len(prefix) + line_indent = " " * prefix_len + prefix_printed = False + + # print tokens until they overflow width then newline and indent + line_len = prefix_len + async for token in stream: + # only print the prefix if the model actually yields anything + if not prefix_printed: + print(prefix, end="") + prefix_printed = True + + # split by newlines + for part in token.splitlines(keepends=True): + # then do bookkeeping + line_len += len(part) + if width and line_len > width: + print(f"\n{line_indent}", end="") + line_len = prefix_len + + # print the token + print(part.rstrip("\r\n"), end="", flush=True) + + # print a newline if the token had one + if part.endswith("\n"): + print(f"\n{line_indent}", end="") + line_len = prefix_len + + # newline at the end to flush if we printed anything + if prefix_printed: + print() diff --git a/kani/utils/message_formatters.py b/kani/utils/message_formatters.py index c5fde3b..02dedcb 100644 --- a/kani/utils/message_formatters.py +++ b/kani/utils/message_formatters.py @@ -36,16 +36,29 @@ def assistant_message_contents_thinking(msg: ChatMessage, show_args=False): if not msg.tool_calls: function_calls = "" else: - # with args: show a nice repr (e.g. `get_weather(location="San Francisco, CA", unit="fahrenheit")`) - if show_args: - parts = [] - for tc in msg.tool_calls: - args = ", ".join(f"{kwarg}={v!r}" for kwarg, v in tc.function.kwargs.items()) - parts.append(f"{tc.function.name}({args})") - called_functions = "; ".join(parts) - # no args: just print the function name - else: - called_functions = "; ".join(tc.function.name for tc in msg.tool_calls) - function_calls = f"Thinking... [{called_functions}]" - - return text + function_calls + function_calls = f"\n{assistant_message_thinking(msg, show_args)}" + + return (text + function_calls).strip() + + +def assistant_message_thinking(msg: ChatMessage, show_args=False): + """Return "Thinking..." on assistant messages with function calls, ignoring any content. + + This is useful if you are streaming the message's contents. + + If *show_args* is True, include the arguments to each function call. + """ + if msg.role != ChatRole.ASSISTANT or not msg.tool_calls: + return + + # with args: show a nice repr (e.g. `get_weather(location="San Francisco, CA", unit="fahrenheit")`) + if show_args: + parts = [] + for tc in msg.tool_calls: + args = ", ".join(f"{kwarg}={v!r}" for kwarg, v in tc.function.kwargs.items()) + parts.append(f"{tc.function.name}({args})") + called_functions = "; ".join(parts) + # no args: just print the function name + else: + called_functions = "; ".join(tc.function.name for tc in msg.tool_calls) + return f"Thinking... [{called_functions}]" diff --git a/pyproject.toml b/pyproject.toml index 626510e..65bb64d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "kani" -version = "0.8.0" +dynamic = ["version"] authors = [ { name = "Andrew Zhu", email = "andrew@zhu.codes" }, { name = "Liam Dugan", email = "ldugan@seas.upenn.edu" }, @@ -24,35 +24,43 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence", ] dependencies = [ - "aiohttp>=3.0.0,<4.0.0", "pydantic>=2.0.0,<3.0.0", ] [project.optional-dependencies] +all = [ + "kani[huggingface,llama,cpp,openai,anthropic]", +] + huggingface = [ - "transformers>=4.0.0,<5.0.0", + "transformers>=4.28.0,<5.0.0", ] llama = [ "sentencepiece~=0.1.99", ] -openai = [ - "tiktoken>=0.4.0,<1.0.0", +cpp = [ + "llama-cpp-python>=0.2.0,<1.0.0", + "huggingface-hub", ] -ctransformers = [ - "ctransformers>=0.2.25,<1.0.0", +openai = [ + "openai>=1.0.0,<2.0.0", + "tiktoken>=0.4.0,<1.0.0", ] anthropic = [ - "anthropic>=0.7.3,<1.0.0", + "anthropic>=0.23.0,<1.0.0", ] [project.urls] "Homepage" = "https://github.com/zhudotexe/kani" "Bug Tracker" = "https://github.com/zhudotexe/kani/issues" +[tool.hatch] +version.path = "kani/_version.py" + [tool.black] line-length = 120 preview = true diff --git a/requirements.txt b/requirements.txt index dc528ed..ba778d2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,23 +1,15 @@ -aiohttp>=3.0.0,<4.0.0 -pydantic>=2.0.0,<3.0.0 - # in development, we want to install all the extra deps so the docs generate correctly -# --- huggingface --- +-e .[all] + +# manual install +aiohttp torch -transformers>=4.0.0,<5.0.0 -# --- llama --- -sentencepiece~=0.1.99 -# --- openai --- -tiktoken>=0.4.0,<1.0.0 -# --- ctransformers --- -ctransformers>=0.2.25,<1.0.0 -# --- anthropic --- -anthropic>=0.7.3,<1.0.0 # dev black build isort +hatch hypothesis pytest pytest-asyncio diff --git a/tests/engine.py b/tests/engine.py index c297ec8..99ec4bb 100644 --- a/tests/engine.py +++ b/tests/engine.py @@ -1,5 +1,7 @@ -from kani import ChatMessage -from kani.engines.base import BaseEngine, Completion +from typing import AsyncIterable + +from kani import AIFunction, ChatMessage +from kani.engines.base import BaseCompletion, BaseEngine, Completion class TestEngine(BaseEngine): @@ -22,3 +24,18 @@ async def predict(self, messages, functions=None, test_echo=False, **hyperparams content = messages[-1].text return Completion(ChatMessage.assistant(content)) return Completion(ChatMessage.assistant("a")) + + +class TestStreamingEngine(TestEngine): + """A mock engine used for testing but with streaming (yields one character at a time).""" + + async def stream( + self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, test_echo=False, **hyperparams + ) -> AsyncIterable[str | BaseCompletion]: + assert sum(len(m.text or "") for m in messages) <= self.max_context_size + if test_echo: + content = messages[-1].text + for char in content: + yield char + else: + yield "a" diff --git a/tests/test_chatmessage.py b/tests/test_chatmessage.py index 76682e9..23a4017 100644 --- a/tests/test_chatmessage.py +++ b/tests/test_chatmessage.py @@ -1,10 +1,9 @@ -import pydantic import pytest from kani import ChatMessage, ChatRole, MessagePart, ToolCall -class TestMessagePart(MessagePart): +class _TestMessagePart(MessagePart): def __str__(self): return "" @@ -19,11 +18,11 @@ def test_basic(): assert msg.function_call is None -def test_immutable(): +# v1.0: no more immutability >:c +def test_mutable(): msg = ChatMessage(role=ChatRole.USER, content="Hello world") - with pytest.raises(pydantic.ValidationError): - msg.content = "not allowed" - assert msg.content == "Hello world" + msg.content = "not allowed" + assert msg.content == "not allowed" def test_none_content(): @@ -34,27 +33,27 @@ def test_none_content(): def test_parts(): - part = TestMessagePart() + part = _TestMessagePart() msg = ChatMessage(role=ChatRole.USER, content=["Hello world", part]) - assert msg.content == ("Hello world", part) + assert msg.content == ["Hello world", part] assert msg.text == "Hello world" assert msg.parts == ["Hello world", part] def test_copy_parts(): - part = TestMessagePart() + part = _TestMessagePart() msg = ChatMessage(role=ChatRole.USER, content=["Hello world", part]) text_copy = msg.copy_with(text="asdf") assert text_copy.content == "asdf" part_copy = msg.copy_with(parts=["foo"]) - assert part_copy.content == ("foo",) + assert part_copy.content == ["foo"] content_copy = msg.copy_with(content="zxcv") assert content_copy.content == "zxcv" - assert msg.content == ("Hello world", part) + assert msg.content == ["Hello world", part] with pytest.raises(ValueError): msg.copy_with(text="foo", parts=[]) @@ -82,4 +81,4 @@ def test_support_multiple_tool_calls(): msg = ChatMessage(role="function", content=None, tool_calls=[tool_call1, tool_call2]) - assert msg.tool_calls == (tool_call1, tool_call2) + assert msg.tool_calls == [tool_call1, tool_call2] diff --git a/tests/test_llama.py b/tests/test_llama.py index b3d8ee1..cdd82d0 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -5,14 +5,14 @@ import pytest from kani import ChatMessage, Kani -from kani.engines.ctransformers.llama2 import LlamaCTransformersEngine +from kani.engines.llamacpp import LlamaCppEngine pytestmark = pytest.mark.llama @pytest.fixture(scope="module") def llama(): - return LlamaCTransformersEngine("TheBloke/Llama-2-7B-Chat-GGML", model_file="llama-2-7b-chat.ggmlv3.q4_K_S.bin") + return LlamaCppEngine(repo_id="TheBloke/Llama-2-7B-Chat-GGUF", filename="*.Q4_K_M.gguf") @pytest.fixture() diff --git a/tests/test_llama2_prompt.py b/tests/test_llama2_prompt.py index 3a1b768..2fda0b9 100644 --- a/tests/test_llama2_prompt.py +++ b/tests/test_llama2_prompt.py @@ -1,18 +1,11 @@ """Tests to ensure the LLaMA v2 prompt is correct.""" from kani import ChatMessage -from kani.engines.llama2_prompt import build_str +from kani.prompts.impl import LLAMA2_PIPELINE def prompt_str(messages: list[ChatMessage]) -> str: - out = "" - for content, bos, eos in build_str(messages): - if bos: - out += "" - out += content - if eos: - out += "" - return out + return LLAMA2_PIPELINE(messages) def test_basic(): @@ -20,7 +13,7 @@ def test_basic(): ChatMessage.user("Hello there."), ChatMessage.assistant("General Kenobi."), ] - expected = "[INST] Hello there. [/INST] General Kenobi. " + expected = "[INST] Hello there. [/INST] General Kenobi." assert prompt_str(messages) == expected messages = [ @@ -28,7 +21,7 @@ def test_basic(): ChatMessage.user("Hello there."), ChatMessage.assistant("General Kenobi."), ] - expected = "[INST] <>\nI am a system message.\n<>\n\nHello there. [/INST] General Kenobi. " + expected = "[INST] <>\nI am a system message.\n<>\n\nHello there. [/INST] General Kenobi." assert prompt_str(messages) == expected @@ -42,7 +35,7 @@ def test_2round(): ] expected = ( "[INST] <>\nI am a system message.\n<>\n\nHello there. [/INST] General Kenobi. " - "[INST] I am: [/INST] a potato. " + "[INST] I am: [/INST] a potato." ) assert prompt_str(messages) == expected @@ -56,7 +49,7 @@ def test_2system(): ] expected = ( "[INST] <>\nI am a system message.\n<>\n\n" - "<>\nBut wait, there's more.\n<>\n\nHello there. [/INST] General Kenobi. " + "<>\nBut wait, there's more.\n<>\n\nHello there. [/INST] General Kenobi." ) assert prompt_str(messages) == expected @@ -70,7 +63,7 @@ def test_2user(): ] expected = ( "[INST] <>\nI am a system message.\n<>\n\n" - "Hello there.I am another message. [/INST] General Kenobi. " + "Hello there.\nI am another message. [/INST] General Kenobi." ) assert prompt_str(messages) == expected @@ -84,6 +77,6 @@ def test_2asst(): ] expected = ( "[INST] <>\nI am a system message.\n<>\n\n" - "Hello there. [/INST] General Kenobi.I am another message. " + "Hello there. [/INST] General Kenobi. I am another message." ) assert prompt_str(messages) == expected diff --git a/tests/test_mock_openai.py b/tests/test_mock_openai.py deleted file mode 100644 index 4b10760..0000000 --- a/tests/test_mock_openai.py +++ /dev/null @@ -1,66 +0,0 @@ -"""Ensure that any messages sent to OpenAI are valid (mock the API and just echo).""" - -import time - -from hypothesis import HealthCheck, given, settings, strategies as st -from pydantic import RootModel - -from kani import ChatMessage, Kani -from kani.engines.openai import OpenAIClient, OpenAIEngine -from kani.engines.openai.models import OpenAIChatMessage - - -class MockOpenAIClient(OpenAIClient): - async def request(self, method: str, route: str, headers=None, retry=None, **kwargs): - if route != "/chat/completions": - raise ValueError("only chat completions is mocked in tests") - - # validate that all the messages come across correctly - data = kwargs["json"] - RootModel[list[OpenAIChatMessage]].model_validate(data["messages"]) - - async def post(self, route: str, **kwargs): - if route != "/chat/completions": - raise ValueError("only chat completions is mocked in tests") - - await self.request("POST", route, **kwargs) - data = kwargs["json"] - message = data["messages"][-1] if data["messages"] else {"role": "assistant", "content": None} - return dict( - id="some-id", - object="chat.completion", - created=int(time.time()), - model=data["model"], - usage=dict(prompt_tokens=0, completion_tokens=0, total_tokens=0), - choices=[dict(message=message, index=0)], - ) - - -class MockOpenAIEngine(OpenAIEngine): - @staticmethod - def translate_messages(messages, cls=OpenAIChatMessage): - # we don't care about the tool call bindings here - just the translation - return [cls.from_chatmessage(m) for m in messages] - - -client = MockOpenAIClient("sk-fake-api-key") -engine = MockOpenAIEngine(client=client) - - -# hypothesis synchronously constructs a coro to call MockOpenAIClient.create_chat_completion -# based on the type annotations of the async function -# we then await the returned coro in the async test body -@settings(suppress_health_check=(HealthCheck.too_slow,), deadline=None) -@given(st.builds(client.create_chat_completion)) -async def test_chat_completions_valid(coro): - await coro - - -def build_kani_state(msgs: list[ChatMessage]): - return Kani(engine, chat_history=msgs) - - -@settings(suppress_health_check=(HealthCheck.too_slow,), deadline=None) -@given(st.builds(build_kani_state)) -async def test_kani_chatmessages_valid(ai): - await ai.get_model_completion() diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 0000000..01a26bf --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,218 @@ +import itertools + +import pytest +from hypothesis import given, strategies as st + +from kani import ChatMessage, ChatRole, FunctionCall, PromptPipeline, ToolCall +from kani.prompts.base import FilterMixin, PipelineStep +from kani.prompts.examples import ALL_EXAMPLE_KWARGS, build_conversation +from kani.prompts.types import PipelineMsgT + + +# ==== bases ==== +class DummyFilterStep(FilterMixin, PipelineStep): + def execute(self, msgs: list[PipelineMsgT]): + return list(self.filtered(msgs)) + + +EXAMPLE_MSGS = list(itertools.chain.from_iterable(build_conversation(**{k: True for k in ALL_EXAMPLE_KWARGS}))) + + +def test_filter_none(): + # no filter does nothing + pipe = PromptPipeline([DummyFilterStep()]) + assert pipe(EXAMPLE_MSGS) == EXAMPLE_MSGS + + +@given(st.sampled_from(ChatRole)) +def test_filter_role(role): + # role filter + pipe = PromptPipeline([DummyFilterStep(role=role)]) + assert all(m.role == role for m in pipe(EXAMPLE_MSGS)) + + +@given(st.lists(st.sampled_from(ChatRole), min_size=1)) +def test_filter_roles(roles): + pipe = PromptPipeline([DummyFilterStep(role=roles)]) + assert all(m.role in roles for m in pipe(EXAMPLE_MSGS)) + + +def test_filter_predicate(): + # only messages with a tool call + pipe = PromptPipeline([DummyFilterStep(predicate=lambda msg: msg.tool_calls and len(msg.tool_calls) > 0)]) + assert all(m.tool_calls for m in pipe(EXAMPLE_MSGS)) + + # only messages whose text is longer than 10 characters + pipe = PromptPipeline([DummyFilterStep(predicate=lambda msg: msg.text and len(msg.text) > 10)]) + assert all(len(m.text) > 10 for m in pipe(EXAMPLE_MSGS) if m.text) + + +def test_filter_predicate_and_role(): + # only USER messages with a tool call (should be none) + pipe = PromptPipeline( + [DummyFilterStep(role=ChatRole.USER, predicate=lambda msg: msg.tool_calls and len(msg.tool_calls) > 0)] + ) + assert len(pipe(EXAMPLE_MSGS)) == 0 + + +# ==== steps ==== +def test_translate_role(): + # should translate the role but nothing else + pipe = PromptPipeline().translate_role(to=ChatRole.USER) + assert len(pipe(EXAMPLE_MSGS)) == len(EXAMPLE_MSGS) + assert all(m.role == ChatRole.USER for m in pipe(EXAMPLE_MSGS)) + assert all(m.content == original.content for m, original in zip(pipe(EXAMPLE_MSGS), EXAMPLE_MSGS)) + + +def test_wrap(): + # prefix + pipe = PromptPipeline().wrap(prefix="prefix!") + assert len(pipe(EXAMPLE_MSGS)) == len(EXAMPLE_MSGS) + assert all(m.text == f"prefix!{original.text or ''}" for m, original in zip(pipe(EXAMPLE_MSGS), EXAMPLE_MSGS)) + + # suffix + pipe = PromptPipeline().wrap(suffix="!suffix") + assert len(pipe(EXAMPLE_MSGS)) == len(EXAMPLE_MSGS) + assert all(m.text == f"{original.text or ''}!suffix" for m, original in zip(pipe(EXAMPLE_MSGS), EXAMPLE_MSGS)) + + # both + pipe = PromptPipeline().wrap(prefix="prefix!", suffix="!suffix") + assert len(pipe(EXAMPLE_MSGS)) == len(EXAMPLE_MSGS) + assert all( + m.text == f"prefix!{original.text or ''}!suffix" for m, original in zip(pipe(EXAMPLE_MSGS), EXAMPLE_MSGS) + ) + + # neither + pipe = PromptPipeline().wrap() + assert len(pipe(EXAMPLE_MSGS)) == len(EXAMPLE_MSGS) + assert all(m.text == (original.text or "") for m, original in zip(pipe(EXAMPLE_MSGS), EXAMPLE_MSGS)) + + +@given(st.sampled_from(ChatRole)) +def test_merge_consecutive(merge_role): + msgs = [ + ChatMessage.system("sys1"), + ChatMessage.system("sys2"), + ChatMessage.user("user1"), + ChatMessage.user("user2"), + ChatMessage.assistant("asst1"), + ChatMessage.assistant("asst2"), + ChatMessage.function("func", "func1"), + ChatMessage.function("func", "func2"), + ] + + # one role + pipe = PromptPipeline().merge_consecutive(role=merge_role, sep="\n") + assert len(pipe(msgs)) == len(msgs) - 1 + assert sum(1 for msg in pipe(msgs) if msg.role == merge_role) == 1 + assert sum(1 for msg in pipe(msgs) if msg.role != merge_role) == 6 + + # all into one + pipe = PromptPipeline().merge_consecutive( + role=[ChatRole.SYSTEM, ChatRole.USER, ChatRole.ASSISTANT, ChatRole.FUNCTION], sep="\n", out_role=ChatRole.USER + ) + assert len(pipe(msgs)) == 1 + assert pipe(msgs)[0].role == ChatRole.USER + assert pipe(msgs)[0].text == "sys1\nsys2\nuser1\nuser2\nasst1\nasst2\nfunc1\nfunc2" + + with pytest.raises(ValueError): + # missing out_role + PromptPipeline().merge_consecutive( + role=[ChatRole.SYSTEM, ChatRole.USER, ChatRole.ASSISTANT, ChatRole.FUNCTION], sep="\n" + ) + + +def test_function_call_fmt(): + pipe = PromptPipeline().function_call_fmt( + lambda tc: "I'm a function call!", prefix="prefix!", suffix="!suffix", sep="\n===\n" + ) + + # one function call + msgs = [ + ChatMessage.assistant( + content=None, function_call=FunctionCall.with_args("get_weather", location="Tokyo, JP", unit="celsius") + ) + ] + assert pipe(msgs)[0].content == "prefix!I'm a function call!!suffix" + + # one with content + msgs = [ + ChatMessage.assistant( + content="content", function_call=FunctionCall.with_args("get_weather", location="Tokyo, JP", unit="celsius") + ) + ] + assert pipe(msgs)[0].content == "contentprefix!I'm a function call!!suffix" + + # many function calls + msgs = [ + ChatMessage.assistant( + content=None, + tool_calls=[ + ToolCall.from_function("make_food", food="coffee", appliance="teapot"), + ToolCall.from_function("make_food", food="cereal", dish="bowl"), + ], + ), + ] + assert pipe(msgs)[0].content == "prefix!I'm a function call!\n===\nI'm a function call!!suffix" + + # no function call + msgs = [ChatMessage.assistant(content="Just passing by")] + assert pipe(msgs) == msgs + + +def test_remove(): + # with no filters it removes everything + # and we know the filters work, so... + pipe = PromptPipeline().remove() + assert len(pipe(EXAMPLE_MSGS)) == 0 + + # if it doesn't match anything it shouldn't do anything + pipe = PromptPipeline().remove(predicate=lambda _: False) + assert len(pipe(EXAMPLE_MSGS)) == len(EXAMPLE_MSGS) + assert pipe(EXAMPLE_MSGS) == EXAMPLE_MSGS + + +@given(st.sampled_from(ChatRole)) +def test_ensure_start(role): + pipe = PromptPipeline().ensure_start(role=role) + assert len(pipe(EXAMPLE_MSGS)) + assert pipe(EXAMPLE_MSGS)[0].role == role + + if EXAMPLE_MSGS[0].role == role: + assert len(pipe(EXAMPLE_MSGS)) == len(EXAMPLE_MSGS) + + +def test_apply(): + # noop apply + pipe = PromptPipeline().apply(lambda m: m) + assert len(pipe(EXAMPLE_MSGS)) == len(EXAMPLE_MSGS) + assert pipe(EXAMPLE_MSGS) == EXAMPLE_MSGS + + # remove everything + pipe = PromptPipeline().apply(lambda m: None) + assert len(pipe(EXAMPLE_MSGS)) == 0 + + # 2 args: translate last to SYSTEM + def translate_last(msg, is_last): + if is_last: + msg.role = ChatRole.SYSTEM + return msg + + pipe = PromptPipeline().apply(translate_last) + assert len(pipe(EXAMPLE_MSGS)) == len(EXAMPLE_MSGS) + assert pipe(EXAMPLE_MSGS)[-1].role == ChatRole.SYSTEM + + # 3 args: translate even msgs to SYSTEM + def translate_even(msg, _, idx): + if idx % 2 == 0: + msg.role = ChatRole.SYSTEM + return msg + + pipe = PromptPipeline().apply(translate_even) + msgs = pipe(EXAMPLE_MSGS) + assert len(msgs) == len(EXAMPLE_MSGS) + for idx, (msg, original) in enumerate(zip(msgs, EXAMPLE_MSGS)): + if idx % 2 == 0: + assert msg.role == ChatRole.SYSTEM + else: + assert msg == original diff --git a/tests/test_saveload.py b/tests/test_saveload.py index 57897a4..6560bb6 100644 --- a/tests/test_saveload.py +++ b/tests/test_saveload.py @@ -61,24 +61,24 @@ async def test_saveload_tool_calls(tmp_path): assert ai.chat_history == loaded.chat_history -class TestMessagePart1(MessagePart): +class _TestMessagePart1(MessagePart): data: str -class TestMessagePart2(MessagePart): +class _TestMessagePart2(MessagePart): data: str async def test_saveload_messageparts(tmp_path): """Test that message parts are serialized and deserialized into the right classes.""" - apart1 = TestMessagePart1(data="apart1") - apart2 = TestMessagePart2(data="apart2") - hpart1 = TestMessagePart1(data="hpart1") - hpart2 = TestMessagePart2(data="hpart2") + apart1 = _TestMessagePart1(data="apart1") + apart2 = _TestMessagePart2(data="apart2") + hpart1 = _TestMessagePart1(data="hpart1") + hpart2 = _TestMessagePart2(data="hpart2") # ensure that different instances with the same data are the same - assert apart1 == TestMessagePart1(data="apart1") + assert apart1 == _TestMessagePart1(data="apart1") # ensure that different classes/data are not - assert apart1 != TestMessagePart2(data="apart1") + assert apart1 != _TestMessagePart2(data="apart1") assert apart1 != hpart1 # init kani state diff --git a/tests/test_streaming.py b/tests/test_streaming.py new file mode 100644 index 0000000..4bd61e8 --- /dev/null +++ b/tests/test_streaming.py @@ -0,0 +1,60 @@ +import pytest +from hypothesis import HealthCheck, given, settings, strategies as st + +from kani import ChatMessage, ChatRole, Kani +from tests.engine import TestEngine, TestStreamingEngine +from tests.utils import flatten_chatmessages + +engine = TestEngine() +streaming_engine = TestStreamingEngine() + + +@pytest.mark.parametrize("eng", [engine, streaming_engine]) +async def test_chat_round_stream_consume_all(eng): + ai = Kani(eng, desired_response_tokens=3) + # 5 tokens, no omitting + resp = await ai.chat_round_stream("12345") + assert resp.content == "a" + assert resp.role == ChatRole.ASSISTANT + assert len(ai.chat_history) == 2 + prompt = await ai.get_prompt() + assert len(prompt) == 2 # 12345a + assert flatten_chatmessages(prompt) == "12345a" + + +async def test_chat_round_stream(): + ai = Kani(streaming_engine, desired_response_tokens=3) + stream = ai.chat_round_stream("12345") + async for token in stream: + assert token == "a" + resp = await stream.message() + assert resp.content == "a" + + ai = Kani(streaming_engine, desired_response_tokens=3) + stream = ai.chat_round_stream("aaa", test_echo=True) + async for token in stream: + assert token == "a" + resp = await stream.message() + assert resp.content == "aaa" + + +@settings(suppress_health_check=(HealthCheck.too_slow,), deadline=None) +@given(st.data()) +@pytest.mark.parametrize("eng", [engine, streaming_engine]) +async def test_spam_stream(eng, data): + # spam the kani with a bunch of random prompts + # and make sure it never breaks + ai = Kani( + eng, + desired_response_tokens=3, + system_prompt=data.draw(st.text(min_size=0, max_size=1)), + always_included_messages=[ChatMessage.user(data.draw(st.text(min_size=0, max_size=1)))], + ) + queries = data.draw(st.lists(st.text(min_size=0, max_size=5))) + for query in queries: + query = query.strip() + resp = await ai.chat_round_stream(query, test_echo=True) + assert resp.content == query + + prompt = await ai.get_prompt() + assert sum(ai.message_token_len(m) for m in prompt) <= ai.max_context_size