From 3f728c4737dd69eb09d2a98360f496a9da79edc7 Mon Sep 17 00:00:00 2001 From: xsigoking <152482559+xsigoking@users.noreply.github.com> Date: Sun, 21 Apr 2024 22:30:14 +0000 Subject: [PATCH] feat: support history messages (#15) --- src/main.rs | 47 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 36 insertions(+), 11 deletions(-) diff --git a/src/main.rs b/src/main.rs index 9d8d3bb..aaa8e69 100644 --- a/src/main.rs +++ b/src/main.rs @@ -204,14 +204,16 @@ impl Server { .map_err(|err| anyhow!("Invalid request body, {err}"))?; let is_stream = req_body["stream"].as_bool().unwrap_or_default(); + let mut invalid = false; let mut new_messages = vec![]; - let mut invalid_messages = false; + let mut system_prompt = None; if let Some(messages) = req_body["messages"].as_array() { + let has_history = messages.len() > 2; for v in messages { let role = match v["role"].as_str() { Some(v) => v, None => { - invalid_messages = true; + invalid = true; break; } }; @@ -228,27 +230,50 @@ impl Server { _ => "", }; if text.is_empty() { - invalid_messages = true; + invalid = true; break; } text }; - new_messages.push(json!({ - "id": random_id(), - "author": { "role": role }, - "content": { "content_type": "text", "parts": [content] }, - "metadata": {}, - })); + if role == "system" { + if system_prompt.is_some() { + invalid = true; + break; + } + system_prompt = Some(content.to_string()); + } else if role == "user" && has_history { + new_messages.push(format!("[INST]{content}[/INST]")); + } else { + new_messages.push(content.to_string()); + } } } - if invalid_messages { + if invalid { bail!("Invalid request messages"); } + let mut messages = vec![]; + if let Some(system_prompt) = system_prompt { + messages.push(json!({ + "id": random_id(), + "author": { "role": "system" }, + "content": { "content_type": "text", "parts": [system_prompt] }, + "metadata": {}, + })) + } + + let combine_message = new_messages.join("\n"); + messages.push(json!({ + "id": random_id(), + "author": { "role": "user" }, + "content": { "content_type": "text", "parts": [combine_message] }, + "metadata": {}, + })); + let req_body = json!({ "action": "next", - "messages": new_messages, + "messages": messages, "parent_message_id": random_id(), "model": "text-davinci-002-render-sha", "timezone_offset_min": 0,