11use crate :: { Generate , Message :: * , Query } ;
2+ use :: std:: time:: Instant ;
23use indicatif:: MultiProgress ;
34
45pub struct ExecuteOptions {
56 /// Prepare query?
67 pub prepare : Option < bool > ,
78}
89
10+ pub struct TimestampedQuery {
11+ pub finish_time : Instant ,
12+ pub result : Query ,
13+ }
14+
15+ impl From < & Query > for TimestampedQuery {
16+ fn from ( result : & Query ) -> Self {
17+ Self {
18+ finish_time : Instant :: now ( ) ,
19+ result : result. clone ( ) ,
20+ }
21+ }
22+ }
23+
24+ impl From < Query > for TimestampedQuery {
25+ fn from ( result : Query ) -> Self {
26+ Self {
27+ finish_time : Instant :: now ( ) ,
28+ result,
29+ }
30+ }
31+ }
32+
933pub type SpnlError = anyhow:: Error ;
10- pub type SpnlResult = anyhow:: Result < crate :: Query > ;
34+ pub type SpnlResult = anyhow:: Result < TimestampedQuery > ;
1135
12- async fn seq (
36+ async fn run_sequentially (
1337 units : & [ Query ] ,
1438 rp : & ExecuteOptions ,
1539 mm : Option < & MultiProgress > ,
16- ) -> anyhow:: Result < Vec < Query > > {
40+ f : fn ( Vec < Query > ) -> Query ,
41+ ) -> anyhow:: Result < TimestampedQuery > {
1742 let mym = MultiProgress :: new ( ) ;
1843 let m = if let Some ( m) = mm { m } else { & mym } ;
1944
@@ -22,32 +47,37 @@ async fn seq(
2247 evaluated. push ( run_subtree ( u, rp, Some ( m) ) . await ?) ;
2348 }
2449
25- Ok ( evaluated)
26- }
27-
28- async fn par ( units : & [ Query ] , rp : & ExecuteOptions ) -> SpnlResult {
29- let m = MultiProgress :: new ( ) ;
30- let evaluated =
31- futures:: future:: try_join_all ( units. iter ( ) . map ( |u| run_subtree ( u, rp, Some ( & m) ) ) ) . await ?;
32-
3350 if evaluated. len ( ) == 1 {
3451 // the unwrap() is safe here, due to the len() == 1 guard
3552 Ok ( evaluated. into_iter ( ) . next ( ) . unwrap ( ) )
3653 } else {
37- Ok ( Query :: Par ( evaluated) )
54+ Ok ( f ( evaluated. into_iter ( ) . map ( |q| q . result ) . collect ( ) ) . into ( ) )
3855 }
3956}
4057
41- async fn plus ( units : & [ Query ] , rp : & ExecuteOptions ) -> SpnlResult {
58+ async fn run_in_parallel (
59+ units : & [ Query ] ,
60+ rp : & ExecuteOptions ,
61+ f : fn ( Vec < Query > ) -> Query ,
62+ ) -> SpnlResult {
4263 let m = MultiProgress :: new ( ) ;
43- let evaluated =
64+ let mut evaluated =
4465 futures:: future:: try_join_all ( units. iter ( ) . map ( |u| run_subtree ( u, rp, Some ( & m) ) ) ) . await ?;
4566
4667 if evaluated. len ( ) == 1 {
4768 // the unwrap() is safe here, due to the len() == 1 guard
4869 Ok ( evaluated. into_iter ( ) . next ( ) . unwrap ( ) )
4970 } else {
50- Ok ( Query :: Plus ( evaluated) )
71+ // Reverse sort the children output so that the first to finish is at the end
72+ evaluated. sort_by_key ( |q| :: std:: cmp:: Reverse ( q. finish_time ) ) ;
73+ let max_finish_time = evaluated
74+ . first ( )
75+ . map ( |q| q. finish_time )
76+ . unwrap_or_else ( Instant :: now) ;
77+ Ok ( TimestampedQuery {
78+ finish_time : max_finish_time,
79+ result : f ( evaluated. into_iter ( ) . map ( |q| q. result ) . collect ( ) ) ,
80+ } )
5181 }
5282}
5383
@@ -61,12 +91,12 @@ async fn run_subtree(query: &Query, rp: &ExecuteOptions, m: Option<&MultiProgres
6191 crate :: pull:: pull_if_needed ( query) . await ?;
6292
6393 match query {
64- Query :: Message ( _) => Ok ( query. clone ( ) ) ,
94+ Query :: Message ( _) => Ok ( query. into ( ) ) ,
6595
66- Query :: Par ( u) => par ( u, rp) . await ,
67- Query :: Seq ( u) => Ok ( Query :: Seq ( seq ( u, rp, m ) . await ? ) ) ,
68- Query :: Cross ( u) => Ok ( Query :: Cross ( seq ( u, rp, m) . await ? ) ) ,
69- Query :: Plus ( u) => plus ( u, rp) . await ,
96+ Query :: Par ( u) => run_in_parallel ( u, rp, Query :: Par ) . await ,
97+ Query :: Plus ( u) => run_in_parallel ( u, rp, Query :: Plus ) . await ,
98+ Query :: Seq ( u) => run_sequentially ( u, rp, m, Query :: Seq ) . await ,
99+ Query :: Cross ( u) => run_sequentially ( u, rp, m , Query :: Cross ) . await ,
70100
71101 Query :: Generate ( Generate {
72102 model,
@@ -76,7 +106,7 @@ async fn run_subtree(query: &Query, rp: &ExecuteOptions, m: Option<&MultiProgres
76106 } ) => {
77107 crate :: generate:: generate (
78108 model. as_str ( ) ,
79- & run_subtree ( input, rp, m) . await ?,
109+ & run_subtree ( input, rp, m) . await ?. result ,
80110 max_tokens,
81111 temperature,
82112 m,
@@ -88,7 +118,7 @@ async fn run_subtree(query: &Query, rp: &ExecuteOptions, m: Option<&MultiProgres
88118 #[ cfg( feature = "cli_support" ) ]
89119 Query :: Print ( m) => {
90120 println ! ( "{m}" ) ;
91- Ok ( Query :: Message ( User ( "" . into ( ) ) ) )
121+ Ok ( Query :: Message ( User ( "" . into ( ) ) ) . into ( ) )
92122 }
93123 #[ cfg( feature = "cli_support" ) ]
94124 Query :: Ask ( message) => {
@@ -106,7 +136,7 @@ async fn run_subtree(query: &Query, rp: &ExecuteOptions, m: Option<&MultiProgres
106136 Err ( err) => panic ! ( "{}" , err) , // TODO this only works in a CLI
107137 } ;
108138 rl. append_history ( "history.txt" ) . unwrap ( ) ;
109- Ok ( Query :: Message ( User ( prompt) ) )
139+ Ok ( Query :: Message ( User ( prompt) ) . into ( ) )
110140 }
111141
112142 // TODO: should not happen; we need to improve the typing of runnable queries
@@ -123,7 +153,7 @@ mod tests {
123153 #[ tokio:: test]
124154 async fn it_works ( ) -> Result < ( ) , SpnlError > {
125155 let result = execute ( & "hello" . into ( ) , & ExecuteOptions { prepare : None } ) . await ?;
126- assert_eq ! ( result, Query :: Message ( User ( "hello" . to_string( ) ) ) ) ;
156+ assert_eq ! ( result. result , Query :: Message ( User ( "hello" . to_string( ) ) ) ) ;
127157 Ok ( ( ) )
128158 }
129159}
0 commit comments