Skip to content

Commit 21d9ebf

Browse files
Plugins intermediate steps reporting (#3238)
This PR is extracted from my OA Browsing experimental branch, and it should improve the visual/functional usability of plugins in a way. Currently, when the plugin is enabled, on the front-end there will be rendered QueueInfo with a message that you are in position 0, for the whole time while the plugin system is executing its steps. And this PR will change that so that when the plugin is enabled, it will be sending intermediate steps while executing each call to the LLM, so QueueInfo will be discarded, as soon as the first step in the plugin system is executed. Also, users now will have "real-time" tracking of LLM inner thoughts, plans etc... Demo: https://github.com/LAION-AI/Open-Assistant/assets/13547364/30252e22-d5d9-4200-9ccf-3a209d8581a7 --------- Co-authored-by: Oliver Stanley <[email protected]>
1 parent 54f4408 commit 21d9ebf

File tree

9 files changed

+176
-6
lines changed

9 files changed

+176
-6
lines changed

inference/server/oasst_inference_server/routes/chats.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,17 @@ async def event_generator(chat_id: str, message_id: str, worker_compat_hash: str
244244
).json(),
245245
}
246246

247+
if response_packet.response_type == "plugin_intermediate":
248+
logger.info(f"Received plugin intermediate response {chat_id}")
249+
yield {
250+
"data": chat_schema.PluginIntermediateResponseEvent(
251+
current_plugin_thought=response_packet.current_plugin_thought,
252+
current_plugin_action_taken=response_packet.current_plugin_action_taken,
253+
current_plugin_action_input=response_packet.current_plugin_action_input,
254+
current_plugin_action_response=response_packet.current_plugin_action_response,
255+
).json(),
256+
}
257+
247258
if response_packet.response_type == "internal_error":
248259
yield {
249260
"data": chat_schema.ErrorResponseEvent(

inference/server/oasst_inference_server/routes/workers.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,12 @@ def _add_receive(ftrs: set):
218218
response=worker_response,
219219
work_request_map=work_request_map,
220220
)
221+
case "plugin_intermediate":
222+
worker_response = cast(inference.PluginIntermediateResponse, worker_response)
223+
await handle_plugin_intermediate_response(
224+
work_request_map=work_request_map,
225+
response=worker_response,
226+
)
221227
case _:
222228
raise RuntimeError(f"Unknown response type: {worker_response.response_type}")
223229
finally:
@@ -338,6 +344,19 @@ async def handle_token_response(
338344
work_response_container.num_responses += 1
339345

340346

347+
async def handle_plugin_intermediate_response(
348+
response: inference.PluginIntermediateResponse,
349+
work_request_map: WorkRequestContainerMap,
350+
):
351+
work_response_container = get_work_request_container(work_request_map, response.request_id)
352+
message_queue = queueing.message_queue(
353+
deps.redis_client,
354+
message_id=work_response_container.message_id,
355+
)
356+
await message_queue.enqueue(response.json())
357+
work_response_container.num_responses += 1
358+
359+
341360
async def handle_generated_text_response(
342361
response: inference.GeneratedTextResponse,
343362
work_request_map: WorkRequestContainerMap,

inference/server/oasst_inference_server/schemas/chat.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,23 @@ class SafePromptResponseEvent(pydantic.BaseModel):
4646
message: inference.MessageRead
4747

4848

49+
class PluginIntermediateResponseEvent(pydantic.BaseModel):
50+
event_type: Literal["plugin_intermediate"] = "plugin_intermediate"
51+
current_plugin_thought: str
52+
current_plugin_action_taken: str
53+
current_plugin_action_input: str
54+
current_plugin_action_response: str
55+
message: inference.MessageRead | None = None
56+
57+
4958
ResponseEvent = Annotated[
50-
Union[TokenResponseEvent, ErrorResponseEvent, MessageResponseEvent, SafePromptResponseEvent],
59+
Union[
60+
TokenResponseEvent,
61+
ErrorResponseEvent,
62+
MessageResponseEvent,
63+
SafePromptResponseEvent,
64+
PluginIntermediateResponseEvent,
65+
],
5166
pydantic.Field(discriminator="event_type"),
5267
]
5368

inference/worker/chat_chain.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import interface
44
import transformers
55
import utils
6+
import websocket
67
from chat_chain_prompts import (
78
ASSISTANT_PREFIX,
89
HUMAN_PREFIX,
@@ -108,6 +109,8 @@ def handle_plugin_usage(
108109
tools: list[Tool],
109110
plugin: inference.PluginEntry | None,
110111
plugin_max_depth: int,
112+
ws: websocket.WebSocket,
113+
work_request_id: str,
111114
) -> tuple[str, inference.PluginUsed]:
112115
execution_details = inference.PluginExecutionDetails(
113116
inner_monologue=[],
@@ -142,17 +145,46 @@ def handle_plugin_usage(
142145
tokenizer, worker_config, parameters, prompt_template, memory, tool_names, language, action_input_format
143146
)
144147

148+
# send "thinking..." intermediate step to UI (This will discard queue position 0) immediately
149+
utils.send_response(
150+
ws,
151+
inference.PluginIntermediateResponse(
152+
request_id=work_request_id,
153+
current_plugin_thought="thinking...",
154+
current_plugin_action_taken="",
155+
current_plugin_action_input="",
156+
current_plugin_action_response="",
157+
),
158+
)
159+
145160
init_prompt = f"{input_prompt}{eos_token}{V2_ASST_PREFIX}"
146161
init_prompt, chain_response = chain.call(init_prompt)
147162

148163
inner_monologue.append("In: " + str(init_prompt))
149164
inner_monologue.append("Out: " + str(chain_response))
150165

166+
current_action_thought = ""
167+
if THOUGHT_SEQ in chain_response:
168+
current_action_thought = chain_response.split(THOUGHT_SEQ)[1].split("\n")[0]
169+
151170
# Tool name/assistant prefix, Tool input/assistant response
152171
prefix, response = extract_tool_and_input(llm_output=chain_response, ai_prefix=ASSISTANT_PREFIX)
153172
assisted = False if ASSISTANT_PREFIX in prefix else True
154173
chain_finished = not assisted
155174

175+
if assisted:
176+
# model decided to use a tool, so send that thought to the client
177+
utils.send_response(
178+
ws,
179+
inference.PluginIntermediateResponse(
180+
request_id=work_request_id,
181+
current_plugin_thought=current_action_thought,
182+
current_plugin_action_taken=prefix,
183+
current_plugin_action_input=chain_response,
184+
current_plugin_action_response=response,
185+
),
186+
)
187+
156188
while not chain_finished and assisted and achieved_depth < plugin_max_depth:
157189
tool_response = use_tool(prefix, response, tools)
158190

@@ -165,6 +197,22 @@ def handle_plugin_usage(
165197
inner_monologue.append("In: " + str(new_prompt))
166198
inner_monologue.append("Out: " + str(chain_response))
167199

200+
current_action_thought = ""
201+
if THOUGHT_SEQ in chain_response:
202+
current_action_thought = chain_response.split(THOUGHT_SEQ)[1].split("\n")[0]
203+
204+
# Send deep plugin intermediate steps to UI
205+
utils.send_response(
206+
ws,
207+
inference.PluginIntermediateResponse(
208+
request_id=work_request_id,
209+
current_plugin_thought=current_action_thought,
210+
current_plugin_action_taken=prefix,
211+
current_plugin_action_input=chain_response,
212+
current_plugin_action_response=response,
213+
),
214+
)
215+
168216
prefix, response = extract_tool_and_input(llm_output=chain_response, ai_prefix=ASSISTANT_PREFIX)
169217
assisted = False if ASSISTANT_PREFIX in prefix else True
170218

@@ -286,6 +334,7 @@ def handle_conversation(
286334
worker_config: inference.WorkerConfig,
287335
parameters: interface.GenerateStreamParameters,
288336
tokenizer: transformers.PreTrainedTokenizer,
337+
ws: websocket.WebSocket,
289338
) -> tuple[str, inference.PluginUsed | None]:
290339
try:
291340
original_prompt = work_request.thread.messages[-1].content
@@ -323,6 +372,8 @@ def handle_conversation(
323372
tools,
324373
plugin,
325374
work_request.parameters.plugin_max_depth,
375+
ws,
376+
work_request.id,
326377
)
327378

328379
return handle_standard_usage(original_prompt, prompt_template, language, memory, worker_config, tokenizer)

inference/worker/work.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def handle_work_request(
8484

8585
for plugin in parameters.plugins:
8686
if plugin.enabled:
87-
prompt, used_plugin = chat_chain.handle_conversation(work_request, worker_config, parameters, tokenizer)
87+
prompt, used_plugin = chat_chain.handle_conversation(work_request, worker_config, parameters, tokenizer, ws)
8888
# When using plugins and final prompt is truncated due to length limit
8989
# LLaMA has tendency to leak internal prompts and generate bad continuations
9090
# So we add keywords/sequences to the stop sequences to reduce this

oasst-shared/oasst_shared/schemas/inference.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,15 @@ class SafePromptResponse(WorkerResponseBase):
329329
safety_rots: str
330330

331331

332+
class PluginIntermediateResponse(WorkerResponseBase):
333+
response_type: Literal["plugin_intermediate"] = "plugin_intermediate"
334+
text: str = ""
335+
current_plugin_thought: str
336+
current_plugin_action_taken: str
337+
current_plugin_action_input: str
338+
current_plugin_action_response: str
339+
340+
332341
class TokenResponse(WorkerResponseBase):
333342
response_type: Literal["token"] = "token"
334343
text: str
@@ -389,6 +398,7 @@ class GeneralErrorResponse(WorkerResponseBase):
389398
InternalFinishedMessageResponse,
390399
InternalErrorResponse,
391400
SafePromptResponse,
401+
PluginIntermediateResponse,
392402
],
393403
pydantic.Field(discriminator="response_type"),
394404
]

website/src/components/Chat/ChatConversation.tsx

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import { UseFormGetValues } from "react-hook-form";
77
import SimpleBar from "simplebar-react";
88
import { useMessageVote } from "src/hooks/chat/useMessageVote";
99
import { get, post } from "src/lib/api";
10-
import { handleChatEventStream, QueueInfo } from "src/lib/chat_stream";
10+
import { handleChatEventStream, QueueInfo, PluginIntermediateResponse } from "src/lib/chat_stream";
1111
import { OasstError } from "src/lib/oasst_api_client";
1212
import { API_ROUTES } from "src/lib/routes";
1313
import {
@@ -38,6 +38,7 @@ export const ChatConversation = memo(function ChatConversation({ chatId, getConf
3838

3939
const [streamedResponse, setResponse] = useState<string | null>(null);
4040
const [queueInfo, setQueueInfo] = useState<QueueInfo | null>(null);
41+
const [pluginIntermediateResponse, setPluginIntermediateResponse] = useState<PluginIntermediateResponse | null>(null);
4142
const [isSending, setIsSending] = useBoolean();
4243
const [showEncourageMessage, setShowEncourageMessage] = useBoolean(false);
4344

@@ -98,9 +99,16 @@ export const ChatConversation = memo(function ChatConversation({ chatId, getConf
9899
message = await handleChatEventStream({
99100
stream: body!,
100101
onError: console.error,
101-
onPending: setQueueInfo,
102+
onPending: (data) => {
103+
setQueueInfo(data);
104+
setPluginIntermediateResponse(null);
105+
},
106+
onPluginIntermediateResponse: setPluginIntermediateResponse,
102107
onToken: async (text) => {
103108
setQueueInfo(null);
109+
if (text != "") {
110+
setPluginIntermediateResponse(null);
111+
}
104112
setResponse(text);
105113
await new Promise(requestAnimationFrame);
106114
},
@@ -305,6 +313,33 @@ export const ChatConversation = memo(function ChatConversation({ chatId, getConf
305313
{t("queue_info", queueInfo)}
306314
</Badge>
307315
)}
316+
{pluginIntermediateResponse && pluginIntermediateResponse.currentPluginThought && (
317+
<Box
318+
position="absolute"
319+
bottom="0"
320+
left="50%"
321+
transform="translate(-50%)"
322+
display="flex"
323+
flexDirection="row"
324+
gap="1"
325+
justifyContent="center"
326+
alignItems="center"
327+
>
328+
<Box
329+
bg="purple.700"
330+
color="white"
331+
px="2"
332+
py="2.5px"
333+
borderRadius="8px"
334+
maxWidth="50vw"
335+
fontSize="11"
336+
fontWeight="bold"
337+
isTruncated
338+
>
339+
{pluginIntermediateResponse.currentPluginThought}
340+
</Box>
341+
</Box>
342+
)}
308343
</Box>
309344
<ChatForm ref={inputRef} isSending={isSending} onSubmit={sendPrompterMessage}></ChatForm>
310345
<ChatWarning />

website/src/lib/chat_stream.ts

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,27 @@ export interface QueueInfo {
55
queueSize: number;
66
}
77

8+
export interface PluginIntermediateResponse {
9+
currentPluginThought: string;
10+
currentPluginAction: string;
11+
currentPluginActionResponse: string;
12+
currentPluginActionInput: string;
13+
}
14+
815
export interface ChatStreamHandlerOptions {
916
stream: ReadableStream<Uint8Array>;
1017
onError: (err: unknown) => unknown;
1118
onPending: (info: QueueInfo) => unknown;
1219
onToken: (partialMessage: string) => unknown;
20+
onPluginIntermediateResponse: (pluginIntermediateResponse: PluginIntermediateResponse) => unknown;
1321
}
1422

1523
export async function handleChatEventStream({
1624
stream,
1725
onError,
1826
onPending,
1927
onToken,
28+
onPluginIntermediateResponse,
2029
}: ChatStreamHandlerOptions): Promise<InferenceMessage | null> {
2130
let tokens = "";
2231
for await (const { event, data } of iteratorSSE(stream)) {
@@ -39,8 +48,15 @@ export async function handleChatEventStream({
3948
// handle error
4049
await onError(chunk.error);
4150
return chunk.message;
51+
} else if (chunk.event_type === "plugin_intermediate") {
52+
await onPluginIntermediateResponse({
53+
currentPluginThought: chunk.current_plugin_thought,
54+
currentPluginAction: chunk.current_plugin_action_taken,
55+
currentPluginActionResponse: chunk.current_plugin_action_response,
56+
currentPluginActionInput: chunk.current_plugin_action_input,
57+
});
4258
} else {
43-
console.error("Unexpected event", chunk);
59+
console.log("Unexpected event", chunk);
4460
}
4561
} catch (e) {
4662
console.error(`Error parsing data: ${data}, error: ${e}`);

website/src/types/Chat.ts

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,20 @@ interface InferenceEventPending {
8080
queue_size: number;
8181
}
8282

83-
export type InferenceEvent = InferenceEventMessage | InferenceEventError | InferenceEventToken | InferenceEventPending;
83+
interface InferenceEventPluginIntermediateStep {
84+
event_type: "plugin_intermediate";
85+
current_plugin_thought: string;
86+
current_plugin_action_taken: string;
87+
current_plugin_action_response: string;
88+
current_plugin_action_input: string;
89+
}
90+
91+
export type InferenceEvent =
92+
| InferenceEventMessage
93+
| InferenceEventError
94+
| InferenceEventToken
95+
| InferenceEventPending
96+
| InferenceEventPluginIntermediateStep;
8497

8598
export type ModelInfo = {
8699
name: string;

0 commit comments

Comments
 (0)