Skip to content

Commit

Permalink
feat(llm): generalize openai-style completion, add openai option
Browse files Browse the repository at this point in the history
  • Loading branch information
dev-msp committed Apr 27, 2024
1 parent 25856dc commit 88767d8
Show file tree
Hide file tree
Showing 8 changed files with 311 additions and 166 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ log = "0.4.21"
serde = { version = "1.0.197", features = ["derive"] }
serde_json = "1.0.115"
tokio = { version = "1.37.0", features = ["full"] }
derive_builder = "0.20.0"

[dependencies]
cpal = "0.14.0"
Expand All @@ -37,4 +38,4 @@ env_logger = { workspace = true }
log = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
derive_builder = "0.20.0"
derive_builder = { workspace = true }
1 change: 1 addition & 0 deletions llm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ log = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tokio = { workspace = true }
derive_builder = { workspace = true }
anyhow = "1.0.82"
43 changes: 38 additions & 5 deletions llm/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,55 @@
mod vendor;

use clap::Parser;
use clap::{Parser, ValueEnum};
use vendor::openai::compat::Response;

#[derive(Debug, Clone, Copy, ValueEnum)]
enum Provider {
OpenAi,
Groq,
}

impl Provider {
async fn completion(
&self,
system_message: Option<String>,
user_message: String,
) -> Result<Response, anyhow::Error> {
match self {
Provider::Groq => {
let api_key = get_env("GROQ_API_KEY")?;
vendor::groq::completion(api_key, system_message, user_message).await
}
Provider::OpenAi => {
let api_key = get_env("OPENAI_API_KEY")?;
vendor::openai::completion(api_key, system_message, user_message).await
}
}
}
}

fn get_env(key: &str) -> Result<String, anyhow::Error> {
std::env::var(key).map_err(|_| anyhow::anyhow!("{} is not set", key))
}

#[derive(Debug, clap::Parser)]
struct App {
#[clap(short, long)]
provider: Provider,

#[clap(short, long)]
system_message: Option<String>,

user_message: String,
}

/// The async main entry point of the application.
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
let app = App::parse();
let api_key = std::env::var("GROQ_API_KEY")?;

let response = vendor::groq::completion(api_key, app.system_message, app.user_message).await?;
let response = app
.provider
.completion(app.system_message, app.user_message)
.await?;
println!("{response}");
Ok(())
}
173 changes: 13 additions & 160 deletions llm/src/vendor/groq.rs
Original file line number Diff line number Diff line change
@@ -1,125 +1,18 @@
use std::collections::VecDeque;

use itertools::Itertools;
use serde::{Deserialize, Serialize};

use super::openai::{self, compat::Response};

const GROQ_CHAT_API: &str = "https://api.groq.com/openai/v1/chat/completions";

const MIXTRAL: &str = "mixtral-8x7b-32768";
const LLAMA3_70B: &str = "llama3-70b-8192";
const LLAMA3_8B: &str = "llama3-8b-8192";

#[derive(Serialize, Deserialize)]
pub struct CompletionRequest {
messages: Chat,
model: Model,
temperature: f32,
max_tokens: i32,
top_p: f32,
stream: bool,
stop: Option<String>,
}

impl Default for CompletionRequest {
fn default() -> Self {
Self {
messages: Chat::default(),
model: Model::Llama3_70B,
temperature: 0.0,
max_tokens: 0,
top_p: 0.0,
stream: false,
stop: None,
}
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionResponse(serde_json::Value);

impl std::fmt::Display for CompletionResponse {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(into = "VecDeque<Message>", try_from = "VecDeque<Message>")]
struct Chat {
system: String,
messages: VecDeque<Message>,
}

impl Default for Chat {
fn default() -> Self {
let msg = Message {
role: Role::User,
content: "Hello!".to_string(),
};
let mut messages = VecDeque::new();
messages.push_back(msg);
Self {
system: "You are a helpful assistant.".to_string(),
messages,
}
}
}

impl From<Chat> for VecDeque<Message> {
fn from(chat: Chat) -> Self {
let mut messages = chat.messages;
messages.push_front(Message {
role: Role::System,
content: chat.system,
});
messages
}
}

#[derive(Debug, Clone, Copy, thiserror::Error)]
enum ConversionError {
#[error("no system message found")]
NoSystem,

#[error("multiple system messages found")]
MultipleSystem,

#[error("no user messages found")]
Empty,
}

impl TryFrom<VecDeque<Message>> for Chat {
type Error = ConversionError;

fn try_from(messages: VecDeque<Message>) -> Result<Self, Self::Error> {
let mut msgs = messages.into_iter();
let system = match msgs.next() {
Some(ref msg @ Message { ref content, .. }) if msg.is_system() => content.clone(),
_ => return Err(ConversionError::NoSystem),
};

let messages: VecDeque<Message> = msgs
.map(|msg| {
if msg.is_system() {
Err(ConversionError::MultipleSystem)
} else {
Ok(msg)
}
})
.try_collect()?;

if messages.is_empty() {
return Err(ConversionError::Empty);
}

Ok(Self { system, messages })
}
}

#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize)]
#[serde(try_from = "&str", into = "&str")]
enum Model {
Mixtral,
#[default]
Llama3_70B,
Llama3_8B,
}
Expand Down Expand Up @@ -164,57 +57,17 @@ impl Model {
}
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum Role {
#[serde(rename = "system")]
System,

#[serde(rename = "assistant")]
Assistant,

#[serde(rename = "user")]
User,
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Message {
role: Role,
content: String,
}

impl Message {
pub fn is_system(&self) -> bool {
self.role == Role::System
}
}

pub async fn completion(
api_key: String,
system_message: Option<String>,
user_message: String,
) -> Result<CompletionResponse, anyhow::Error> {
let system_message =
system_message.unwrap_or_else(|| "You are a helpful assistant.".to_string());
let chat = Chat {
system: system_message,
messages: vec![Message {
role: Role::User,
content: user_message,
}]
.into(),
};
let req = CompletionRequest {
messages: chat,
..Default::default()
};

let response: reqwest::Response = reqwest::Client::new()
.post(GROQ_CHAT_API)
.bearer_auth(api_key)
.json(&req)
.send()
.await?;

let hey: CompletionResponse = response.json().await?;
Ok(hey)
) -> Result<Response, anyhow::Error> {
openai::compat::completion(
GROQ_CHAT_API,
api_key,
Model::default(),
system_message,
user_message,
)
.await
}
1 change: 1 addition & 0 deletions llm/src/vendor/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod groq;
pub mod openai;
Loading

0 comments on commit 88767d8

Please sign in to comment.