Skip to content

Commit

Permalink
+ kg-chat
Browse files Browse the repository at this point in the history
  • Loading branch information
pchalasani committed Jan 25, 2024
1 parent c0cf2f9 commit 85779ce
Show file tree
Hide file tree
Showing 17 changed files with 1,328 additions and 460 deletions.
71 changes: 71 additions & 0 deletions examples/basic/chat-local-numerical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""
Test multi-round interaction with a local LLM, playing a simple "doubling game".
In each round:
- User gives a number
- LLM responds with the double of that number
Run like this --
python3 examples/basic/chat-local-numerical.py -m <model_name_with_formatter_after//>
Recommended local model setup:
- spin up an LLM with oobabooga at an endpoint like http://127.0.0.1:5000/v1
- run this script with -m local/127.0.0.1:5000/v1
- To ensure accurate chat formatting (and not use the defaults from ooba),
append the appropriate HuggingFace model name to the
-m arg, separated by //, e.g. -m local/127.0.0.1:5000/v1//mistral-instruct-v0.2
(no need to include the full model name, as long as you include enough to
uniquely identify the model's chat formatting template)
"""
import os
import fire

import langroid as lr
from langroid.utils.configuration import settings
import langroid.language_models as lm

# for best results:
DEFAULT_LLM = lm.OpenAIChatModel.GPT4_TURBO

os.environ["TOKENIZERS_PARALLELISM"] = "false"

# (1) Define the desired fn-call as a ToolMessage via Pydantic.


def app(
m: str = DEFAULT_LLM, # model name
d: bool = False, # debug
nc: bool = False, # no cache
):
settings.debug = d
settings.cache = not nc
# create LLM config
llm_cfg = lm.OpenAIGPTConfig(
chat_model=m or DEFAULT_LLM,
chat_context_length=4096, # set this based on model
max_output_tokens=100,
temperature=0.2,
stream=True,
timeout=45,
)

agent = lr.ChatAgent(
lr.ChatAgentConfig(
llm=llm_cfg,
system_message="""
You are a number-doubling expert. When user gives you a NUMBER,
simply respond with its DOUBLE and SAY NOTHING ELSE.
DO NOT EXPLAIN YOUR ANSWER OR YOUR THOUGHT PROCESS.
""",
)
)

task = lr.Task(agent)
task.run("15") # initial number


if __name__ == "__main__":
fire.Fire(app)
2 changes: 1 addition & 1 deletion examples/basic/chat-local.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# or you can explicitly specify it as `lm.OpenAIChatModel.GPT4` or `lm.OpenAIChatModel.GPT4_TURBO`

llm_config = lm.OpenAIGPTConfig(
chat_model="litellm/ollama/mistral",
chat_model="litellm/ollama/mistral:7b-instruct-v0.2-q4_K_M",
max_output_tokens=200,
chat_context_length=2048, # adjust based on your local LLM params
)
Expand Down
109 changes: 35 additions & 74 deletions examples/basic/chat-search.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,14 @@
python3 examples/basic/chat-search.py
You can specify which search provider to use with this optional flag:
There are optional args, especially note these:
-p or --provider: google or sciphi (default: google)
-m <model_name>: to run with a different LLM model (default: gpt4-turbo)
See the comments at the top of this script for more on how to specify local LLMs:
https://github.com/langroid/langroid/blob/main/examples/docqa/rag-local-simple.py
NOTE:
Expand All @@ -29,53 +34,38 @@

import typer
from dotenv import load_dotenv
from pydantic import BaseSettings
from rich import print
from rich.prompt import Prompt

import langroid.language_models as lm
from langroid.agent.chat_agent import ChatAgent, ChatAgentConfig
from langroid.agent.task import Task
from langroid.agent.tools.google_search_tool import GoogleSearchTool
from langroid.agent.tools.sciphi_search_rag_tool import SciPhiSearchRAGTool
from langroid.language_models.openai_gpt import OpenAIGPTConfig
from langroid.utils.configuration import Settings, set_global
from langroid.utils.logging import setup_colored_logging

app = typer.Typer()

setup_colored_logging()

# create classes for other model configs
LiteLLMOllamaConfig = OpenAIGPTConfig.create(prefix="ollama")
litellm_ollama_config = LiteLLMOllamaConfig(
chat_model="ollama/llama2",
completion_model="ollama/llama2",
api_base="http://localhost:11434",
litellm=True,
chat_context_length=4096,
use_completion_for_chat=False,
)
OobaConfig = OpenAIGPTConfig.create(prefix="ooba")
ooba_config = OobaConfig(
chat_model="local", # doesn't matter
completion_model="local", # doesn't matter
api_base="http://localhost:8000/v1", # <- edit if running at a different port
chat_context_length=2048,
litellm=False,
use_completion_for_chat=False,
)


class CLIOptions(BaseSettings):
model: str = ""
provider: str = "google"

class Config:
extra = "forbid"
env_prefix = ""


def chat(opts: CLIOptions) -> None:
@app.command()
def main(
debug: bool = typer.Option(False, "--debug", "-d", help="debug mode"),
model: str = typer.Option("", "--model", "-m", help="model name"),
provider: str = typer.Option(
"google", "--provider", "-p", help="search provider name (Google, SciPhi)"
),
no_stream: bool = typer.Option(False, "--nostream", "-ns", help="no streaming"),
nocache: bool = typer.Option(False, "--nocache", "-nc", help="don't use cache"),
cache_type: str = typer.Option(
"redis", "--cachetype", "-ct", help="redis or momento"
),
) -> None:
set_global(
Settings(
debug=debug,
cache=not nocache,
stream=not no_stream,
cache_type=cache_type,
)
)
print(
"""
[blue]Welcome to the Web Search chatbot!
Expand All @@ -92,14 +82,9 @@ def chat(opts: CLIOptions) -> None:

load_dotenv()

# use the appropriate config instance depending on model name
if opts.model == "ooba":
llm_config = ooba_config
elif opts.model.startswith("ollama"):
llm_config = litellm_ollama_config
llm_config.chat_model = opts.model
else:
llm_config = OpenAIGPTConfig()
llm_config = lm.OpenAIGPTConfig(
chat_model=model or lm.OpenAIChatModel.GPT4_TURBO,
)

config = ChatAgentConfig(
system_message=sys_msg,
Expand All @@ -108,13 +93,14 @@ def chat(opts: CLIOptions) -> None:
)
agent = ChatAgent(config)

match opts.provider:
match provider:
case "google":
search_tool_class = GoogleSearchTool
case "sciphi":
from langroid.agent.tools.sciphi_search_rag_tool import SciPhiSearchRAGTool
search_tool_class = SciPhiSearchRAGTool
case _:
raise ValueError(f"Unsupported provider {opts.provider} specified.")
raise ValueError(f"Unsupported provider {provider} specified.")

agent.enable_message(search_tool_class)
search_tool_handler_method = search_tool_class.default_value("request")
Expand Down Expand Up @@ -142,34 +128,9 @@ def chat(opts: CLIOptions) -> None:
""",
)
# local models do not like the first message to be empty
user_message = "Hello." if (opts.model != "") else None
user_message = "Hello." if (model != "") else None
task.run(user_message)


@app.command()
def main(
debug: bool = typer.Option(False, "--debug", "-d", help="debug mode"),
model: str = typer.Option("", "--model", "-m", help="model name"),
provider: str = typer.Option(
"google", "--provider", "-p", help="search provider name (Google, SciPhi)"
),
no_stream: bool = typer.Option(False, "--nostream", "-ns", help="no streaming"),
nocache: bool = typer.Option(False, "--nocache", "-nc", help="don't use cache"),
cache_type: str = typer.Option(
"redis", "--cachetype", "-ct", help="redis or momento"
),
) -> None:
set_global(
Settings(
debug=debug,
cache=not nocache,
stream=not no_stream,
cache_type=cache_type,
)
)
opts = CLIOptions(model=model, provider=provider)
chat(opts)


if __name__ == "__main__":
app()
Loading

0 comments on commit 85779ce

Please sign in to comment.