Skip to content

Commit

Permalink
Big update to llm crate and configuration of same
Browse files Browse the repository at this point in the history
  • Loading branch information
dev-msp committed Jun 12, 2024
1 parent ee2af4a commit 418ff8e
Show file tree
Hide file tree
Showing 11 changed files with 291 additions and 126 deletions.
62 changes: 62 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions llm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ toml = { workspace = true }
xdg = { workspace = true }
derive_builder = { workspace = true }
anyhow = "1.0.82"
chrono = "0.4.38"
1 change: 1 addition & 0 deletions llm/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::vendor::{ollama, openai::compat};

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub default_system_message: Option<String>,
pub providers: Providers,
}

Expand Down
57 changes: 19 additions & 38 deletions llm/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,39 +1,36 @@
use clap::Parser;
use itertools::Itertools;
use llm::{
vendor::{self, openai},
Config,
};
use llm::{vendor, Config};

#[derive(Debug, clap::Parser)]
#[command(version, about, long_about = None)]
struct App {
#[clap(short, long)]
provider: vendor::Provider,

#[command(subcommand)]
command: Commands,
}

#[derive(Debug, clap::Subcommand)]
enum Commands {
Completion(Completion),
Completion(vendor::Completion),
ListModels,
}

#[derive(Debug, clap::Args)]
struct Completion {
#[clap(short, long)]
provider: vendor::Provider,

#[clap(short, long)]
system_message: Option<String>,

user_message: String,
}

impl Completion {
async fn run(self, config: &Config) -> Result<openai::compat::Response, anyhow::Error> {
self.provider
.completion(config, self.system_message, self.user_message)
.await
impl App {
async fn run(self, config: &Config) -> Result<(), anyhow::Error> {
match self.command {
Commands::Completion(c) => {
let response = self.provider.completion(config, c).await?;
println!("{}", response.content());
}
Commands::ListModels => {
let models = self.provider.list_models(config).await?;
println!("{}", models.iter().join("\n"));
}
}
Ok(())
}
}

Expand All @@ -43,21 +40,5 @@ async fn main() -> Result<(), anyhow::Error> {

let app = App::parse();
let config = Config::read()?;
match app.command {
Commands::Completion(c) => {
let response = c.run(&config).await?;
println!("{}", response.content());
}
Commands::ListModels => {
let models: Vec<_> = vendor::ollama::list_models().await?.into();
println!(
"{:?}",
models
.iter()
.map(|m| (m.name(), m.human_size()))
.collect_vec()
);
}
}
Ok(())
app.run(&config).await
}
14 changes: 4 additions & 10 deletions llm/src/vendor/groq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use serde::{Deserialize, Serialize};

use super::openai::{self, compat::Response};

const GROQ_CHAT_API: &str = "https://api.groq.com/openai/v1/chat/completions";
pub(crate) const GROQ_CHAT_API: &str = "https://api.groq.com/openai/v1";

const MIXTRAL: &str = "mixtral-8x7b-32768";
const LLAMA3_70B: &str = "llama3-70b-8192";
Expand Down Expand Up @@ -59,15 +59,9 @@ impl Model {

pub async fn completion(
api_key: String,
system_message: Option<String>,
model: String,
system_message: String,
user_message: String,
) -> Result<Response, anyhow::Error> {
openai::compat::completion(
GROQ_CHAT_API,
api_key,
Model::default(),
system_message,
user_message,
)
.await
openai::compat::completion(GROQ_CHAT_API, api_key, model, system_message, user_message).await
}
82 changes: 77 additions & 5 deletions llm/src/vendor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,23 @@ use anyhow::anyhow;

use crate::Config;

use self::openai::compat;

pub mod groq;
pub mod ollama;
pub mod openai;

#[derive(Debug, Clone, clap::Args)]
pub struct Completion {
#[clap(short, long)]
model: Option<String>,

#[clap(short, long)]
system_message: Option<String>,

user_message: String,
}

#[derive(Debug, Clone, Copy, clap::ValueEnum)]
#[value(rename_all = "lower")]
pub enum Provider {
Expand All @@ -18,27 +31,86 @@ impl Provider {
pub async fn completion(
&self,
config: &Config,
system_message: Option<String>,
user_message: String,
comp: Completion,
) -> Result<openai::compat::Response, anyhow::Error> {
let model = comp
.model
.clone()
.or_else(|| self.default_model(config))
.ok_or_else(|| {
anyhow!(
"No model specified and no default configured for {}",
self.name()
)
})?;

let system_message = comp
.system_message
.clone()
.or_else(|| config.default_system_message.clone())
.ok_or_else(|| anyhow!("No default system message configured"))?;

match self {
Provider::Groq => {
let provider = &config.providers.groq;
let api_key = provider
.get_api_key()
.await?
.ok_or_else(|| anyhow!("no api key?"))?;
groq::completion(api_key, model, system_message, comp.user_message).await
}
Provider::OpenAi => {
let provider = &config.providers.openai;
let api_key = provider
.get_api_key()
.await?
.ok_or_else(|| anyhow!("no api key?"))?;
openai::completion(api_key, model, system_message, comp.user_message).await
}
Provider::Ollama => {
let model = comp
.model
.ok_or_else(|| anyhow!("No model specified for Ollama"))?;
ollama::completion(model, system_message, comp.user_message).await
}
}
}

pub async fn list_models(&self, config: &Config) -> anyhow::Result<Vec<String>> {
match self {
Provider::Groq => {
let provider = &config.providers.groq;
let api_key = provider
.get_api_key()
.await?
.ok_or_else(|| anyhow!("no api key?"))?;
groq::completion(api_key, system_message, user_message).await
compat::list_models(groq::GROQ_CHAT_API, api_key).await
}
Provider::OpenAi => {
let provider = &config.providers.openai;
let api_key = provider
.get_api_key()
.await?
.ok_or_else(|| anyhow!("no api key?"))?;
openai::completion(api_key, system_message, user_message).await
openai::list_models(api_key).await
}
Provider::Ollama => ollama::completion(system_message, user_message).await,
Provider::Ollama => ollama::list_models().await,
}
}

fn name(&self) -> &'static str {
match self {
Provider::Groq => "Groq",
Provider::OpenAi => "OpenAI",
Provider::Ollama => "Ollama",
}
}

fn default_model(&self, config: &Config) -> Option<String> {
match self {
Provider::Groq => config.providers.groq.default_model(),
Provider::OpenAi => config.providers.openai.default_model(),
Provider::Ollama => config.providers.ollama.default_model(),
}
}
}
Expand Down
10 changes: 5 additions & 5 deletions llm/src/vendor/ollama/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ impl Default for Host {
}
}

struct Model(String);

#[derive(Debug, Clone, Deserialize)]
pub struct ListModelsResponse {
models: Vec<LocalModel>,
Expand Down Expand Up @@ -94,11 +92,13 @@ impl LocalModel {
}
}

pub async fn list_models() -> anyhow::Result<ListModelsResponse> {
let resp = reqwest::Client::new()
pub async fn list_models() -> anyhow::Result<Vec<String>> {
let resp: ListModelsResponse = reqwest::Client::new()
.get(format!("{OLLAMA_API}/tags"))
.send()
.await?
.json()
.await?;

Ok(resp.json().await?)
Ok(resp.models.into_iter().map(|m| m.name).collect())
}
Loading

0 comments on commit 418ff8e

Please sign in to comment.