diff --git a/src/main.rs b/src/main.rs index 9d8d3bb..aaa8e69 100644 --- a/src/main.rs +++ b/src/main.rs @@ -204,14 +204,16 @@ impl Server { .map_err(|err| anyhow!("Invalid request body, {err}"))?; let is_stream = req_body["stream"].as_bool().unwrap_or_default(); + let mut invalid = false; let mut new_messages = vec![]; - let mut invalid_messages = false; + let mut system_prompt = None; if let Some(messages) = req_body["messages"].as_array() { + let has_history = messages.len() > 2; for v in messages { let role = match v["role"].as_str() { Some(v) => v, None => { - invalid_messages = true; + invalid = true; break; } }; @@ -228,27 +230,50 @@ impl Server { _ => "", }; if text.is_empty() { - invalid_messages = true; + invalid = true; break; } text }; - new_messages.push(json!({ - "id": random_id(), - "author": { "role": role }, - "content": { "content_type": "text", "parts": [content] }, - "metadata": {}, - })); + if role == "system" { + if system_prompt.is_some() { + invalid = true; + break; + } + system_prompt = Some(content.to_string()); + } else if role == "user" && has_history { + new_messages.push(format!("[INST]{content}[/INST]")); + } else { + new_messages.push(content.to_string()); + } } } - if invalid_messages { + if invalid { bail!("Invalid request messages"); } + let mut messages = vec![]; + if let Some(system_prompt) = system_prompt { + messages.push(json!({ + "id": random_id(), + "author": { "role": "system" }, + "content": { "content_type": "text", "parts": [system_prompt] }, + "metadata": {}, + })) + } + + let combine_message = new_messages.join("\n"); + messages.push(json!({ + "id": random_id(), + "author": { "role": "user" }, + "content": { "content_type": "text", "parts": [combine_message] }, + "metadata": {}, + })); + let req_body = json!({ "action": "next", - "messages": new_messages, + "messages": messages, "parent_message_id": random_id(), "model": "text-davinci-002-render-sha", "timezone_offset_min": 0,