diff --git a/engine/baml-lib/jinja-runtime/src/output_format/mod.rs b/engine/baml-lib/jinja-runtime/src/output_format/mod.rs index a4a96c18f9..0b03d84e27 100644 --- a/engine/baml-lib/jinja-runtime/src/output_format/mod.rs +++ b/engine/baml-lib/jinja-runtime/src/output_format/mod.rs @@ -4,6 +4,7 @@ use std::str::FromStr; use minijinja::{value::Kwargs, ErrorKind, Value}; use strum::VariantNames; +use types::HoistClasses; use crate::{types::RenderOptions, RenderContext}; @@ -132,6 +133,28 @@ impl minijinja::value::Object for OutputFormat { None }; + let hoist_classes = if kwargs.has("hoist_classes") { + // true | false + match kwargs.get::("hoist_classes") { + Ok(true) => Some(HoistClasses::All), + Ok(false) => Some(HoistClasses::Auto), + // auto + Err(_) => match kwargs.get::("hoist_classes") { + Ok(s) if s == "auto" => Some(HoistClasses::Auto), + // subset + _ => match kwargs.get::>("hoist_classes") { + Ok(classes) => Some(HoistClasses::Subset(classes)), + Err(e) => return Err(Error::new( + ErrorKind::SyntaxError, + format!("Invalid value for hoist_classes (expected one of bool | \"auto\" | string[]): {e}") + )) + } + } + } + } else { + None + }; + let map_style = if kwargs.has("map_style") { match kwargs .get::("map_style") @@ -177,6 +200,7 @@ impl minijinja::value::Object for OutputFormat { always_hoist_enums, map_style, hoisted_class_prefix, + hoist_classes, ))?; match content { diff --git a/engine/baml-lib/jinja-runtime/src/output_format/types.rs b/engine/baml-lib/jinja-runtime/src/output_format/types.rs index f8b0b8d60b..140947dbf5 100644 --- a/engine/baml-lib/jinja-runtime/src/output_format/types.rs +++ b/engine/baml-lib/jinja-runtime/src/output_format/types.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{ops::Deref, sync::Arc}; use anyhow::Result; use baml_types::{Constraint, FieldType, StreamingBehavior, TypeValue}; @@ -157,11 +157,27 @@ pub(crate) enum MapStyle { ObjectLiteral, } +/// Hoist classes setting. +/// +/// Recursive classes are always hoisted. +pub(crate) enum HoistClasses { + /// Hoist all classes. + All, + /// Hoist only the specified subset. + Subset(Vec), + /// Default behavior, hoist only recursive classes. + Auto, +} + +/// Maximum number of variants in the enum that we render without hoisting. +const INLINE_RENDER_ENUM_MAX_VALUES: usize = 6; + pub struct RenderOptions { prefix: RenderSetting, pub(crate) or_splitter: String, enum_value_prefix: RenderSetting, hoisted_class_prefix: RenderSetting, + hoist_classes: HoistClasses, always_hoist_enums: RenderSetting, map_style: MapStyle, } @@ -173,6 +189,7 @@ impl Default for RenderOptions { or_splitter: Self::DEFAULT_OR_SPLITTER.to_string(), enum_value_prefix: RenderSetting::Auto, hoisted_class_prefix: RenderSetting::Auto, + hoist_classes: HoistClasses::Auto, always_hoist_enums: RenderSetting::Auto, map_style: MapStyle::TypeParameters, } @@ -183,6 +200,12 @@ impl RenderOptions { const DEFAULT_OR_SPLITTER: &'static str = " or "; const DEFAULT_TYPE_PREFIX_IN_RENDER_MESSAGE: &'static str = "schema"; + /// Option> Basically means that we can have a paremeter which + /// 1. the user can completely omit: None + /// 2. the user can set to null: Some(None) + /// + /// This might be a little annoying, maybe we can change the code in mod.rs + /// to flatten the types Option> => Option pub(crate) fn new( prefix: Option>, or_splitter: Option, @@ -190,6 +213,7 @@ impl RenderOptions { always_hoist_enums: Option, map_style: Option, hoisted_class_prefix: Option>, + hoist_classes: Option, ) -> Self { Self { prefix: prefix.map_or(RenderSetting::Auto, |p| { @@ -205,6 +229,7 @@ impl RenderOptions { hoisted_class_prefix: hoisted_class_prefix.map_or(RenderSetting::Auto, |p| { p.map_or(RenderSetting::Never, RenderSetting::Always) }), + hoist_classes: hoist_classes.unwrap_or(HoistClasses::Auto), } } @@ -215,6 +240,14 @@ impl RenderOptions { ..Default::default() } } + + // TODO: Might need a builder pattern for this as well. + pub(crate) fn hoist_classes(hoist_classes: HoistClasses) -> Self { + Self { + hoist_classes, + ..Default::default() + } + } } struct Attribute { @@ -311,8 +344,9 @@ fn indefinite_article_a_or_an(word: &str) -> &str { } } -struct RenderState { +struct RenderCtx { hoisted_enums: IndexSet, + hoisted_classes: IndexSet, } impl OutputFormatContent { @@ -335,10 +369,11 @@ impl OutputFormatContent { } } - fn prefix(&self, options: &RenderOptions) -> Option { + fn prefix(&self, options: &RenderOptions, render_state: &RenderCtx) -> Option { fn auto_prefix( ft: &FieldType, options: &RenderOptions, + render_state: &RenderCtx, output_format_content: &OutputFormatContent, ) -> Option { match ft { @@ -356,7 +391,7 @@ impl OutputFormatContent { }; // Line break if schema else just inline the name. - let end = if output_format_content.recursive_classes.contains(cls) { + let end = if render_state.hoisted_classes.contains(cls) { " " } else { "\n" @@ -382,7 +417,7 @@ impl OutputFormatContent { FieldType::Map(_, _) => Some(String::from("Answer in JSON using this schema:\n")), FieldType::Tuple(_) => None, FieldType::WithMetadata { base, .. } => { - auto_prefix(base, options, output_format_content) + auto_prefix(base, options, render_state, output_format_content) } FieldType::Arrow(_) => None, // TODO: Error? Arrow shouldn't appear here. } @@ -391,7 +426,7 @@ impl OutputFormatContent { match &options.prefix { RenderSetting::Always(prefix) => Some(prefix.to_owned()), RenderSetting::Never => None, - RenderSetting::Auto => auto_prefix(&self.target, options, self), + RenderSetting::Auto => auto_prefix(&self.target, options, render_state, self), } } @@ -411,42 +446,146 @@ impl OutputFormatContent { .to_string(options) } - /// Recursive classes are rendered using their name instead of schema. + /// Renders either the schema or the name of a type. + /// + /// Prompt rendering is somewhat confusing because of hoisted types, so + /// let's give a little explanation. + /// + /// The [`Self::inner_type_render`] function renders schemas only, say we + /// have these classes: + /// + /// ```baml + /// class Example { + /// a string + /// b string + /// c Nested + /// } + /// + /// class Nested { + /// n int + /// m int + /// } + /// ``` + /// + /// then [`Self::inner_type_render`] will return this string: + /// + /// ```ts + /// { + /// a: string, + /// b: string, + /// c: { + /// n: int, + /// m: int, + /// }, + /// } + /// ``` + /// + /// Basically it renders all schemas recursively into one single schema. + /// That becomes a problem when you define recursive classes, because + /// there's no way to render them "inline" as above. Here's an example: + /// + /// ```baml + /// class Node { + /// data int + /// next Node? + /// } + /// ``` /// - /// The schema must be hoisted and named, otherwise there's no way to refer - /// to a recursive class. + /// If we wanted to render this as above we'd stack overflow: /// - /// This function stops the recursion if it finds a recursive class and - /// simply returns its name. It acts as wrapper for - /// [`Self::inner_type_render`] and must be called wherever we could - /// encounter a recursive type when rendering. + /// ```ts + /// { + /// data: int, + /// next: { + /// data: int, + /// next: { + /// data: int, + /// next: <<< STACK OVERFLOW >>> + /// }, + /// }, + /// } + /// ``` /// - /// Do not call this function as an entry point because if the target type - /// is recursive itself you own't get any rendering! You'll just get the - /// name of the type. Instead call [`Self::inner_type_render`] as an entry - /// point and that will render the schema considering recursive fields. - fn render_possibly_recursive_type( + /// So the solution is to hoist the class and use its name instead. This is + /// how the complete prompt would look like: + /// + /// ```text + /// Node { + /// data: int, + /// next: Node, + /// } + /// + /// Answer in JSON using this schema: Node + /// ``` + /// + /// Obviously, we want to be able to embed recursive classes in other + /// non-recursive classes, something like this: + /// + /// ```baml + /// class Example { + /// a string + /// b string + /// c Nested + /// d LinkedList + /// } + /// ``` + /// + /// Which requires this prompt: + /// + /// ```text + /// Node { + /// data: int, + /// next: Node, + /// } + /// + /// Answer in JSON using this schema: + /// { + /// a: string, + /// b: string, + /// c: { + /// n: int, + /// m: int, + /// }, + /// d: Node, + /// } + /// ``` + /// + /// We need to render both schemas and names, which makes deciding when to + /// "stop" recursion complicated. And that's what this function does, it + /// saves us from writing if statements in every case where we might + /// encounter a nested recursive type in [`Self::inner_type_render`]. + /// + /// Users can also decide to hoist non-recursive classes for other reasons + /// such as saving tokens or improve the adherence to the schema of the + /// model response. + /// + /// Rule of thumb is, call [`Self::inner_type_render`] as an entry point + /// and inside [`Self::inner_type_render`] call this function for each + /// nested/inner type and let it handle the rest of recursion. + fn render_possibly_hoisted_type( &self, options: &RenderOptions, field_type: &FieldType, - render_state: &mut RenderState, - group_hoisted_literals: bool, + render_ctx: &RenderCtx, ) -> Result { match field_type { - FieldType::Class(nested_class) if self.recursive_classes.contains(nested_class) => { + FieldType::Class(nested_class) if render_ctx.hoisted_classes.contains(nested_class) => { Ok(nested_class.to_owned()) } - _ => self.inner_type_render(options, field_type, render_state, group_hoisted_literals), + _ => self.inner_type_render(options, field_type, render_ctx), } } + /// This function is the entry point for recursive schema rendering. + /// + /// Read the documentation of [`Self::render_possibly_hoisted_type`] for + /// more details. fn inner_type_render( &self, options: &RenderOptions, field: &FieldType, - render_state: &mut RenderState, - group_hoisted_literals: bool, + render_ctx: &RenderCtx, ) -> Result { Ok(match field { FieldType::Primitive(t) => match t { @@ -463,12 +602,9 @@ impl OutputFormatContent { } }, FieldType::Literal(v) => v.to_string(), - FieldType::WithMetadata { base, .. } => self.render_possibly_recursive_type( - options, - base, - render_state, - group_hoisted_literals, - )?, + FieldType::WithMetadata { base, .. } => { + self.render_possibly_hoisted_type(options, base, render_ctx)? + } FieldType::Enum(e) => { let Some(enm) = self.enums.get(e) else { return Err(minijinja::Error::new( @@ -477,22 +613,14 @@ impl OutputFormatContent { )); }; - if enm.values.len() <= 6 - && enm.values.iter().all(|(_, d)| d.is_none()) - && !group_hoisted_literals - && !matches!(options.always_hoist_enums, RenderSetting::Always(true)) - { - let values = enm - .values + if render_ctx.hoisted_enums.contains(&enm.name.name) { + enm.name.rendered_name().to_string() + } else { + enm.values .iter() .map(|(n, _)| format!("'{}'", n.rendered_name())) .collect::>() - .join(&options.or_splitter); - - values - } else { - render_state.hoisted_enums.insert(enm.name.name.clone()); - enm.name.rendered_name().to_string() + .join(&options.or_splitter) } } FieldType::Class(cls) => { @@ -512,11 +640,8 @@ impl OutputFormatContent { Ok(ClassFieldRender { name: name.rendered_name().to_string(), description: description.clone(), - r#type: self.render_possibly_recursive_type( - options, - field_type, - render_state, - false, + r#type: self.render_possibly_hoisted_type( + options, field_type, render_ctx, )?, }) }) @@ -526,18 +651,19 @@ impl OutputFormatContent { } FieldType::RecursiveTypeAlias(name) => name.to_owned(), FieldType::List(inner) => { - let is_recursive = match inner.as_ref() { - FieldType::Class(nested_class) => self.recursive_classes.contains(nested_class), + let is_hoisted = match inner.as_ref() { + FieldType::Class(nested_class) => { + render_ctx.hoisted_classes.contains(nested_class) + } FieldType::RecursiveTypeAlias(name) => { self.structural_recursive_aliases.contains_key(name) } _ => false, }; - let inner_str = - self.render_possibly_recursive_type(options, inner, render_state, false)?; + let inner_str = self.render_possibly_hoisted_type(options, inner, render_ctx)?; - if !is_recursive + if !is_hoisted && match inner.as_ref() { FieldType::Primitive(_) => false, FieldType::Optional(t) => !t.is_primitive(), @@ -554,12 +680,11 @@ impl OutputFormatContent { } FieldType::Union(items) => items .iter() - .map(|t| self.render_possibly_recursive_type(options, t, render_state, false)) + .map(|t| self.render_possibly_hoisted_type(options, t, render_ctx)) .collect::, minijinja::Error>>()? .join(&options.or_splitter), FieldType::Optional(inner) => { - let inner_str = - self.render_possibly_recursive_type(options, inner, render_state, false)?; + let inner_str = self.render_possibly_hoisted_type(options, inner, render_ctx)?; if inner.is_optional() { inner_str } else { @@ -574,20 +699,8 @@ impl OutputFormatContent { } FieldType::Map(key_type, value_type) => MapRender { style: &options.map_style, - // NOTE: Key can't be recursive because we only support strings - // as keys. - key_type: self.render_possibly_recursive_type( - options, - key_type, - render_state, - false, - )?, - value_type: self.render_possibly_recursive_type( - options, - value_type, - render_state, - false, - )?, + key_type: self.render_possibly_hoisted_type(options, key_type, render_ctx)?, + value_type: self.render_possibly_hoisted_type(options, value_type, render_ctx)?, } .to_string(), FieldType::Arrow(_) => { @@ -600,12 +713,77 @@ impl OutputFormatContent { } pub fn render(&self, options: RenderOptions) -> Result, minijinja::Error> { - let prefix = self.prefix(&options); - - let mut render_state = RenderState { + // Render context. Only contains hoisted types for now. + let mut render_ctx = RenderCtx { hoisted_enums: IndexSet::new(), + // Recursive classes are always hoisted so we start with those as base. + // TODO: Figure out memory gymnastics to avoid this clone. + hoisted_classes: self.recursive_classes.deref().clone(), + }; + + // Precompute hoisted enums. + // + // Original code had the "group_hoisted_literals" logic here but it + // was always false, so not actually used. See this code: + // https://github.com/BoundaryML/baml/blob/ee15d0f379f53a93f2d80b39909c74495b19930b/engine/baml-lib/jinja-runtime/src/output_format/types.rs#L480-L496 + for enm in self.enums.values() { + if enm.values.len() > INLINE_RENDER_ENUM_MAX_VALUES + || enm.values.iter().any(|(_, desc)| desc.is_some()) + || matches!(options.always_hoist_enums, RenderSetting::Always(true)) + // || group_hoisted_literals + { + render_ctx.hoisted_enums.insert(enm.name.name.clone()); + } + } + + // Now figure out what to hoist besides recursive classes. + match &options.hoist_classes { + // Nothing here, default behavior. + HoistClasses::Auto => {} + + // Hoist all classes. + HoistClasses::All => render_ctx + .hoisted_classes + .extend(self.classes.keys().cloned()), + + // Hoist only the specified subset. + HoistClasses::Subset(classes) => { + let mut not_found = IndexSet::new(); + + for cls in classes { + if self.classes.contains_key(cls) { + render_ctx.hoisted_classes.insert(cls.to_owned()); + } else { + not_found.insert(cls.to_owned()); + } + } + + // Error message if class/classes not found. + if !not_found.is_empty() { + let (class_or_classes, it_does_or_they_do) = if not_found.len() == 1 { + ("class", "it does") + } else { + ("classes", "they do") + }; + + return Err(minijinja::Error::new( + minijinja::ErrorKind::BadSerialization, + format!( + "Cannot hoist {class_or_classes} {} because {it_does_or_they_do} not exist", + not_found + .iter() + .map(|cls| format!("\"{cls}\"")) + .collect::>() + .join(", "), + ), + )); + } + } }; + // Schema prefix (Answer in JSON using...) + let prefix = self.prefix(&options, &render_ctx); + let mut message = match &self.target { FieldType::Primitive(TypeValue::String) if prefix.is_none() => None, FieldType::Enum(e) => { @@ -618,13 +796,13 @@ impl OutputFormatContent { Some(self.enum_to_string(enm, &options)) } - _ => Some(self.inner_type_render(&options, &self.target, &mut render_state, false)?), + _ => Some(self.inner_type_render(&options, &self.target, &render_ctx)?), }; // Top level recursive classes will just use their name instead of the // entire schema which should already be hoisted. if let FieldType::Class(class) = &self.target { - if self.recursive_classes.contains(class) { + if render_ctx.hoisted_classes.contains(class) { message = Some(class.to_owned()); } } @@ -632,16 +810,11 @@ impl OutputFormatContent { let mut class_definitions = Vec::new(); let mut type_alias_definitions = Vec::new(); - // Hoist recursive classes. The render_state struct doesn't need to - // contain these classes because we already know that we're gonna hoist - // them beforehand. Recursive cycles are computed after the AST - // validation stage. - for class_name in self.recursive_classes.iter() { + for class_name in &render_ctx.hoisted_classes { let schema = self.inner_type_render( &options, &FieldType::Class(class_name.to_owned()), - &mut render_state, - false, + &render_ctx, )?; class_definitions.push(match &options.hoisted_class_prefix { @@ -653,8 +826,7 @@ impl OutputFormatContent { } for (alias, target) in self.structural_recursive_aliases.iter() { - let recursive_pointer = - self.inner_type_render(&options, target, &mut render_state, false)?; + let recursive_pointer = self.inner_type_render(&options, target, &render_ctx)?; type_alias_definitions.push(match &options.hoisted_class_prefix { RenderSetting::Always(prefix) if !prefix.is_empty() => { @@ -664,9 +836,7 @@ impl OutputFormatContent { }); } - // once render_state.hoisted_enums is used, we shouldn't write to it again, hence why into_iter() over iter(). - // We want a compile-time error if render_state.hoisted_enums is used again. - let enum_definitions = Vec::from_iter(render_state.hoisted_enums.into_iter().map(|e| { + let enum_definitions = Vec::from_iter(render_ctx.hoisted_enums.into_iter().map(|e| { let enm = self.enums.get(&e).expect("Enum not found"); // TODO: Jinja Err self.enum_to_string(enm, &options) })); @@ -796,7 +966,12 @@ mod tests { assert_eq!( rendered, Some(String::from( - "Answer with any of the categories:\nColor\n----\n- Red\n- Green\n- Blue" + "Answer with any of the categories: +Color +---- +- Red +- Green +- Blue" )) ); } @@ -830,7 +1005,13 @@ mod tests { assert_eq!( rendered, Some(String::from( - "Answer in JSON using this schema:\n{\n // The person's name\n name: string,\n // The person's age\n age: int,\n}" + r#"Answer in JSON using this schema: +{ + // The person's name + name: string, + // The person's age + age: int, +}"# )) ); } @@ -865,7 +1046,181 @@ mod tests { assert_eq!( rendered, Some(String::from( - "Answer in JSON using this schema:\n{\n // 111\n // \n school: string or null,\n // 2222222\n degree: string,\n year: int,\n}" + r#"Answer in JSON using this schema: +{ + // 111 + // + school: string or null, + // 2222222 + degree: string, + year: int, +}"# + )) + ); + } + + #[test] + fn hoist_enum_if_more_than_max_values() { + let enums = vec![Enum { + name: Name::new("Enm".to_string()), + values: vec![ + (Name::new("A".to_string()), None), + (Name::new("B".to_string()), None), + (Name::new("C".to_string()), None), + (Name::new("D".to_string()), None), + (Name::new("E".to_string()), None), + (Name::new("F".to_string()), None), + (Name::new("G".to_string()), None), + ], + constraints: Vec::new(), + }]; + + let classes = vec![Class { + name: Name::new("Output".to_string()), + fields: vec![( + Name::new("output".to_string()), + FieldType::Enum("Enm".to_string()), + None, + false, + )], + constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), + }]; + + let content = OutputFormatContent::target(FieldType::class("Output")) + .enums(enums) + .classes(classes) + .build(); + let rendered = content.render(RenderOptions::default()).unwrap(); + assert_eq!( + rendered, + Some(String::from( + r#"Enm +---- +- A +- B +- C +- D +- E +- F +- G + +Answer in JSON using this schema: +{ + output: Enm, +}"# + )) + ); + } + + #[test] + fn hoist_enum_if_variant_has_description() { + let enums = vec![Enum { + name: Name::new("Enm".to_string()), + values: vec![ + ( + Name::new("A".to_string()), + Some("A description".to_string()), + ), + (Name::new("B".to_string()), None), + (Name::new("C".to_string()), None), + (Name::new("D".to_string()), None), + (Name::new("E".to_string()), None), + (Name::new("F".to_string()), None), + ], + constraints: Vec::new(), + }]; + + let classes = vec![Class { + name: Name::new("Output".to_string()), + fields: vec![( + Name::new("output".to_string()), + FieldType::Enum("Enm".to_string()), + None, + false, + )], + constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), + }]; + + let content = OutputFormatContent::target(FieldType::class("Output")) + .enums(enums) + .classes(classes) + .build(); + let rendered = content.render(RenderOptions::default()).unwrap(); + assert_eq!( + rendered, + Some(String::from( + r#"Enm +---- +- A: A description +- B +- C +- D +- E +- F + +Answer in JSON using this schema: +{ + output: Enm, +}"# + )) + ); + } + + #[test] + fn hoist_enum_if_setting_always_hoist_enum() { + let enums = vec![Enum { + name: Name::new("Enm".to_string()), + values: vec![ + (Name::new("A".to_string()), None), + (Name::new("B".to_string()), None), + (Name::new("C".to_string()), None), + (Name::new("D".to_string()), None), + (Name::new("E".to_string()), None), + (Name::new("F".to_string()), None), + ], + constraints: Vec::new(), + }]; + + let classes = vec![Class { + name: Name::new("Output".to_string()), + fields: vec![( + Name::new("output".to_string()), + FieldType::Enum("Enm".to_string()), + None, + false, + )], + constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), + }]; + + let content = OutputFormatContent::target(FieldType::class("Output")) + .enums(enums) + .classes(classes) + .build(); + let rendered = content + .render(RenderOptions { + always_hoist_enums: RenderSetting::Always(true), + ..Default::default() + }) + .unwrap(); + assert_eq!( + rendered, + Some(String::from( + r#"Enm +---- +- A +- B +- C +- D +- E +- F + +Answer in JSON using this schema: +{ + output: Enm, +}"# )) ); } @@ -2598,4 +2953,186 @@ Answer in JSON using this type: A"# )) ); } + + #[test] + fn render_hoisted_classes_subset() { + let classes = vec![ + Class { + name: Name::new("A".to_string()), + fields: vec![(Name::new("prop".to_string()), FieldType::int(), None, false)], + constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), + }, + Class { + name: Name::new("B".to_string()), + fields: vec![( + Name::new("prop".to_string()), + FieldType::string(), + None, + false, + )], + constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), + }, + Class { + name: Name::new("C".to_string()), + fields: vec![( + Name::new("prop".to_string()), + FieldType::float(), + None, + false, + )], + constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), + }, + Class { + name: Name::new("Ret".to_string()), + fields: vec![ + ( + Name::new("a".to_string()), + FieldType::class("A"), + None, + false, + ), + ( + Name::new("b".to_string()), + FieldType::class("B"), + None, + false, + ), + ( + Name::new("c".to_string()), + FieldType::class("C"), + None, + false, + ), + ], + constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), + }, + ]; + + let content = OutputFormatContent::target(FieldType::class("Ret")) + .classes(classes) + .build(); + let rendered = content + .render(RenderOptions::hoist_classes(HoistClasses::Subset(vec![ + "A".to_string(), + "B".to_string(), + ]))) + .unwrap(); + #[rustfmt::skip] + assert_eq!( + rendered, + Some(String::from( +r#"A { + prop: int, +} + +B { + prop: string, +} + +Answer in JSON using this schema: +{ + a: A, + b: B, + c: { + prop: float, + }, +}"# + )) + ); + } + + #[test] + fn render_hoist_all_classes() { + let classes = vec![ + Class { + name: Name::new("A".to_string()), + fields: vec![(Name::new("prop".to_string()), FieldType::int(), None, false)], + constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), + }, + Class { + name: Name::new("B".to_string()), + fields: vec![( + Name::new("prop".to_string()), + FieldType::string(), + None, + false, + )], + constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), + }, + Class { + name: Name::new("C".to_string()), + fields: vec![( + Name::new("prop".to_string()), + FieldType::float(), + None, + false, + )], + constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), + }, + Class { + name: Name::new("Ret".to_string()), + fields: vec![ + ( + Name::new("a".to_string()), + FieldType::class("A"), + None, + false, + ), + ( + Name::new("b".to_string()), + FieldType::class("B"), + None, + false, + ), + ( + Name::new("c".to_string()), + FieldType::class("C"), + None, + false, + ), + ], + constraints: Vec::new(), + streaming_behavior: StreamingBehavior::default(), + }, + ]; + + let content = OutputFormatContent::target(FieldType::class("Ret")) + .classes(classes) + .build(); + let rendered = content + .render(RenderOptions::hoist_classes(HoistClasses::All)) + .unwrap(); + #[rustfmt::skip] + assert_eq!( + rendered, + Some(String::from( +r#"A { + prop: int, +} + +B { + prop: string, +} + +C { + prop: float, +} + +Ret { + a: A, + b: B, + c: C, +} + +Answer in JSON using this schema: Ret"# + )) + ); + } } diff --git a/engine/baml-lib/jinja/src/evaluate_type/mod.rs b/engine/baml-lib/jinja/src/evaluate_type/mod.rs index 40f30b7aec..602c8d36e9 100644 --- a/engine/baml-lib/jinja/src/evaluate_type/mod.rs +++ b/engine/baml-lib/jinja/src/evaluate_type/mod.rs @@ -140,6 +140,10 @@ impl TypeError { } } + // TODO: There's a bug with the suggestions, they are not consistent due to + // either some ordering issue or closest match algorithm does weird stuff + // and returns results non-deterministically. See commented test in + // baml-lib/jinja/src/evaluate_type/test_expr.rs fn new_unknown_arg(func: &str, span: Span, name: &str, valid_args: HashSet<&String>) -> Self { let names = valid_args.into_iter().collect::>(); let mut close_names = sort_by_match(name, &names, Some(3)); diff --git a/engine/baml-lib/jinja/src/evaluate_type/test_expr.rs b/engine/baml-lib/jinja/src/evaluate_type/test_expr.rs index d33361656a..3a41585ba1 100644 --- a/engine/baml-lib/jinja/src/evaluate_type/test_expr.rs +++ b/engine/baml-lib/jinja/src/evaluate_type/test_expr.rs @@ -270,9 +270,22 @@ fn test_output_format() { ); assert_eq!( - assert_fails_to!("ctx.output_format(prefix='1', unknown=1)", &types), - vec!["Function 'baml::OutputFormat' does not have an argument 'unknown'. Did you mean one of these: 'always_hoist_enums', 'enum_value_prefix', 'or_splitter'?"] + assert_fails_to!( + "ctx.output_format(hoist_classes=1)", + &types + ), + vec!["Function 'baml::OutputFormat' expects argument 'hoist_classes' to be of type (none | bool | literal[\"auto\"] | list[string]), but got literal[1]"] ); + + // TODO: There's a bug here, suggestions are not always the same, maybe some + // ordering issue or algorithm used to determine closest strings is not + // consistent. Code is in baml-lib/jinja/src/evaluate_type/mod.rs, + // TypeError::new_unknown_arg + // + // assert_eq!( + // assert_fails_to!("ctx.output_format(prefix='1', unknown=1)", &types), + // vec!["Function 'baml::OutputFormat' does not have an argument 'unknown'. Did you mean one of these: 'enum_value_prefix', 'hoist_classes', 'or_splitter'?"] + // ); } #[test] diff --git a/engine/baml-lib/jinja/src/evaluate_type/types.rs b/engine/baml-lib/jinja/src/evaluate_type/types.rs index 0839bebcac..693629d449 100644 --- a/engine/baml-lib/jinja/src/evaluate_type/types.rs +++ b/engine/baml-lib/jinja/src/evaluate_type/types.rs @@ -303,6 +303,15 @@ impl PredefinedTypes { "hoisted_class_prefix".into(), Type::merge(vec![Type::String, Type::None]), ), + ( + "hoist_classes".into(), + Type::merge(vec![ + Type::None, + Type::Bool, + Type::Literal(LiteralValue::String(String::from("auto"))), + Type::List(Box::new(Type::String)), + ]), + ), ], ), ), diff --git a/fern/03-reference/baml/prompt-syntax/output-format.mdx b/fern/03-reference/baml/prompt-syntax/output-format.mdx index 2a95417560..534232ad69 100644 --- a/fern/03-reference/baml/prompt-syntax/output-format.mdx +++ b/fern/03-reference/baml/prompt-syntax/output-format.mdx @@ -127,13 +127,127 @@ BAML renders it as `property: string or null` as we have observed some LLMs have You can always set it to ` | ` or something else for a specific model you use. + + +**Default: `"auto"`** + + +Requires BAML Version 0.89+ + + +Controls which classes are hoisted in the prompt. Recursive classes are +**always** hoisted because they need to be referenced by name. + +Let's use this as an example to visualize the different options: + +```baml +class Example { + a string + b string + c NestedClass + d Node +} + +class NestedClass { + m int + n int +} + +class Node { + data int + next Node? +} + +function UseExample() -> Example { + client GPT4 + prompt #"{{ctx.output_format}}"# +} +``` + +**"auto"** + +Only recursive classes are hoisted: + +```baml +Node { + data: int, + next: Node or null +} + +Answer in JSON using this schema: +{ + a: string, + b: string, + c: { + m: int, + n: int, + }, + d: Node, +} +``` + +**false** + +Same as `"auto"`. + +**true** + +Hoist all classes. + +```baml +Node { + data: int, + next: Node or null +} + +Example { + a: string, + b: string, + c: NestedClass, + d: Node, +} + +NestedClass { + m: int, + n: int, +} + +Answer in JSON using this schema: Example +``` + +**list[string]** + +Hoist only recursive classes and the classes specified in the list. For example +`ctx.output_format(hoist_classes=["NestedClass"])` will hoist `NestedClass`. + +```baml +Node { + data: int, + next: Node or null +} + +NestedClass { + m: int, + n: int, +} + +Answer in JSON using this schema: +{ + a: string, + b: string, + c: NestedClass, + d: Node, +} +``` + + + Prefix of hoisted classes in the prompt. **Default: ``** -Recursive classes are hoisted in the prompt so that any class field can -reference them using their name. This parameter controls the prefix used for -hoisted classes as well as the word used in the render message to refer to the -output type, which defaults to `"schema"`: +This parameter controls the prefix used for hoisted classes as well as the word +used in the render message to refer to the output type, which defaults to +`"schema"`: ``` Answer in JSON using this schema: