Skip to content

Commit edf2fb9

Browse files
committed
feat: for Par and Plus, sort results by finish time
BREAKING CHANGE: this changes the return type of `execute::execute()` Signed-off-by: Nick Mitchell <[email protected]>
1 parent 65eb5a1 commit edf2fb9

File tree

7 files changed

+75
-35
lines changed

7 files changed

+75
-35
lines changed

benchmarks/haystack/src/main.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,10 @@ async fn main() -> Result<(), SpnlError> {
178178
return Ok(());
179179
}
180180

181-
match execute(&query, &ExecuteOptions { prepare: None }).await? {
181+
match execute(&query, &ExecuteOptions { prepare: None })
182+
.await?
183+
.result
184+
{
182185
Query::Message(User(ss)) => {
183186
// oof, be gracious here. sometimes the model wraps the
184187
// requested json array with markdown even though we asked

cli/src/main.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,15 @@ async fn main() -> Result<(), SpnlError> {
7777
ptree::write_tree(&query, ::std::io::stderr())?;
7878
}
7979

80-
let res = execute(&query, &rp).await.map(|res| {
81-
if !res.to_string().is_empty() {
82-
println!("{res}");
83-
}
84-
Ok(())
85-
})?;
80+
let res = execute(&query, &rp)
81+
.await
82+
.map(|res| res.result)
83+
.map(|res| {
84+
if !res.to_string().is_empty() {
85+
println!("{res}");
86+
}
87+
Ok(())
88+
})?;
8689

8790
if let Some(time) = time {
8891
eprintln!("{}", time.elapsed().as_millis());

spnl/src/augment/index/raptor.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use crate::{
1111
AugmentOptions,
1212
embed::{EmbedData, embed},
1313
},
14+
execute::TimestampedQuery,
1415
generate::generate,
1516
};
1617

@@ -153,7 +154,10 @@ async fn cross_index_fragment(
153154
)
154155
.await?
155156
{
156-
Query::Message(User(s)) => s,
157+
TimestampedQuery {
158+
result: Query::Message(User(s)),
159+
..
160+
} => s,
157161
_ => "".into(),
158162
};
159163

spnl/src/execute.rs

Lines changed: 54 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,44 @@
11
use crate::{Generate, Message::*, Query};
2+
use ::std::time::Instant;
23
use indicatif::MultiProgress;
34

45
pub struct ExecuteOptions {
56
/// Prepare query?
67
pub prepare: Option<bool>,
78
}
89

10+
pub struct TimestampedQuery {
11+
pub finish_time: Instant,
12+
pub result: Query,
13+
}
14+
15+
impl From<&Query> for TimestampedQuery {
16+
fn from(result: &Query) -> Self {
17+
Self {
18+
finish_time: Instant::now(),
19+
result: result.clone(),
20+
}
21+
}
22+
}
23+
24+
impl From<Query> for TimestampedQuery {
25+
fn from(result: Query) -> Self {
26+
Self {
27+
finish_time: Instant::now(),
28+
result,
29+
}
30+
}
31+
}
32+
933
pub type SpnlError = anyhow::Error;
10-
pub type SpnlResult = anyhow::Result<crate::Query>;
34+
pub type SpnlResult = anyhow::Result<TimestampedQuery>;
1135

12-
async fn seq(
36+
async fn run_sequentially(
1337
units: &[Query],
1438
rp: &ExecuteOptions,
1539
mm: Option<&MultiProgress>,
16-
) -> anyhow::Result<Vec<Query>> {
40+
f: fn(Vec<Query>) -> Query,
41+
) -> anyhow::Result<TimestampedQuery> {
1742
let mym = MultiProgress::new();
1843
let m = if let Some(m) = mm { m } else { &mym };
1944

@@ -22,32 +47,37 @@ async fn seq(
2247
evaluated.push(run_subtree(u, rp, Some(m)).await?);
2348
}
2449

25-
Ok(evaluated)
26-
}
27-
28-
async fn par(units: &[Query], rp: &ExecuteOptions) -> SpnlResult {
29-
let m = MultiProgress::new();
30-
let evaluated =
31-
futures::future::try_join_all(units.iter().map(|u| run_subtree(u, rp, Some(&m)))).await?;
32-
3350
if evaluated.len() == 1 {
3451
// the unwrap() is safe here, due to the len() == 1 guard
3552
Ok(evaluated.into_iter().next().unwrap())
3653
} else {
37-
Ok(Query::Par(evaluated))
54+
Ok(f(evaluated.into_iter().map(|q| q.result).collect()).into())
3855
}
3956
}
4057

41-
async fn plus(units: &[Query], rp: &ExecuteOptions) -> SpnlResult {
58+
async fn run_in_parallel(
59+
units: &[Query],
60+
rp: &ExecuteOptions,
61+
f: fn(Vec<Query>) -> Query,
62+
) -> SpnlResult {
4263
let m = MultiProgress::new();
43-
let evaluated =
64+
let mut evaluated =
4465
futures::future::try_join_all(units.iter().map(|u| run_subtree(u, rp, Some(&m)))).await?;
4566

4667
if evaluated.len() == 1 {
4768
// the unwrap() is safe here, due to the len() == 1 guard
4869
Ok(evaluated.into_iter().next().unwrap())
4970
} else {
50-
Ok(Query::Plus(evaluated))
71+
// Reverse sort the children output so that the first to finish is at the end
72+
evaluated.sort_by_key(|q| ::std::cmp::Reverse(q.finish_time));
73+
let max_finish_time = evaluated
74+
.first()
75+
.map(|q| q.finish_time)
76+
.unwrap_or_else(Instant::now);
77+
Ok(TimestampedQuery {
78+
finish_time: max_finish_time,
79+
result: f(evaluated.into_iter().map(|q| q.result).collect()),
80+
})
5181
}
5282
}
5383

@@ -61,12 +91,12 @@ async fn run_subtree(query: &Query, rp: &ExecuteOptions, m: Option<&MultiProgres
6191
crate::pull::pull_if_needed(query).await?;
6292

6393
match query {
64-
Query::Message(_) => Ok(query.clone()),
94+
Query::Message(_) => Ok(query.into()),
6595

66-
Query::Par(u) => par(u, rp).await,
67-
Query::Seq(u) => Ok(Query::Seq(seq(u, rp, m).await?)),
68-
Query::Cross(u) => Ok(Query::Cross(seq(u, rp, m).await?)),
69-
Query::Plus(u) => plus(u, rp).await,
96+
Query::Par(u) => run_in_parallel(u, rp, Query::Par).await,
97+
Query::Plus(u) => run_in_parallel(u, rp, Query::Plus).await,
98+
Query::Seq(u) => run_sequentially(u, rp, m, Query::Seq).await,
99+
Query::Cross(u) => run_sequentially(u, rp, m, Query::Cross).await,
70100

71101
Query::Generate(Generate {
72102
model,
@@ -76,7 +106,7 @@ async fn run_subtree(query: &Query, rp: &ExecuteOptions, m: Option<&MultiProgres
76106
}) => {
77107
crate::generate::generate(
78108
model.as_str(),
79-
&run_subtree(input, rp, m).await?,
109+
&run_subtree(input, rp, m).await?.result,
80110
max_tokens,
81111
temperature,
82112
m,
@@ -88,7 +118,7 @@ async fn run_subtree(query: &Query, rp: &ExecuteOptions, m: Option<&MultiProgres
88118
#[cfg(feature = "cli_support")]
89119
Query::Print(m) => {
90120
println!("{m}");
91-
Ok(Query::Message(User("".into())))
121+
Ok(Query::Message(User("".into())).into())
92122
}
93123
#[cfg(feature = "cli_support")]
94124
Query::Ask(message) => {
@@ -106,7 +136,7 @@ async fn run_subtree(query: &Query, rp: &ExecuteOptions, m: Option<&MultiProgres
106136
Err(err) => panic!("{}", err), // TODO this only works in a CLI
107137
};
108138
rl.append_history("history.txt").unwrap();
109-
Ok(Query::Message(User(prompt)))
139+
Ok(Query::Message(User(prompt)).into())
110140
}
111141

112142
// TODO: should not happen; we need to improve the typing of runnable queries
@@ -123,7 +153,7 @@ mod tests {
123153
#[tokio::test]
124154
async fn it_works() -> Result<(), SpnlError> {
125155
let result = execute(&"hello".into(), &ExecuteOptions { prepare: None }).await?;
126-
assert_eq!(result, Query::Message(User("hello".to_string())));
156+
assert_eq!(result.result, Query::Message(User("hello".to_string())));
127157
Ok(())
128158
}
129159
}

spnl/src/generate/backend/openai.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ pub async fn generate(
128128
stdout.write_all(b"\n").await?;
129129
}
130130

131-
Ok(Query::Message(Assistant(response_string)))
131+
Ok(Query::Message(Assistant(response_string)).into())
132132
}
133133

134134
pub fn messagify(input: &Query) -> Vec<ChatCompletionRequestMessage> {

spnl/src/generate/backend/spnl.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,5 +62,5 @@ pub async fn generate(
6262
stdout.write_all(b"\n").await?;
6363
}
6464

65-
Ok(Query::Message(Assistant(response_string)))
65+
Ok(Query::Message(Assistant(response_string)).into())
6666
}

spnl/src/python.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ pub async fn execute(q: String) -> Result<ChatResponse, PyErr> {
5353
));
5454

5555
res.map(|res| ChatResponse {
56-
data: res.to_string(),
56+
data: res.result.to_string(),
5757
model_id: None,
5858
usage: None,
5959
})

0 commit comments

Comments
 (0)