Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion benchmarks/haystack/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,10 @@ async fn main() -> Result<(), SpnlError> {
return Ok(());
}

match execute(&query, &ExecuteOptions { prepare: None }).await? {
match execute(&query, &ExecuteOptions { prepare: None })
.await?
.result
{
Query::Message(User(ss)) => {
// oof, be gracious here. sometimes the model wraps the
// requested json array with markdown even though we asked
Expand Down
15 changes: 9 additions & 6 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,15 @@ async fn main() -> Result<(), SpnlError> {
ptree::write_tree(&query, ::std::io::stderr())?;
}

let res = execute(&query, &rp).await.map(|res| {
if !res.to_string().is_empty() {
println!("{res}");
}
Ok(())
})?;
let res = execute(&query, &rp)
.await
.map(|res| res.result)
.map(|res| {
if !res.to_string().is_empty() {
println!("{res}");
}
Ok(())
})?;

if let Some(time) = time {
eprintln!("{}", time.elapsed().as_millis());
Expand Down
6 changes: 5 additions & 1 deletion spnl/src/augment/index/raptor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::{
AugmentOptions,
embed::{EmbedData, embed},
},
execute::TimestampedQuery,
generate::generate,
};

Expand Down Expand Up @@ -153,7 +154,10 @@ async fn cross_index_fragment(
)
.await?
{
Query::Message(User(s)) => s,
TimestampedQuery {
result: Query::Message(User(s)),
..
} => s,
_ => "".into(),
};

Expand Down
78 changes: 54 additions & 24 deletions spnl/src/execute.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,44 @@
use crate::{Generate, Message::*, Query};
use ::std::time::Instant;
use indicatif::MultiProgress;

pub struct ExecuteOptions {
/// Prepare query?
pub prepare: Option<bool>,
}

pub struct TimestampedQuery {
pub finish_time: Instant,
pub result: Query,
}

impl From<&Query> for TimestampedQuery {
fn from(result: &Query) -> Self {
Self {
finish_time: Instant::now(),
result: result.clone(),
}
}
}

impl From<Query> for TimestampedQuery {
fn from(result: Query) -> Self {
Self {
finish_time: Instant::now(),
result,
}
}
}

pub type SpnlError = anyhow::Error;
pub type SpnlResult = anyhow::Result<crate::Query>;
pub type SpnlResult = anyhow::Result<TimestampedQuery>;

async fn seq(
async fn run_sequentially(
units: &[Query],
rp: &ExecuteOptions,
mm: Option<&MultiProgress>,
) -> anyhow::Result<Vec<Query>> {
f: fn(Vec<Query>) -> Query,
) -> anyhow::Result<TimestampedQuery> {
let mym = MultiProgress::new();
let m = if let Some(m) = mm { m } else { &mym };

Expand All @@ -22,32 +47,37 @@ async fn seq(
evaluated.push(run_subtree(u, rp, Some(m)).await?);
}

Ok(evaluated)
}

async fn par(units: &[Query], rp: &ExecuteOptions) -> SpnlResult {
let m = MultiProgress::new();
let evaluated =
futures::future::try_join_all(units.iter().map(|u| run_subtree(u, rp, Some(&m)))).await?;

if evaluated.len() == 1 {
// the unwrap() is safe here, due to the len() == 1 guard
Ok(evaluated.into_iter().next().unwrap())
} else {
Ok(Query::Par(evaluated))
Ok(f(evaluated.into_iter().map(|q| q.result).collect()).into())
}
}

async fn plus(units: &[Query], rp: &ExecuteOptions) -> SpnlResult {
async fn run_in_parallel(
units: &[Query],
rp: &ExecuteOptions,
f: fn(Vec<Query>) -> Query,
) -> SpnlResult {
let m = MultiProgress::new();
let evaluated =
let mut evaluated =
futures::future::try_join_all(units.iter().map(|u| run_subtree(u, rp, Some(&m)))).await?;

if evaluated.len() == 1 {
// the unwrap() is safe here, due to the len() == 1 guard
Ok(evaluated.into_iter().next().unwrap())
} else {
Ok(Query::Plus(evaluated))
// Reverse sort the children output so that the first to finish is at the end
evaluated.sort_by_key(|q| ::std::cmp::Reverse(q.finish_time));
let max_finish_time = evaluated
.first()
.map(|q| q.finish_time)
.unwrap_or_else(Instant::now);
Ok(TimestampedQuery {
finish_time: max_finish_time,
result: f(evaluated.into_iter().map(|q| q.result).collect()),
})
}
}

Expand All @@ -61,12 +91,12 @@ async fn run_subtree(query: &Query, rp: &ExecuteOptions, m: Option<&MultiProgres
crate::pull::pull_if_needed(query).await?;

match query {
Query::Message(_) => Ok(query.clone()),
Query::Message(_) => Ok(query.into()),

Query::Par(u) => par(u, rp).await,
Query::Seq(u) => Ok(Query::Seq(seq(u, rp, m).await?)),
Query::Cross(u) => Ok(Query::Cross(seq(u, rp, m).await?)),
Query::Plus(u) => plus(u, rp).await,
Query::Par(u) => run_in_parallel(u, rp, Query::Par).await,
Query::Plus(u) => run_in_parallel(u, rp, Query::Plus).await,
Query::Seq(u) => run_sequentially(u, rp, m, Query::Seq).await,
Query::Cross(u) => run_sequentially(u, rp, m, Query::Cross).await,

Query::Generate(Generate {
model,
Expand All @@ -76,7 +106,7 @@ async fn run_subtree(query: &Query, rp: &ExecuteOptions, m: Option<&MultiProgres
}) => {
crate::generate::generate(
model.as_str(),
&run_subtree(input, rp, m).await?,
&run_subtree(input, rp, m).await?.result,
max_tokens,
temperature,
m,
Expand All @@ -88,7 +118,7 @@ async fn run_subtree(query: &Query, rp: &ExecuteOptions, m: Option<&MultiProgres
#[cfg(feature = "cli_support")]
Query::Print(m) => {
println!("{m}");
Ok(Query::Message(User("".into())))
Ok(Query::Message(User("".into())).into())
}
#[cfg(feature = "cli_support")]
Query::Ask(message) => {
Expand All @@ -106,7 +136,7 @@ async fn run_subtree(query: &Query, rp: &ExecuteOptions, m: Option<&MultiProgres
Err(err) => panic!("{}", err), // TODO this only works in a CLI
};
rl.append_history("history.txt").unwrap();
Ok(Query::Message(User(prompt)))
Ok(Query::Message(User(prompt)).into())
}

// TODO: should not happen; we need to improve the typing of runnable queries
Expand All @@ -123,7 +153,7 @@ mod tests {
#[tokio::test]
async fn it_works() -> Result<(), SpnlError> {
let result = execute(&"hello".into(), &ExecuteOptions { prepare: None }).await?;
assert_eq!(result, Query::Message(User("hello".to_string())));
assert_eq!(result.result, Query::Message(User("hello".to_string())));
Ok(())
}
}
2 changes: 1 addition & 1 deletion spnl/src/generate/backend/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ pub async fn generate(
stdout.write_all(b"\n").await?;
}

Ok(Query::Message(Assistant(response_string)))
Ok(Query::Message(Assistant(response_string)).into())
}

pub fn messagify(input: &Query) -> Vec<ChatCompletionRequestMessage> {
Expand Down
2 changes: 1 addition & 1 deletion spnl/src/generate/backend/spnl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,5 @@ pub async fn generate(
stdout.write_all(b"\n").await?;
}

Ok(Query::Message(Assistant(response_string)))
Ok(Query::Message(Assistant(response_string)).into())
}
2 changes: 1 addition & 1 deletion spnl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pub async fn execute(q: String) -> Result<ChatResponse, PyErr> {
));

res.map(|res| ChatResponse {
data: res.to_string(),
data: res.result.to_string(),
model_id: None,
usage: None,
})
Expand Down