Skip to content

Commit

Permalink
integrate ollama for basic chat
Browse files Browse the repository at this point in the history
  • Loading branch information
neindochoh committed Jan 23, 2024
1 parent 1788b76 commit a80feb0
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 35 deletions.
38 changes: 28 additions & 10 deletions renumics/spotlight/backend/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import asyncio
import json
from typing import (
Any,
Coroutine,
Expand All @@ -21,6 +22,7 @@
from loguru import logger
from pydantic import BaseModel
from typing_extensions import Literal
import httpx

from renumics.spotlight.data_store import DataStore

Expand Down Expand Up @@ -337,13 +339,29 @@ class ChatData(BaseModel):

@message_handler("chat", ChatData)
async def _(data: ChatData, connection: WebsocketConnection) -> None:
# TODO: integrate LLM
await asyncio.sleep(3)
llm_response = f"llm response for {data.message}"
print(llm_response)
await connection.send_async(
Message(
type="chat.response",
data={"chat_id": data.chat_id, "message": llm_response},
)
)
async with httpx.AsyncClient(
base_url="http://localhost:11434/api/"
) as ollama_client:
async with ollama_client.stream(
"POST",
"chat",
json={
"model": "openhermes2",
"stream": True,
"messages": [{"role": "user", "content": data.message}],
},
timeout=None,
) as stream:
async for chunk in stream.aiter_text():
try:
response = json.loads(chunk)
except json.JSONDecodeError:
break
llm_response = response["message"]["content"]

await connection.send_async(
Message(
type="chat.response",
data={"chat_id": data.chat_id, "message": llm_response},
)
)
29 changes: 17 additions & 12 deletions src/services/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,24 @@ class ChatService {
);
}

async chat(message: string) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
return new Promise<any>((resolve, reject) => {
const chat_id = crypto.randomUUID();
this.dispatchTable.set(chat_id, { resolve, reject });
websocketService.send({
type: 'chat',
data: {
chat_id: chat_id,
message,
},
});
async *stream(message: string) {
const chat_id = crypto.randomUUID();
websocketService.send({
type: 'chat',
data: {
chat_id: chat_id,
message,
},
});

let msg = '';
do {
const promise = new Promise<string>((resolve, reject) => {
this.dispatchTable.set(chat_id, { resolve, reject });
});
msg = await promise;
yield msg;
} while (msg !== '');
}
}

Expand Down
44 changes: 31 additions & 13 deletions src/widgets/LLMWidget/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,56 @@ import WidgetContent from '../../components/ui/WidgetContent';
import BrainIcon from '../../icons/Brain';
import DeleteIcon from '../../icons/Delete';
import tw from 'twin.macro';
import { KeyboardEvent, useRef, useState } from 'react';
import { KeyboardEvent, useCallback, useRef, useState } from 'react';
import Spinner from '../../components/ui/Spinner';
import Button from '../../components/ui/Button';
import chatService from '../../services/chat';

interface Message {
content: string;
processing?: boolean;
}

const LLMWidget: Widget = () => {
const [chat, setChat] = useState<Array<string>>([]);
const [chat, setChat] = useState<Array<Message>>([]);
const [processing, setProcessing] = useState(false);

const queryInputRef = useRef<HTMLInputElement>(null);

const handleKeyUp = (e: KeyboardEvent) => {
const handleKeyUp = useCallback((e: KeyboardEvent) => {
if (!queryInputRef.current) return;

if (e.key == 'Enter') {
const query = queryInputRef.current.value;
queryInputRef.current.value = '';
setProcessing(true);
setChat((state) => [...state, query]);
setChat((state) => [...state, { content: query }]);

const processQuery = async () => {
const response = await chatService.chat(query);
setChat((state) => [...state, response]);
setChat((messages) => [...messages, { content: '', processing: true }]);

const stream = chatService.stream(query);
for await (const response of stream) {
setChat((messages) => {
const lastMsg = messages[messages.length - 1];
return [
...messages.slice(0, messages.length - 1),
{ content: lastMsg.content + response, processing: true },
];
});
}
setChat((messages) => {
const lastMsg = messages[messages.length - 1];
return [
...messages.slice(0, messages.length - 1),
{ content: lastMsg.content, processing: false },
];
});
setProcessing(false);
};
processQuery();
}
};
}, []);

const clearChat = () => setChat([]);

Expand All @@ -47,14 +69,10 @@ const LLMWidget: Widget = () => {
<div tw="flex-grow flex flex-col p-1 space-y-1">
{chat.map((message, i) => (
<div tw="bg-gray-100 px-1 py-0.5 rounded" key={i}>
{message}
{message.content}
{message.processing && <Spinner tw="w-4 h-4" />}
</div>
))}
{processing && (
<div tw="bg-gray-100 px-1 py-0.5 rounded">
<Spinner tw="w-4 h-4" />
</div>
)}
</div>
<div tw="p-1 relative">
<input
Expand Down

0 comments on commit a80feb0

Please sign in to comment.