Skip to content

Commit

Permalink
feat: search for built-ins inside functions (#430)
Browse files Browse the repository at this point in the history
  • Loading branch information
morgante authored Jul 20, 2024
1 parent 5923f64 commit 3de3c32
Show file tree
Hide file tree
Showing 20 changed files with 322 additions and 59 deletions.
2 changes: 1 addition & 1 deletion crates/cli/src/commands/apply_pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ pub(crate) async fn run_apply_pattern(
}

let warn_uncommitted = !arg.dry_run && !arg.force && has_uncommitted_changes(cwd.clone()).await;
if warn_uncommitted && has_rewrite(&compiled.pattern, &compiled.pattern_definitions) {
if warn_uncommitted && has_rewrite(&compiled.pattern, &compiled.definitions()) {
let term = console::Term::stderr();
if !term.is_term() {
bail!("Error: Untracked changes detected. Grit will not proceed with rewriting files in non-TTY environments unless '--force' is used. Please commit all changes or use '--force' to override this safety check.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ variables:
sourceFile: "`function () { $body }`"
parsedPattern: "[..]"
valid: true
usesAi: false
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,4 @@ variables:
sourceFile: "or { bubble `function ($args) { $body }`, bubble `($args) => { $body }` }"
parsedPattern: "[..]"
valid: true
usesAi: false
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ variables:
sourceFile: "engine marzano(0.1)\nlanguage js\n\nfunction adder() js {\n console.log(\"We are in JavaScript now!\");\n return 10 % 3\n}\n\n`x` => adder()"
parsedPattern: "[..]"
valid: true
usesAi: false
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ variables:
sourceFile: "language js\n\n`console.log($msg)`\n"
parsedPattern: "[..]"
valid: true
usesAi: false
190 changes: 183 additions & 7 deletions crates/core/src/analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,16 @@ use std::collections::BTreeMap;
use std::ffi::OsStr;
use std::path::Path;

use grit_pattern_matcher::constants::DEFAULT_FILE_NAME;
use grit_pattern_matcher::{
constants::DEFAULT_FILE_NAME,
context::StaticDefinitions,
pattern::{DynamicPattern, Pattern, PatternOrPredicate},
};

use crate::pattern_compiler::compiler::{defs_to_filenames, DefsToFilenames};
use crate::{
pattern_compiler::compiler::{defs_to_filenames, DefsToFilenames},
problem::MarzanoQueryContext,
};

/// Walks the call tree and returns true if the predicate is true for any node.
/// This is potentially error-prone, so not entirely recommended
Expand Down Expand Up @@ -182,6 +189,36 @@ fn find_child_tree_definition(
Ok(None)
}

fn uses_named_function(
root: &Pattern<MarzanoQueryContext>,
definitions: &StaticDefinitions<MarzanoQueryContext>,
function_name: &str,
) -> bool {
for pattern in root.iter(definitions) {
if let PatternOrPredicate::Pattern(grit_pattern_matcher::pattern::Pattern::CallBuiltIn(
call,
)) = pattern
{
if call.name == function_name {
return true;
}
}
if let PatternOrPredicate::DynamicPattern(DynamicPattern::CallBuiltIn(call)) = pattern {
if call.name == function_name {
return true;
}
}
}
false
}

pub fn uses_ai(
root: &Pattern<MarzanoQueryContext>,
definitions: &StaticDefinitions<MarzanoQueryContext>,
) -> bool {
uses_named_function(root, definitions, "llm_chat")
}

#[cfg(test)]
mod tests {
use grit_pattern_matcher::has_rewrite;
Expand Down Expand Up @@ -277,7 +314,7 @@ mod tests {

println!("problem: {:?}", problem);

assert!(has_rewrite(&problem.pattern, &[]));
assert!(has_rewrite(&problem.pattern, &problem.definitions()));
}

#[test]
Expand All @@ -303,7 +340,7 @@ mod tests {

println!("problem: {:?}", problem);

assert!(!has_rewrite(&problem.pattern, &[]));
assert!(!has_rewrite(&problem.pattern, &problem.definitions()));
}

#[test]
Expand All @@ -330,7 +367,37 @@ mod tests {

println!("problem: {:?}", problem);

assert!(has_rewrite(&problem.pattern, &problem.pattern_definitions));
assert!(has_rewrite(&problem.pattern, &problem.definitions()));
}

#[test]
fn test_is_not_rewrite_with_pattern_call() {
let pattern_src = r#"
pattern pattern_with_rewrite() {
`console.log($msg)` => `console.error($msg)`
}
pattern pattern_without_rewrite() {
`console.log($msg)`
}
pattern_without_rewrite()
"#
.to_string();
let libs = BTreeMap::new();
let problem = src_to_problem_libs(
pattern_src.to_string(),
&libs,
TargetLanguage::default(),
None,
None,
None,
None,
)
.unwrap()
.problem;

println!("problem: {:?}", problem);

assert!(!has_rewrite(&problem.pattern, &problem.definitions()));
}

#[test]
Expand Down Expand Up @@ -361,7 +428,7 @@ mod tests {

println!("problem: {:?}", problem);

assert!(has_rewrite(&problem.pattern, &problem.pattern_definitions));
assert!(has_rewrite(&problem.pattern, &problem.definitions()));
}

#[test]
Expand Down Expand Up @@ -391,6 +458,115 @@ mod tests {

println!("problem: {:?}", problem);

assert!(has_rewrite(&problem.pattern, &problem.pattern_definitions));
assert!(has_rewrite(&problem.pattern, &problem.definitions()));
}

#[test]
fn test_is_rewrite_with_predicate() {
let pattern_src = r#"
pattern pattern_with_rewrite() {
`me` => `console.error(me)`
}
predicate predicate_with_rewrite() {
$program <: contains pattern_with_rewrite()
}
`you` where {
predicate_with_rewrite()
}
"#
.to_string();
let libs = BTreeMap::new();
let problem = src_to_problem_libs(
pattern_src.to_string(),
&libs,
TargetLanguage::default(),
None,
None,
None,
None,
)
.unwrap()
.problem;

println!("problem: {:?}", problem);

assert!(has_rewrite(&problem.pattern, &problem.definitions()));
}

#[test]
fn test_is_rewrite_with_function_call() {
let pattern_src = r#"
pattern pattern_with_rewrite() {
`me` => `console.error(me)`
}
function more_indirection_is_good() {
if ($program <: contains pattern_with_rewrite()) {
return `console.error($program)`
},
return `console.error($program)`
}
predicate predicate_with_function_call() {
$foo = more_indirection_is_good()
}
`you` where {
predicate_with_function_call()
}
"#
.to_string();
let libs = BTreeMap::new();
let problem = src_to_problem_libs(
pattern_src.to_string(),
&libs,
TargetLanguage::default(),
None,
None,
None,
None,
)
.unwrap()
.problem;

println!("problem: {:?}", problem);

assert!(has_rewrite(&problem.pattern, &problem.definitions()));
assert!(!uses_named_function(
&problem.pattern,
&problem.definitions(),
"text",
));
}

#[test]
fn test_uses_text_fn() {
let pattern_src = r#"
`me` => text(`$program`)
"#
.to_string();
let libs = BTreeMap::new();
let problem = src_to_problem_libs(
pattern_src.to_string(),
&libs,
TargetLanguage::default(),
None,
None,
None,
None,
)
.unwrap()
.problem;

println!("problem: {:?}", problem);

assert!(has_rewrite(&problem.pattern, &problem.definitions()));
assert!(uses_named_function(
&problem.pattern,
&problem.definitions(),
"text",
));
}
}
5 changes: 5 additions & 0 deletions crates/core/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ pub struct PatternInfo {
pub source_file: String,
pub parsed_pattern: String,
pub valid: bool,
pub uses_ai: bool,
}

impl PatternInfo {
Expand All @@ -310,12 +311,16 @@ impl PatternInfo {
&grit_node_types,
))
.unwrap();

let uses_ai = crate::analysis::uses_ai(&compiled.pattern, &compiled.definitions());

Self {
messages: vec![],
variables: compiled.compiled_vars(),
source_file,
parsed_pattern,
valid: true,
uses_ai,
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions crates/core/src/ast_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{
problem::MarzanoQueryContext,
};
use anyhow::{anyhow, Result};
use grit_pattern_matcher::pattern::PatternDefinition;

use grit_pattern_matcher::{
binding::Binding,
pattern::{
Expand Down Expand Up @@ -32,7 +32,7 @@ impl AstNodePattern<MarzanoQueryContext> for ASTNode {

fn children<'a>(
&'a self,
_definitions: &'a [PatternDefinition<MarzanoQueryContext>],
_definitions: &'a grit_pattern_matcher::context::StaticDefinitions<MarzanoQueryContext>,
) -> Vec<PatternOrPredicate<'a, MarzanoQueryContext>> {
self.args
.iter()
Expand Down
3 changes: 2 additions & 1 deletion crates/core/src/built_in_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ impl BuiltIns {
built_ins: &BuiltIns,
index: usize,
lang: &impl Language,
name: &str,
) -> Result<CallBuiltIn<MarzanoQueryContext>> {
let params = &built_ins.0[index].params;
let mut pattern_params = Vec::with_capacity(args.len());
Expand All @@ -94,7 +95,7 @@ impl BuiltIns {
None => pattern_params.push(None),
}
}
Ok(CallBuiltIn::new(index, pattern_params))
Ok(CallBuiltIn::new(index, name, pattern_params))
}

/// Add an anonymous built-in, used for callbacks
Expand Down
1 change: 1 addition & 0 deletions crates/core/src/pattern_compiler/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ impl PatternBuilder {
let predicate_match = Predicate::Match(Box::new(Match::new(
Container::FunctionCall(Box::new(CallBuiltIn::new(
index,
&format!("match_{}", index),
vec![Some(grit_pattern_matcher::pattern::Pattern::Variable(
match_var,
))],
Expand Down
1 change: 1 addition & 0 deletions crates/core/src/pattern_compiler/call_compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ impl NodeCompiler for CallCompiler {
context.compilation.built_ins,
index,
lang,
kind,
)?)))
} else if let Some(info) = context.compilation.function_definition_info.get(kind) {
let args = match_args_to_params(kind, args, &collect_params(&info.parameters), lang)?;
Expand Down
9 changes: 6 additions & 3 deletions crates/core/src/pattern_compiler/not_compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ use super::{
};
use crate::problem::MarzanoQueryContext;
use anyhow::{anyhow, Result};
use grit_pattern_matcher::pattern::{Not, Pattern, PatternOrPredicate, PrNot, Predicate};
use grit_pattern_matcher::{
context::StaticDefinitions,
pattern::{Not, Pattern, PatternOrPredicate, PrNot, Predicate},
};
use grit_util::AnalysisLogBuilder;
use marzano_util::node_with_source::NodeWithSource;

Expand All @@ -23,7 +26,7 @@ impl NodeCompiler for NotCompiler {
.ok_or_else(|| anyhow!("missing pattern of patternNot"))?;
let range = pattern.range();
let pattern = PatternCompiler::from_node(&pattern, context)?;
if pattern.iter(&[]).any(|p| {
if pattern.iter(&StaticDefinitions::default()).any(|p| {
matches!(
p,
PatternOrPredicate::Pattern(Pattern::Rewrite(_))
Expand Down Expand Up @@ -59,7 +62,7 @@ impl NodeCompiler for PrNotCompiler {
.ok_or_else(|| anyhow!("predicateNot missing predicate"))?;
let range = not.range();
let not = PredicateCompiler::from_node(&not, context)?;
if not.iter(&[]).any(|p| {
if not.iter(&StaticDefinitions::default()).any(|p| {
matches!(
p,
PatternOrPredicate::Pattern(Pattern::Rewrite(_))
Expand Down
10 changes: 9 additions & 1 deletion crates/core/src/problem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{
use anyhow::{bail, Result};
use grit_pattern_matcher::{
constants::{GLOBAL_VARS_SCOPE_INDEX, NEW_FILES_INDEX},
context::QueryContext,
context::{QueryContext, StaticDefinitions},
file_owners::FileOwners,
pattern::{
FilePtr, FileRegistry, GritFunctionDefinition, Matcher, Pattern, PatternDefinition,
Expand Down Expand Up @@ -66,6 +66,14 @@ impl Problem {
pub fn compiled_vars(&self) -> Vec<VariableMatch> {
self.variables.compiled_vars(&self.tree.source)
}

pub fn definitions(&self) -> StaticDefinitions<'_, MarzanoQueryContext> {
StaticDefinitions::new(
&self.pattern_definitions,
&self.predicate_definitions,
&self.function_definitions,
)
}
}

enum FilePattern {
Expand Down
Loading

0 comments on commit 3de3c32

Please sign in to comment.