Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion spnl/src/ir/bulk.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use super::Generate;
use super::{Generate, GenerateMetadata};

#[derive(Debug, Clone, PartialEq, serde::Deserialize, serde::Serialize)]
pub enum Bulk {
Repeat(Repeat),

Map(Map),
}

/// Bulk operation: generate `n` outputs for the given `generate` specification
Expand All @@ -14,3 +16,14 @@ pub struct Repeat {
/// The specification of what to generate
pub generate: Generate,
}

/// Bulk operation: generate `n` outputs using the given `metadata`
/// specification, one output per given input
#[derive(Debug, Clone, PartialEq, serde::Deserialize, serde::Serialize)]
pub struct Map {
/// The metadata governing the content generation
pub metadata: GenerateMetadata,

/// Generate one output for each input in this list
pub inputs: Vec<String>,
}
6 changes: 5 additions & 1 deletion spnl/src/ir/pretty_print.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ impl ptree::TreeItem for Query {
)),
Query::Monad(_) => style.paint("\x1b[2mMonad\x1b[0m".to_string()),
Query::Bulk(Bulk::Repeat(Repeat { n, .. })) => style.paint(format!("Repeat {n}")),
Query::Bulk(Bulk::Map(Map { inputs, .. })) =>
style.paint(format!("Map {}", inputs.len())),
Query::Ask(m) => style.paint(format!("Ask {m}")),
Query::Print(m) => style.paint(format!("Print {}", truncate(m, 700))),
#[cfg(feature = "rag")]
Expand All @@ -83,7 +85,9 @@ impl ptree::TreeItem for Query {
}
fn children(&self) -> ::std::borrow::Cow<'_, [Self::Child]> {
::std::borrow::Cow::from(match self {
Query::Ask(_) | Query::Message(_) | Query::Print(_) => vec![],
Query::Ask(_) | Query::Message(_) | Query::Print(_) | Query::Bulk(Bulk::Map(_)) => {
vec![]
}
Query::Par(v) | Query::Seq(v) | Query::Plus(v) | Query::Cross(v) => v.clone(),
Query::Monad(q) => vec![*q.clone()],
Query::Bulk(Bulk::Repeat(Repeat { generate, .. })) => vec![*generate.input.clone()],
Expand Down
58 changes: 56 additions & 2 deletions spnl/src/optimizer/hlo/simplify.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::ir::{Bulk, Generate, Query, Repeat};
use crate::ir::{Bulk, Generate, Map, Message, Query, Repeat};

pub fn simplify(query: &Query) -> Query {
simplify_iter(query).into()
Expand All @@ -8,11 +8,26 @@ pub fn simplify(query: &Query) -> Query {
/// e.g. Plus-of-Plus or Cross with a tail Cross.
fn simplify_iter(query: &Query) -> Vec<Query> {
match query {
// Unroll repeats
// Unroll repeats. TODO: move this into the query executor backend, to expose server-side support for Repeat
Query::Bulk(Bulk::Repeat(Repeat { n, generate })) => {
::std::iter::repeat_n(Query::Generate(generate.clone()), *n).collect::<Vec<_>>()
}

// Unroll batch. TODO: move this into the query executor backend, to expose server-side support for Map
Query::Bulk(Bulk::Map(Map { metadata, inputs })) => {
vec![Query::Plus(
inputs
.iter()
.map(|input| {
Query::Generate(Generate {
metadata: metadata.clone(),
input: Query::Message(Message::User(input.clone())).into(),
})
})
.collect(),
)]
}

Query::Seq(v) => match &v[..] {
// One-entry sequence
[q] => simplify_iter(q),
Expand Down Expand Up @@ -127,4 +142,43 @@ mod tests {
Seq(::std::iter::repeat_n(Query::Generate(g), n).collect())
);
}

#[test]
fn simplify_batch_expansion() {
let inputs = ["a".into(), "b".into(), "c".into()];
let metadata = GenerateMetadataBuilder::default()
.model("does not matter for this test")
.build()
.unwrap();
let q = Bulk(Bulk::Map(Map {
metadata: metadata.clone(),
inputs: inputs.to_vec(),
}));
assert_eq!(
simplify(&q),
Plus(vec![
Generate(
GenerateBuilder::default()
.metadata(metadata.clone())
.input(Message(User(inputs[0].clone())).into())
.build()
.unwrap()
),
Generate(
GenerateBuilder::default()
.metadata(metadata.clone())
.input(Message(User(inputs[1].clone())).into())
.build()
.unwrap()
),
Generate(
GenerateBuilder::default()
.metadata(metadata.clone())
.input(Message(User(inputs[2].clone())).into())
.build()
.unwrap()
),
])
);
}
}
Loading