From af20a1e19caa3c9d1a884dd2fe220c65dfb12511 Mon Sep 17 00:00:00 2001 From: jlanson Date: Mon, 30 Sep 2024 15:55:19 -0400 Subject: [PATCH] feat: add type system to policy exprs --- hipcheck/src/policy_exprs/env.rs | 496 ++++++++++++++++++++++++----- hipcheck/src/policy_exprs/error.rs | 20 +- hipcheck/src/policy_exprs/expr.rs | 311 +++++++++++++++++- hipcheck/src/policy_exprs/mod.rs | 119 ++++++- hipcheck/src/policy_exprs/pass.rs | 121 ++++++- 5 files changed, 973 insertions(+), 94 deletions(-) diff --git a/hipcheck/src/policy_exprs/env.rs b/hipcheck/src/policy_exprs/env.rs index b5064a78..296a1b77 100644 --- a/hipcheck/src/policy_exprs/env.rs +++ b/hipcheck/src/policy_exprs/env.rs @@ -1,12 +1,21 @@ // SPDX-License-Identifier: Apache-2.0 use crate::policy_exprs::{ - pass::ExprMutator, Array as StructArray, Error, Expr, ExprVisitor, Function as StructFunction, - Ident, Lambda as StructLambda, Primitive, Result, F64, + expr::{ + ArrayType as ExprArrayType, FuncReturnType, Function, FunctionDef, FunctionType, Op, + OpInfo, PrimitiveType, ReturnableType, Type, TypeChecker, Typed, + }, + pass::ExprMutator, + Array as StructArray, Error, Expr, ExprVisitor, Function as StructFunction, Ident, + Lambda as StructLambda, Primitive, Result, F64, }; use itertools::Itertools as _; use jiff::{Span, Zoned}; -use std::{cmp::Ordering, collections::HashMap, ops::Not as _}; +use std::{ + cmp::{Ordering, PartialEq}, + collections::HashMap, + ops::Not as _, +}; use Expr::*; use Primitive::*; @@ -23,14 +32,331 @@ pub struct Env<'parent> { #[derive(Clone)] pub enum Binding { /// A function. - Fn(Op), + Fn(FunctionDef), /// A primitive value. Var(Primitive), } -/// Helper type for operation function pointer. -type Op = fn(&Env, &[Expr]) -> Result; +// Ensure that type of array elements is valid with a lambda +fn ty_check_higher_order_lambda( + l_ty: &FunctionType, + arr_ty: &ExprArrayType, +) -> Result { + if let Some(arr_elt_ty) = arr_ty { + // Copy the lambda function type, replace ident with arr_elt_ty + let mut try_l_ty = l_ty.clone(); + let first_arg = try_l_ty.arg_tys.get_mut(0).ok_or(Error::NotEnoughArgs { + name: "".to_owned(), + expected: 1, + given: 0, + })?; + *first_arg = Type::Primitive(*arr_elt_ty); + // If this returns error, means array type was incorrect for lambda + try_l_ty.get_return_type() + } else { + Ok(ReturnableType::Unknown) + } +} + +// Expects args to contain [lambda, array] +fn ty_filter(args: &[Type]) -> Result { + let arr_ty = expect_array_at(args, 1)?; + + let wrapped_l_ty = args.first().ok_or(Error::InternalError( + "we were supposed to have already checked that there are at least two arguments".to_owned(), + ))?; + let Type::Lambda(l_ty) = wrapped_l_ty else { + return Err(Error::BadFuncArgType { + name: "".to_owned(), + idx: 0, + expected: "a lambda".to_owned(), + got: wrapped_l_ty.clone(), + }); + }; + + let res_ty = ty_check_higher_order_lambda(l_ty, &arr_ty)?; + match res_ty { + ReturnableType::Primitive(PrimitiveType::Bool) | ReturnableType::Unknown => { + Ok(ReturnableType::Array(arr_ty)) + } + a => Err(Error::BadFuncArgType { + name: "".to_owned(), + idx: 0, + expected: "a bool-returning lambda".to_owned(), + got: Type::Lambda(l_ty.clone()), + }), + } +} + +// Expects args to contain [lambda, array] +fn ty_higher_order_bool_fn(args: &[Type]) -> Result { + let arr_ty = expect_array_at(args, 1)?; + + let wrapped_l_ty = args.first().ok_or(Error::InternalError( + "we were supposed to have already checked that there are at least two arguments".to_owned(), + ))?; + let Type::Lambda(l_ty) = wrapped_l_ty else { + return Err(Error::BadFuncArgType { + name: "".to_owned(), + idx: 0, + expected: "a lambda".to_owned(), + got: wrapped_l_ty.clone(), + }); + }; + + let res_ty = ty_check_higher_order_lambda(l_ty, &arr_ty)?; + match res_ty { + ReturnableType::Primitive(PrimitiveType::Bool) | ReturnableType::Unknown => { + Ok(ReturnableType::Primitive(PrimitiveType::Bool)) + } + a => Err(Error::BadFuncArgType { + name: "".to_owned(), + idx: 0, + expected: "a bool-returning lambda".to_owned(), + got: Type::Lambda(l_ty.clone()), + }), + } +} + +// Type of dynamic function is dependent on first arg +fn ty_inherit_first(args: &[Type]) -> Result { + args.first().ok_or( + Error::InternalError("type checking function expects one argument, was incorrectly applied to a function that takes none".to_owned()) + )?.try_into() +} + +fn ty_from_first_arr(args: &[Type]) -> Result { + let arr_ty = expect_array_at(args, 0)?; + Ok(match arr_ty { + None => ReturnableType::Unknown, + Some(p_ty) => ReturnableType::Primitive(p_ty), + }) +} + +fn expect_primitive_at(args: &[Type], idx: usize) -> Result> { + let arg = args + .get(idx) + .ok_or(Error::InternalError( + "we were supposed to have already checked that function had enough arguments" + .to_owned(), + ))? + .try_into()?; + + match arg { + ReturnableType::Primitive(p) => Ok(Some(p)), + ReturnableType::Array(a) => Err(Error::BadFuncArgType { + name: "".to_owned(), + idx, + expected: "a primitive type".to_owned(), + got: Type::Array(a), + }), + ReturnableType::Unknown => Ok(None), + } +} + +fn expect_array_at(args: &[Type], idx: usize) -> Result { + let arg = args + .get(idx) + .ok_or(Error::InternalError( + "we were supposed to have already checked that function had enough arguments" + .to_owned(), + ))? + .try_into()?; + + match arg { + ReturnableType::Array(a) => Ok(a), + ReturnableType::Unknown => Ok(None), + ReturnableType::Primitive(p) => Err(Error::BadFuncArgType { + name: "".to_owned(), + idx, + expected: "an array".to_owned(), + got: Type::Primitive(p), + }), + } +} + +fn ty_divz(args: &[Type]) -> Result { + let opt_ty_1 = expect_primitive_at(args, 0)?; + let opt_ty_2 = expect_primitive_at(args, 1)?; + use PrimitiveType::*; + use ReturnableType::*; + + let (bad, idx) = match (opt_ty_1, opt_ty_2) { + (None | Some(Int | Float), None | Some(Int | Float)) => return Ok(Float.into()), + (Some(x), None | Some(_)) => (x, 0), + (None, Some(x)) => (x, 1), + }; + + Err(Error::BadFuncArgType { + name: "".to_owned(), + idx, + expected: "an int or float".to_owned(), + got: Type::Primitive(bad), + }) +} + +fn ty_arithmetic_binary_ops(args: &[Type]) -> Result { + // ensure both ops result in primitive types or unknown + let opt_ty_1 = expect_primitive_at(args, 0)?; + let opt_ty_2 = expect_primitive_at(args, 1)?; + use PrimitiveType::*; + use ReturnableType::*; + + let (bad, idx) = match (opt_ty_1, opt_ty_2) { + (None, None) => return Ok(Unknown), + (None | Some(Int), None | Some(Int)) => return Ok(Primitive(Int)), + (None | Some(Int | Float), None | Some(Int | Float)) => return Ok(Primitive(Float)), + (None, Some(Span)) => return Ok(Unknown), + (Some(Span), None | Some(Span)) => return Ok(Primitive(DateTime)), + (Some(DateTime), None | Some(Span)) => return Ok(Primitive(DateTime)), + (Some(x), _) => (x, 0), + (_, Some(x)) => (x, 1), + }; + + Err(Error::BadFuncArgType { + name: "".to_owned(), + idx, + expected: "a float, int, span, or datetime".to_owned(), + got: Type::Primitive(bad), + }) +} + +fn ty_foreach(args: &[Type]) -> Result { + expect_array_at(args, 1)?; + let first_arg = args.first().ok_or(Error::InternalError( + "we were supposed to have already checked that there are at least two arguments".to_owned(), + ))?; + let fty = match first_arg { + Type::Lambda(f) => f, + other => { + return Err(Error::BadFuncArgType { + name: "foreach".to_owned(), + idx: 0, + expected: "lambda".to_owned(), + got: other.clone(), + }); + } + }; + fty.get_return_type() +} + +fn ty_comp(args: &[Type]) -> Result { + let resp = Ok(Bool.into()); + let opt_ty_1: Option = expect_primitive_at(args, 0)?; + let opt_ty_2: Option = expect_primitive_at(args, 1)?; + use PrimitiveType::*; + use ReturnableType::*; + let (bad, idx) = match (opt_ty_1, opt_ty_2) { + (None, None) => return Ok(Primitive(Bool)), + (None | Some(Int), None | Some(Int)) => return resp, + (None | Some(Int | Float), None | Some(Int | Float)) => return resp, + (None | Some(Span), None | Some(Span)) => return resp, + (None | Some(Bool), None | Some(Bool)) => return resp, + (None | Some(DateTime), None | Some(DateTime)) => return resp, + (Some(x), _) => (x, 0), + (_, Some(x)) => (x, 1), + }; + Err(Error::BadFuncArgType { + name: "".to_owned(), + idx, + expected: "a float, int, bool, span, or datetime".to_owned(), + got: Type::Primitive(bad), + }) +} + +fn ty_count(args: &[Type]) -> Result { + Ok(PrimitiveType::Int.into()) +} + +fn ty_avg(args: &[Type]) -> Result { + use PrimitiveType::*; + use ReturnableType::*; + let arr_ty = expect_array_at(args, 0)?; + match arr_ty { + None | Some(Int) | Some(Float) => Ok(Float.into()), + Some(x) => Err(Error::BadFuncArgType { + name: "".to_owned(), + idx: 0, + expected: "array of ints or floats".to_owned(), + got: Type::Array(Some(x)), + }), + } +} + +fn ty_duration(args: &[Type]) -> Result { + use PrimitiveType::*; + use ReturnableType::*; + let opt_ty_1 = expect_primitive_at(args, 0)?; + let opt_ty_2 = expect_primitive_at(args, 1)?; + match opt_ty_1 { + None | (Some(DateTime)) => (), + Some(got) => { + return Err(Error::BadFuncArgType { + name: "".to_owned(), + idx: 0, + expected: "a datetime".to_owned(), + got: Type::Primitive(got), + }); + } + } + match opt_ty_2 { + None | (Some(DateTime)) => (), + Some(got) => { + return Err(Error::BadFuncArgType { + name: "".to_owned(), + idx: 1, + expected: "a datetime".to_owned(), + got: Type::Primitive(got), + }); + } + } + Ok(PrimitiveType::Span.into()) +} + +fn ty_bool_unary(args: &[Type]) -> Result { + use PrimitiveType::*; + use ReturnableType::*; + match expect_primitive_at(args, 0)? { + None | (Some(Bool)) => Ok(PrimitiveType::Bool.into()), + Some(got) => Err(Error::BadFuncArgType { + name: "".to_owned(), + idx: 0, + expected: "a bool".to_owned(), + got: Type::Primitive(got), + }), + } +} + +fn ty_bool_binary(args: &[Type]) -> Result { + use PrimitiveType::*; + use ReturnableType::*; + let opt_ty_1 = expect_primitive_at(args, 0)?; + let opt_ty_2 = expect_primitive_at(args, 1)?; + match opt_ty_1 { + None | (Some(Bool)) => (), + Some(got) => { + return Err(Error::BadFuncArgType { + name: "".to_owned(), + idx: 0, + expected: "a bool".to_owned(), + got: Type::Primitive(got), + }); + } + } + match opt_ty_2 { + None | (Some(Bool)) => (), + Some(got) => { + return Err(Error::BadFuncArgType { + name: "".to_owned(), + idx: 1, + expected: "a bool".to_owned(), + got: Type::Primitive(got), + }); + } + } + Ok(PrimitiveType::Bool.into()) +} impl<'parent> Env<'parent> { /// Create an empty environment. @@ -43,48 +369,50 @@ impl<'parent> Env<'parent> { /// Create the standard environment. pub fn std() -> Self { + use FuncReturnType::*; + use PrimitiveType::*; let mut env = Env::empty(); // Comparison functions. - env.add_fn("gt", gt); - env.add_fn("lt", lt); - env.add_fn("gte", gte); - env.add_fn("lte", lte); - env.add_fn("eq", eq); - env.add_fn("neq", neq); + env.add_fn("gt", gt, 2, ty_comp); + env.add_fn("lt", lt, 2, ty_comp); + env.add_fn("gte", gte, 2, ty_comp); + env.add_fn("lte", lte, 2, ty_comp); + env.add_fn("eq", eq, 2, ty_comp); + env.add_fn("neq", neq, 2, ty_comp); // Math functions. - env.add_fn("add", add); - env.add_fn("sub", sub); - env.add_fn("divz", divz); + env.add_fn("add", add, 2, ty_arithmetic_binary_ops); + env.add_fn("sub", sub, 2, ty_arithmetic_binary_ops); + env.add_fn("divz", divz, 2, ty_divz); // Additional datetime math functions - env.add_fn("duration", duration); + env.add_fn("duration", duration, 2, ty_duration); // Logical functions. - env.add_fn("and", and); - env.add_fn("or", or); - env.add_fn("not", not); + env.add_fn("and", and, 2, ty_bool_binary); + env.add_fn("or", or, 2, ty_bool_binary); + env.add_fn("not", not, 1, ty_bool_unary); // Array math functions. - env.add_fn("max", max); - env.add_fn("min", min); - env.add_fn("avg", avg); - env.add_fn("median", median); - env.add_fn("count", count); + env.add_fn("max", max, 1, ty_from_first_arr); + env.add_fn("min", min, 1, ty_from_first_arr); + env.add_fn("avg", avg, 1, ty_avg); + env.add_fn("median", median, 1, ty_from_first_arr); + env.add_fn("count", count, 1, ty_count); // Array logic functions. - env.add_fn("all", all); - env.add_fn("nall", nall); - env.add_fn("some", some); - env.add_fn("none", none); + env.add_fn("all", all, 1, ty_higher_order_bool_fn); + env.add_fn("nall", nall, 1, ty_higher_order_bool_fn); + env.add_fn("some", some, 1, ty_higher_order_bool_fn); + env.add_fn("none", none, 1, ty_higher_order_bool_fn); // Array higher-order functions. - env.add_fn("filter", filter); - env.add_fn("foreach", foreach); + env.add_fn("filter", filter, 2, ty_filter); + env.add_fn("foreach", foreach, 2, ty_foreach); // Debugging functions. - env.add_fn("dbg", dbg); + env.add_fn("dbg", dbg, 1, ty_inherit_first); env } @@ -103,8 +431,22 @@ impl<'parent> Env<'parent> { } /// Add a function to the environment. - pub fn add_fn(&mut self, name: &str, op: Op) -> Option { - self.bindings.insert(name.to_owned(), Binding::Fn(op)) + pub fn add_fn( + &mut self, + name: &str, + op: Op, + expected_args: usize, + ty_checker: TypeChecker, + ) -> Option { + self.bindings.insert( + name.to_owned(), + Binding::Fn(FunctionDef { + name: name.to_owned(), + expected_args, + ty_checker, + op, + }), + ) } /// Get a binding from the environment, walking up the scopes. @@ -136,7 +478,7 @@ fn check_num_args(name: &str, args: &[Expr], expected: usize) -> Result<()> { } /// Partially evaluate a binary operation on primitives. -fn partially_evaluate(fn_name: &'static str, arg: Expr) -> Result { +pub fn partially_evaluate(env: &Env, fn_name: &str, arg: Expr) -> Result { let var_name = "x"; let var = Ident(String::from(var_name)); let func = Ident(String::from(fn_name)); @@ -144,8 +486,10 @@ fn partially_evaluate(fn_name: &'static str, arg: Expr) -> Result { // function lambda to make higher-order functions read better. // e.g. `(filter (lt 3) [])` would actually check if array elements are // greater than 3 if we put the placeholder var second - let op = StructFunction::new(func, vec![Primitive(Identifier(var.clone())), arg]).into(); - let lambda = StructLambda::new(var, Box::new(op)).into(); + let op = + StructFunction::new(func, vec![Primitive(Identifier(var.clone())), arg]).resolve(env)?; + let lambda: Expr = StructLambda::new(var, op).into(); + lambda.get_type()?; Ok(lambda) } @@ -159,7 +503,7 @@ where F: FnOnce(Primitive, Primitive) -> Result, { if args.len() == 1 { - return partially_evaluate(name, args[0].clone()); + return partially_evaluate(env, name, args[0].clone()); } check_num_args(name, args, 2)?; @@ -216,7 +560,7 @@ where /// Define a higher-order operation over arrays. fn higher_order_array_op(name: &'static str, env: &Env, args: &[Expr], op: F) -> Result where - F: FnOnce(ArrayType, Ident, Box) -> Result, + F: FnOnce(ArrayType, Ident, Function) -> Result, { check_num_args(name, args, 2)?; @@ -322,14 +666,14 @@ fn array_type(arr: &[Primitive]) -> Result { } /// Evaluate the lambda, injecting into the environment. -fn eval_lambda(env: &Env, ident: &Ident, val: Primitive, body: Expr) -> Result { +fn eval_lambda(env: &Env, ident: &Ident, val: Primitive, body: Function) -> Result { let mut child = env.child(); if child.add_var(&ident.0, val).is_some() { return Err(Error::AlreadyBound); } - child.visit_expr(body) + child.visit_function(body) } #[allow(clippy::bool_comparison)] @@ -710,35 +1054,35 @@ fn count(env: &Env, args: &[Expr]) -> Result { fn all(env: &Env, args: &[Expr]) -> Result { let name = "all"; - let op = |arr, ident: Ident, body: Box| { + let op = |arr, ident: Ident, body: Function| { let result = match arr { ArrayType::Int(ints) => ints .iter() - .map(|val| eval_lambda(env, &ident, Int(*val), (*body).clone())) + .map(|val| eval_lambda(env, &ident, Int(*val), body.clone())) .process_results(|mut iter| { iter.all(|expr| matches!(expr, Primitive(Bool(true)))) })?, ArrayType::Float(floats) => floats .iter() - .map(|val| eval_lambda(env, &ident, Float(*val), (*body).clone())) + .map(|val| eval_lambda(env, &ident, Float(*val), body.clone())) .process_results(|mut iter| { iter.all(|expr| matches!(expr, Primitive(Bool(true)))) })?, ArrayType::Bool(bools) => bools .iter() - .map(|val| eval_lambda(env, &ident, Bool(*val), (*body).clone())) + .map(|val| eval_lambda(env, &ident, Bool(*val), body.clone())) .process_results(|mut iter| { iter.all(|expr| matches!(expr, Primitive(Bool(true)))) })?, ArrayType::DateTime(dts) => dts .iter() - .map(|val| eval_lambda(env, &ident, DateTime(val.clone()), (*body).clone())) + .map(|val| eval_lambda(env, &ident, DateTime(val.clone()), body.clone())) .process_results(|mut iter| { iter.all(|expr| matches!(expr, Primitive(Bool(true)))) })?, ArrayType::Span(spans) => spans .iter() - .map(|val| eval_lambda(env, &ident, Span(*val), (*body).clone())) + .map(|val| eval_lambda(env, &ident, Span(*val), body.clone())) .process_results(|mut iter| { iter.all(|expr| matches!(expr, Primitive(Bool(true)))) })?, @@ -754,35 +1098,35 @@ fn all(env: &Env, args: &[Expr]) -> Result { fn nall(env: &Env, args: &[Expr]) -> Result { let name = "nall"; - let op = |arr, ident: Ident, body: Box| { + let op = |arr, ident: Ident, body: Function| { let result = match arr { ArrayType::Int(ints) => ints .iter() - .map(|val| eval_lambda(env, &ident, Int(*val), (*body).clone())) + .map(|val| eval_lambda(env, &ident, Int(*val), body.clone())) .process_results(|mut iter| { iter.all(|expr| matches!(expr, Primitive(Bool(true)))).not() })?, ArrayType::Float(floats) => floats .iter() - .map(|val| eval_lambda(env, &ident, Float(*val), (*body).clone())) + .map(|val| eval_lambda(env, &ident, Float(*val), body.clone())) .process_results(|mut iter| { iter.all(|expr| matches!(expr, Primitive(Bool(true)))).not() })?, ArrayType::Bool(bools) => bools .iter() - .map(|val| eval_lambda(env, &ident, Bool(*val), (*body).clone())) + .map(|val| eval_lambda(env, &ident, Bool(*val), body.clone())) .process_results(|mut iter| { iter.all(|expr| matches!(expr, Primitive(Bool(true)))).not() })?, ArrayType::DateTime(dts) => dts .iter() - .map(|val| eval_lambda(env, &ident, DateTime(val.clone()), (*body).clone())) + .map(|val| eval_lambda(env, &ident, DateTime(val.clone()), body.clone())) .process_results(|mut iter| { iter.all(|expr| matches!(expr, Primitive(Bool(true)))).not() })?, ArrayType::Span(spans) => spans .iter() - .map(|val| eval_lambda(env, &ident, Span(*val), (*body).clone())) + .map(|val| eval_lambda(env, &ident, Span(*val), body.clone())) .process_results(|mut iter| { iter.all(|expr| matches!(expr, Primitive(Bool(true)))).not() })?, @@ -798,35 +1142,35 @@ fn nall(env: &Env, args: &[Expr]) -> Result { fn some(env: &Env, args: &[Expr]) -> Result { let name = "some"; - let op = |arr, ident: Ident, body: Box| { + let op = |arr, ident: Ident, body: Function| { let result = match arr { ArrayType::Int(ints) => ints .iter() - .map(|val| eval_lambda(env, &ident, Int(*val), (*body).clone())) + .map(|val| eval_lambda(env, &ident, Int(*val), body.clone())) .process_results(|mut iter| { iter.any(|expr| matches!(expr, Primitive(Bool(true)))) })?, ArrayType::Float(floats) => floats .iter() - .map(|val| eval_lambda(env, &ident, Float(*val), (*body).clone())) + .map(|val| eval_lambda(env, &ident, Float(*val), body.clone())) .process_results(|mut iter| { iter.any(|expr| matches!(expr, Primitive(Bool(true)))) })?, ArrayType::Bool(bools) => bools .iter() - .map(|val| eval_lambda(env, &ident, Bool(*val), (*body).clone())) + .map(|val| eval_lambda(env, &ident, Bool(*val), body.clone())) .process_results(|mut iter| { iter.any(|expr| matches!(expr, Primitive(Bool(true)))) })?, ArrayType::DateTime(dts) => dts .iter() - .map(|val| eval_lambda(env, &ident, DateTime(val.clone()), (*body).clone())) + .map(|val| eval_lambda(env, &ident, DateTime(val.clone()), body.clone())) .process_results(|mut iter| { iter.any(|expr| matches!(expr, Primitive(Bool(true)))) })?, ArrayType::Span(spans) => spans .iter() - .map(|val| eval_lambda(env, &ident, Span(*val), (*body).clone())) + .map(|val| eval_lambda(env, &ident, Span(*val), body.clone())) .process_results(|mut iter| { iter.any(|expr| matches!(expr, Primitive(Bool(true)))) })?, @@ -842,35 +1186,35 @@ fn some(env: &Env, args: &[Expr]) -> Result { fn none(env: &Env, args: &[Expr]) -> Result { let name = "none"; - let op = |arr, ident: Ident, body: Box| { + let op = |arr, ident: Ident, body: Function| { let result = match arr { ArrayType::Int(ints) => ints .iter() - .map(|val| eval_lambda(env, &ident, Int(*val), (*body).clone())) + .map(|val| eval_lambda(env, &ident, Int(*val), body.clone())) .process_results(|mut iter| { iter.any(|expr| matches!(expr, Primitive(Bool(true)))).not() })?, ArrayType::Float(floats) => floats .iter() - .map(|val| eval_lambda(env, &ident, Float(*val), (*body).clone())) + .map(|val| eval_lambda(env, &ident, Float(*val), body.clone())) .process_results(|mut iter| { iter.any(|expr| matches!(expr, Primitive(Bool(true)))).not() })?, ArrayType::Bool(bools) => bools .iter() - .map(|val| eval_lambda(env, &ident, Bool(*val), (*body).clone())) + .map(|val| eval_lambda(env, &ident, Bool(*val), body.clone())) .process_results(|mut iter| { iter.any(|expr| matches!(expr, Primitive(Bool(true)))).not() })?, ArrayType::DateTime(dts) => dts .iter() - .map(|val| eval_lambda(env, &ident, DateTime(val.clone()), (*body).clone())) + .map(|val| eval_lambda(env, &ident, DateTime(val.clone()), body.clone())) .process_results(|mut iter| { iter.any(|expr| matches!(expr, Primitive(Bool(true)))).not() })?, ArrayType::Span(spans) => spans .iter() - .map(|val| eval_lambda(env, &ident, Span(*val), (*body).clone())) + .map(|val| eval_lambda(env, &ident, Span(*val), body.clone())) .process_results(|mut iter| { iter.any(|expr| matches!(expr, Primitive(Bool(true)))).not() })?, @@ -886,11 +1230,11 @@ fn none(env: &Env, args: &[Expr]) -> Result { fn filter(env: &Env, args: &[Expr]) -> Result { let name = "filter"; - let op = |arr, ident: Ident, body: Box| { + let op = |arr, ident: Ident, body: Function| { let arr = match arr { ArrayType::Int(ints) => ints .iter() - .map(|val| Ok((val, eval_lambda(env, &ident, Int(*val), (*body).clone())))) + .map(|val| Ok((val, eval_lambda(env, &ident, Int(*val), body.clone())))) .filter_map_ok(|(val, expr)| { if let Ok(Primitive(Bool(true))) = expr { Some(Primitive::Int(*val)) @@ -901,7 +1245,7 @@ fn filter(env: &Env, args: &[Expr]) -> Result { .collect::>>()?, ArrayType::Float(floats) => floats .iter() - .map(|val| Ok((val, eval_lambda(env, &ident, Float(*val), (*body).clone())))) + .map(|val| Ok((val, eval_lambda(env, &ident, Float(*val), body.clone())))) .filter_map_ok(|(val, expr)| { if let Ok(Primitive(Bool(true))) = expr { Some(Primitive::Float(*val)) @@ -912,7 +1256,7 @@ fn filter(env: &Env, args: &[Expr]) -> Result { .collect::>>()?, ArrayType::Bool(bools) => bools .iter() - .map(|val| Ok((val, eval_lambda(env, &ident, Bool(*val), (*body).clone())))) + .map(|val| Ok((val, eval_lambda(env, &ident, Bool(*val), body.clone())))) .filter_map_ok(|(val, expr)| { if let Ok(Primitive(Bool(true))) = expr { Some(Primitive::Bool(*val)) @@ -926,7 +1270,7 @@ fn filter(env: &Env, args: &[Expr]) -> Result { .map(|val| { Ok(( val, - eval_lambda(env, &ident, DateTime(val.clone()), (*body).clone()), + eval_lambda(env, &ident, DateTime(val.clone()), body.clone()), )) }) .filter_map_ok(|(val, expr)| { @@ -939,7 +1283,7 @@ fn filter(env: &Env, args: &[Expr]) -> Result { .collect::>>()?, ArrayType::Span(spans) => spans .iter() - .map(|val| Ok((val, eval_lambda(env, &ident, Span(*val), (*body).clone())))) + .map(|val| Ok((val, eval_lambda(env, &ident, Span(*val), body.clone())))) .filter_map_ok(|(val, expr)| { if let Ok(Primitive(Bool(true))) = expr { Some(Primitive::Span(*val)) @@ -960,11 +1304,11 @@ fn filter(env: &Env, args: &[Expr]) -> Result { fn foreach(env: &Env, args: &[Expr]) -> Result { let name = "foreach"; - let op = |arr, ident: Ident, body: Box| { + let op = |arr, ident: Ident, body: Function| { let arr = match arr { ArrayType::Int(ints) => ints .iter() - .map(|val| eval_lambda(env, &ident, Int(*val), (*body).clone())) + .map(|val| eval_lambda(env, &ident, Int(*val), body.clone())) .map(|expr| match expr { Ok(Primitive(inner)) => Ok(inner), Ok(_) => Err(Error::BadType(name)), @@ -973,7 +1317,7 @@ fn foreach(env: &Env, args: &[Expr]) -> Result { .collect::>>()?, ArrayType::Float(floats) => floats .iter() - .map(|val| eval_lambda(env, &ident, Float(*val), (*body).clone())) + .map(|val| eval_lambda(env, &ident, Float(*val), body.clone())) .map(|expr| match expr { Ok(Primitive(inner)) => Ok(inner), Ok(_) => Err(Error::BadType(name)), @@ -982,7 +1326,7 @@ fn foreach(env: &Env, args: &[Expr]) -> Result { .collect::>>()?, ArrayType::Bool(bools) => bools .iter() - .map(|val| eval_lambda(env, &ident, Bool(*val), (*body).clone())) + .map(|val| eval_lambda(env, &ident, Bool(*val), body.clone())) .map(|expr| match expr { Ok(Primitive(inner)) => Ok(inner), Ok(_) => Err(Error::BadType(name)), @@ -991,7 +1335,7 @@ fn foreach(env: &Env, args: &[Expr]) -> Result { .collect::>>()?, ArrayType::DateTime(dts) => dts .iter() - .map(|val| eval_lambda(env, &ident, DateTime(val.clone()), (*body).clone())) + .map(|val| eval_lambda(env, &ident, DateTime(val.clone()), body.clone())) .map(|expr| match expr { Ok(Primitive(inner)) => Ok(inner), Ok(_) => Err(Error::BadType(name)), @@ -1000,7 +1344,7 @@ fn foreach(env: &Env, args: &[Expr]) -> Result { .collect::>>()?, ArrayType::Span(spans) => spans .iter() - .map(|val| eval_lambda(env, &ident, Span(*val), (*body).clone())) + .map(|val| eval_lambda(env, &ident, Span(*val), body.clone())) .map(|expr| match expr { Ok(Primitive(inner)) => Ok(inner), Ok(_) => Err(Error::BadType(name)), diff --git a/hipcheck/src/policy_exprs/error.rs b/hipcheck/src/policy_exprs/error.rs index 6469cfd3..024bb2c1 100644 --- a/hipcheck/src/policy_exprs/error.rs +++ b/hipcheck/src/policy_exprs/error.rs @@ -1,6 +1,9 @@ // SPDX-License-Identifier: Apache-2.0 -use crate::policy_exprs::{Expr, Ident, LexingError}; +use crate::policy_exprs::{ + expr::{PrimitiveType, Type}, + Expr, Ident, LexingError, +}; use jiff::Error as JError; use nom::{error::ErrorKind, Needed}; use ordered_float::FloatIsNan; @@ -72,6 +75,21 @@ pub enum Error { #[error("called '{0}' with mismatched types")] BadType(&'static str), + #[error("call to '{name}' with '{got:?}' as argument {idx}, expected {expected}")] + BadFuncArgType { + name: String, + idx: usize, + expected: String, + got: Type, + }, + + #[error("array of {expected:?}s contains a {got:?} at idx {idx}")] + BadArrayElt { + idx: usize, + expected: PrimitiveType, + got: PrimitiveType, + }, + #[error("no max value found in array")] NoMax, diff --git a/hipcheck/src/policy_exprs/expr.rs b/hipcheck/src/policy_exprs/expr.rs index 019c0510..66301c82 100644 --- a/hipcheck/src/policy_exprs/expr.rs +++ b/hipcheck/src/policy_exprs/expr.rs @@ -15,7 +15,13 @@ use nom::{ Finish as _, IResult, }; use ordered_float::NotNan; -use std::{fmt::Display, ops::Deref}; +use std::{ + cmp::Ordering, + fmt::Display, + mem::{discriminant, Discriminant}, + ops::Deref, + sync::LazyLock, +}; #[cfg(test)] use jiff::civil::Date; @@ -55,15 +61,90 @@ impl From for Expr { } } +/// Helper type for operation function pointer. +pub type Op = fn(&Env, &[Expr]) -> Result; + +#[derive(Clone, PartialEq, Debug, Eq)] +pub struct OpInfo { + pub fn_ty: FuncReturnType, + pub expected_args: usize, + pub op: Op, +} + +pub type TypeChecker = fn(&[Type]) -> Result; + +#[derive(Clone, PartialEq, Debug, Eq)] +pub struct FunctionDef { + pub name: String, + pub expected_args: usize, + pub ty_checker: TypeChecker, + pub op: Op, +} +impl FunctionDef { + pub fn type_check(&self, args: &[Type]) -> Result { + match args.len().cmp(&self.expected_args) { + Ordering::Less => { + return Err(Error::NotEnoughArgs { + name: self.name.clone(), + expected: self.expected_args, + given: args.len(), + }); + } + Ordering::Greater => { + return Err(Error::TooManyArgs { + name: self.name.clone(), + expected: self.expected_args, + given: args.len(), + }); + } + _ => (), + } + let mut res = (self.ty_checker)(args); + // There's probably a better way to augment err with name + if let Err(Error::BadFuncArgType { name, .. }) = &mut res { + if name.is_empty() { + *name = self.name.clone(); + } + }; + res + } + pub fn execute(&self, env: &Env, args: &[Expr]) -> Result { + let types = args + .iter() + .map(|a| a.get_type()) + .collect::>>()?; + self.type_check(types.as_slice()); + (self.op)(env, args) + } +} + /// A `deke` function to evaluate. #[derive(Debug, PartialEq, Eq, Clone)] pub struct Function { pub ident: Ident, pub args: Vec, + pub opt_def: Option, } impl Function { pub fn new(ident: Ident, args: Vec) -> Self { - Function { ident, args } + let opt_def = None; + Function { + ident, + args, + opt_def, + } + } + pub fn resolve(&self, env: &Env) -> Result { + let Some(Binding::Fn(op_info)) = env.get(&self.ident.0) else { + return Err(Error::UnknownFunction(self.ident.0.clone())); + }; + let ident = self.ident.clone(); + let args = self.args.clone(); + Ok(Function { + ident, + args, + opt_def: Some(op_info), + }) } } impl From for Expr { @@ -71,15 +152,20 @@ impl From for Expr { Expr::Function(value) } } +impl From for Type { + fn from(value: FunctionType) -> Self { + Type::Function(value) + } +} /// Stores the name of the input variable, followed by the lambda body. #[derive(Debug, PartialEq, Eq, Clone)] pub struct Lambda { pub arg: Ident, - pub body: Box, + pub body: Function, } impl Lambda { - pub fn new(arg: Ident, body: Box) -> Self { + pub fn new(arg: Ident, body: Function) -> Self { Lambda { arg, body } } } @@ -132,6 +218,208 @@ impl From for Expr { } } +// TYPING + +impl Primitive { + pub fn get_primitive_type(&self) -> PrimitiveType { + use PrimitiveType::*; + match self { + Primitive::Identifier(_) => Ident, + Primitive::Int(_) => Int, + Primitive::Float(_) => Float, + Primitive::Bool(_) => Bool, + Primitive::DateTime(_) => DateTime, + Primitive::Span(_) => Span, + } + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum PrimitiveType { + Ident, + Int, + Float, + Bool, + DateTime, + Span, +} + +pub type ArrayType = Option; + +// A limited set of types that we allow a function to return +#[derive(Debug, Clone, PartialEq, Eq, Copy)] +pub enum ReturnableType { + Primitive(PrimitiveType), + Array(ArrayType), + Unknown, +} + +impl From for ReturnableType { + fn from(value: PrimitiveType) -> ReturnableType { + ReturnableType::Primitive(value) + } +} + +// We allow overloaded functions, such that the returned type is dependent on +// the input operand types. This enum encapsulates both static and dynamically +// determined return types. +#[derive(Debug, Clone, PartialEq, Eq, Copy)] +pub enum FuncReturnType { + Dynamic(fn(&[Type]) -> Result), + Static(ReturnableType), +} + +// A function signature is the combination of the return type and the arg types +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FunctionType { + pub def: FunctionDef, + pub arg_tys: Vec, +} + +impl FunctionType { + pub fn get_return_type(&self) -> Result { + self.def.type_check(&self.arg_tys) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Type { + Primitive(PrimitiveType), + Function(FunctionType), + Lambda(FunctionType), + Array(ArrayType), + Unknown, +} + +impl Type { + pub fn get_return_type(&self) -> Result { + self.try_into() + } +} + +impl TryFrom<&Type> for ReturnableType { + type Error = crate::policy_exprs::Error; + fn try_from(value: &Type) -> Result { + Ok(match value { + Type::Function(fn_ty) | Type::Lambda(fn_ty) => fn_ty.get_return_type()?, + Type::Array(arr_ty) => ReturnableType::Array(*arr_ty), + Type::Primitive(PrimitiveType::Ident) => ReturnableType::Unknown, + Type::Primitive(p_ty) => ReturnableType::Primitive(*p_ty), + Type::Unknown => ReturnableType::Unknown, + }) + } +} + +impl Display for PrimitiveType { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + +pub trait Typed { + fn get_type(&self) -> Result; +} + +impl Typed for Primitive { + fn get_type(&self) -> Result { + Ok(Type::Primitive(self.get_primitive_type())) + } +} + +impl Typed for Array { + // Treat first found elt type as the de-facto type of the array. Any subsequent elts that + // disagree are considered errors + fn get_type(&self) -> Result { + let mut ty: Option = None; + + for (idx, elt) in self.elts.iter().enumerate() { + let curr_ty = elt.get_primitive_type(); + + if let Some(expected_ty) = ty { + if expected_ty != curr_ty { + return Err(Error::BadArrayElt { + idx, + expected: expected_ty, + got: curr_ty, + }); + } + } else { + ty = Some(elt.get_primitive_type()); + } + } + + Ok(Type::Array(ty)) + } +} + +impl Typed for Function { + fn get_type(&self) -> Result { + use FuncReturnType::*; + + // Can't get a type if we haven't resolved the function + let Some(def) = self.opt_def.clone() else { + return Err(Error::UnknownFunction(self.ident.0.clone())); + }; + + // Get types of each argument + let arg_tys: Vec = self + .args + .iter() + .map(Typed::get_type) + .collect::>>()?; + + let fn_type = FunctionType { def, arg_tys }; + + // If we are off by one, treat as a lambda + if fn_type.arg_tys.len() == fn_type.def.expected_args - 1 { + Ok(Type::Lambda(fn_type)) + } else { + Ok(fn_type.into()) + } + } +} + +impl Typed for Lambda { + // @Todo - Lambda should be a FunctionType that takes 1 argument and + // contains an interior reference to the function it wraps. + // To get its return type, we should combine Unknown with the + // other typed args to the function and evaluate. + fn get_type(&self) -> Result { + let fty = match self.body.get_type()? { + Type::Function(f) => f, + other => { + return Err(Error::InternalError(format!("Body of a lambda expr should be a function with a placeholder var, got {other:?}"))); + } + }; + + // we need a handle to the function to get a type + Ok(Type::Lambda(fty)) + } +} + +impl Typed for JsonPointer { + fn get_type(&self) -> Result { + if let Some(val) = self.value.as_ref() { + val.get_type() + } else { + Ok(Type::Unknown) + } + } +} + +impl Typed for Expr { + fn get_type(&self) -> Result { + use Expr::*; + match self { + Primitive(p) => p.get_type(), + Array(a) => a.get_type(), + Function(f) => f.get_type(), + Lambda(l) => l.get_type(), + JsonPointer(j) => j.get_type(), + } + } +} + /// A variable or function identifier. #[derive(Debug, Clone, PartialEq, Eq)] pub struct Ident(pub String); @@ -143,6 +431,7 @@ pub struct JsonPointer { pub pointer: String, pub value: Option>, } + impl From for Expr { fn from(value: JsonPointer) -> Self { Expr::JsonPointer(value) @@ -163,11 +452,8 @@ impl Display for Expr { array.elts.iter().map(ToString::to_string).join(" ") ) } - Expr::Function(func) => { - let args = func.args.iter().map(ToString::to_string).join(" "); - write!(f, "({} {})", func.ident, args) - } - Expr::Lambda(l) => write!(f, "(lambda ({}) {}", l.arg, l.body), + Expr::Function(func) => func.fmt(f), + Expr::Lambda(l) => write!(f, "(lambda ({}) {})", l.arg, l.body), Expr::JsonPointer(pointer) => write!(f, "${}", pointer.pointer), } } @@ -203,6 +489,13 @@ impl Primitive { } } +impl Display for Function { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let args = self.args.iter().map(ToString::to_string).join(" "); + write!(f, "({} {})", self.ident, args) + } +} + impl Display for Ident { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0) diff --git a/hipcheck/src/policy_exprs/mod.rs b/hipcheck/src/policy_exprs/mod.rs index 1a9c21b4..fc9c7fec 100644 --- a/hipcheck/src/policy_exprs/mod.rs +++ b/hipcheck/src/policy_exprs/mod.rs @@ -14,8 +14,11 @@ use crate::policy_exprs::env::Env; pub(crate) use crate::policy_exprs::{bridge::Tokens, expr::F64}; pub use crate::policy_exprs::{ error::{Error, Result}, - expr::{Array, Expr, Function, Ident, JsonPointer, Lambda}, - pass::{ExprMutator, ExprVisitor}, + expr::{ + Array, Expr, Function, Ident, JsonPointer, Lambda, PrimitiveType, ReturnableType, Type, + Typed, + }, + pass::{ExprMutator, ExprVisitor, FunctionResolver, TypeChecker, TypeFixer}, token::LexingError, }; use env::Binding; @@ -57,17 +60,37 @@ impl ExprMutator for Env<'_> { Ok(prim.resolve(self)?.into()) } fn visit_function(&self, f: Function) -> Result { + let mut f = f; + // first evaluate all the children + f.args = f + .args + .into_iter() + .map(|a| self.visit_expr(a)) + .collect::>>()?; let binding = self .get(&f.ident) .ok_or_else(|| Error::UnknownFunction(f.ident.deref().to_owned()))?; - if let Binding::Fn(op) = binding { - (op)(self, &f.args) + if let Binding::Fn(op_info) = binding { + // Doesn't use `execute` because currently allows Functions that haven't been changed + // to Lambdas + (op_info.op)(self, &f.args) } else { Err(Error::FoundVarExpectedFunc(f.ident.deref().to_owned())) } } - fn visit_lambda(&self, l: Lambda) -> Result { - Ok((*l.body).clone()) + fn visit_lambda(&self, mut l: Lambda) -> Result { + // Eagerly evaluate the arguments to the lambda but not the func itself + // Visit args, but ignore lambda ident because not yet bound + l.body.args = l + .body + .args + .drain(..) + .map(|a| match a { + Expr::Primitive(Primitive::Identifier(_)) => Ok(a), + b => self.visit_expr(b), + }) + .collect::>>()?; + Ok(l.into()) } fn visit_json_pointer(&self, jp: JsonPointer) -> Result { let expr = &jp.value; @@ -215,6 +238,11 @@ mod tests { let program = "(eq 3 (count (filter (gt 8.0) (foreach (sub 1.0) [1.0 2.0 10.0 20.0 30.0]))))"; let context = Value::Null; + let expr = parse(&program).unwrap(); + println!("EXPR: {:?}", &expr); + let expr = FunctionResolver::std().run(expr).unwrap(); + let expr = TypeFixer::std().run(expr).unwrap(); + println!("RESOLVER RES: {:?}", expr); let result = Executor::std().parse_and_eval(program, &context).unwrap(); assert_eq!(result, Primitive::Bool(true).into()); } @@ -267,4 +295,83 @@ mod tests { .unwrap(); assert_eq!(expected, result2); } + + #[test] + fn type_lambda() { + let program = "(gt #t)"; + let expr = parse(&program).unwrap(); + let expr = FunctionResolver::std().run(expr).unwrap(); + let expr = TypeFixer::std().run(expr).unwrap(); + let res_ty = TypeChecker::default().run(&expr); + let Ok(Type::Lambda(l_ty)) = res_ty else { + assert!(false); + return; + }; + let ret_ty = l_ty.get_return_type(); + assert_eq!(ret_ty, Ok(ReturnableType::Primitive(PrimitiveType::Bool))); + } + + #[test] + fn type_filter_bad_lambda_array() { + // Should fail because can't compare ints and bools + let program = "(filter (gt #t) [1 2])"; + let expr = parse(&program).unwrap(); + let expr = FunctionResolver::std().run(expr).unwrap(); + let expr = TypeFixer::std().run(expr).unwrap(); + let res_ty = TypeChecker::default().run(&expr); + assert!(matches!( + res_ty, + Err(Error::BadFuncArgType { + idx: 0, + got: Type::Primitive(PrimitiveType::Int), + .. + }) + )); + } + + #[test] + fn type_array_mixed_types() { + // Should fail because array elts must have one primitive type + let program = "(count [#t 2])"; + let mut expr = parse(&program).unwrap(); + expr = FunctionResolver::std().run(expr).unwrap(); + let res_ty = TypeChecker::default().run(&expr); + assert_eq!( + res_ty, + Err(Error::BadArrayElt { + idx: 1, + expected: PrimitiveType::Bool, + got: PrimitiveType::Int + }) + ); + } + + #[test] + fn type_propagate_unknown() { + // Type for array should be unknown because we can't know ident type + let program = "(max [])"; + let mut expr = parse(&program).unwrap(); + expr = FunctionResolver::std().run(expr).unwrap(); + let res_ty = TypeChecker::default().run(&expr); + let Ok(Type::Function(f_ty)) = res_ty else { + assert!(false); + return; + }; + assert_eq!(f_ty.get_return_type(), Ok(ReturnableType::Unknown)); + } + + #[test] + fn type_not() { + let program = "(not $)"; + let mut expr = parse(&program).unwrap(); + expr = FunctionResolver::std().run(expr).unwrap(); + let res_ty = TypeChecker::default().run(&expr); + println!("RESTY: {res_ty:?}"); + let Ok(Type::Function(f_ty)) = res_ty else { + assert!(false); + return; + }; + let ret_ty = f_ty.get_return_type(); + assert_eq!(ret_ty, Ok(ReturnableType::Primitive(PrimitiveType::Bool))); + } } diff --git a/hipcheck/src/policy_exprs/pass.rs b/hipcheck/src/policy_exprs/pass.rs index 0c156e3c..16aa8630 100644 --- a/hipcheck/src/policy_exprs/pass.rs +++ b/hipcheck/src/policy_exprs/pass.rs @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 use crate::policy_exprs::{ - env::Env, + env::{partially_evaluate, Env}, error::{Error, Result}, expr::*, }; @@ -30,9 +30,11 @@ pub trait ExprMutator { fn visit_primitive(&self, prim: Primitive) -> Result { Ok(prim.into()) } + fn visit_array(&self, arr: Array) -> Result { Ok(arr.into()) } + fn visit_function(&self, func: Function) -> Result { let mut func = func; func.args = func @@ -42,14 +44,21 @@ pub trait ExprMutator { .collect::>>()?; Ok(func.into()) } + fn visit_lambda(&self, lamb: Lambda) -> Result { let mut lamb = lamb; - lamb.body = Box::new(self.visit_expr(*lamb.body.clone())?); + lamb.body = match self.visit_function(lamb.body)? { + Expr::Function(f) => f, + // if the impl of `visit_function` returned a non-function, just return that + other => return Ok(other), + }; Ok(lamb.into()) } + fn visit_json_pointer(&self, jp: JsonPointer) -> Result { Ok(jp.into()) } + fn visit_expr(&self, expr: Expr) -> Result { match expr { Expr::Primitive(a) => self.visit_primitive(a), @@ -59,7 +68,115 @@ pub trait ExprMutator { Expr::JsonPointer(a) => self.visit_json_pointer(a), } } + fn run(&self, expr: Expr) -> Result { self.visit_expr(expr) } } + +pub struct FunctionResolver { + env: Env<'static>, +} + +impl FunctionResolver { + pub fn std() -> Self { + FunctionResolver { env: Env::std() } + } +} + +impl ExprMutator for FunctionResolver { + fn visit_function(&self, func: Function) -> Result { + let mut func = func.resolve(&self.env)?; + func.args = func + .args + .drain(..) + .map(|a| self.visit_expr(a)) + .collect::>>()?; + Ok(Expr::Function(func)) + } + + fn visit_lambda(&self, mut func: Lambda) -> Result { + let new_body = self.visit_function(func.body)?; + func.body = match new_body { + Expr::Function(f) => f, + other => { + return Err(Error::InternalError(format!( + "FunctionResolver's `visit_function` impl should always return a function" + ))); + } + }; + Ok(Expr::Lambda(func)) + } +} + +#[derive(Default)] +pub struct TypeChecker {} + +impl ExprVisitor> for TypeChecker { + fn visit_primitive(&self, prim: &Primitive) -> Result { + prim.get_type() + } + + fn visit_array(&self, arr: &Array) -> Result { + arr.get_type() + } + + fn visit_function(&self, func: &Function) -> Result { + func.args + .iter() + .map(|a| self.visit_expr(a)) + .collect::>>()?; + + let Type::Function(ft) = func.get_type()? else { + return Err(Error::InternalError( + "expression must have been run through TypeFixer pass first".to_owned(), + )); + }; + // Check that the arguments to the function are correct + ft.get_return_type()?; + Ok(ft.into()) + } + + fn visit_lambda(&self, lamb: &Lambda) -> Result { + self.visit_function(&lamb.body)?; + lamb.get_type() + } + + fn visit_json_pointer(&self, jp: &JsonPointer) -> Result { + jp.get_type() + } +} + +pub struct TypeFixer { + env: Env<'static>, +} + +impl TypeFixer { + pub fn std() -> Self { + TypeFixer { env: Env::std() } + } +} + +impl ExprMutator for TypeFixer { + fn visit_function(&self, mut func: Function) -> Result { + // @FollowUp - should the FunctionResolver be combined into this? + func.args = func + .args + .drain(..) + .map(|a| self.visit_expr(a)) + .collect::>>()?; + let fn_ty = func.get_type()?; + // At this point we know it has info + match fn_ty { + Type::Function(ft) => Ok(func.into()), + Type::Lambda(lt) => { + // Have to feed the new expr through the current pass again + // for any additional transformations + let res = partially_evaluate(&self.env, &func.ident.0, func.args.remove(0))?; + self.visit_expr(res) + } + + _ => unreachable!(), + } + } +}