From 7779be0bc4bf5a253113cf93dd781e5b95284835 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Wed, 27 Apr 2022 10:25:24 -0400 Subject: [PATCH 01/14] Start refactor --- src/frontend/SignatureMismatch.ml | 2 +- src/frontend/Std_library_utils.ml | 38 + src/frontend/Typechecker.ml | 3363 ++++++++++++++------------- src/frontend/Typechecker.mli | 55 +- src/middle/Stan_math_signatures.ml | 15 +- src/middle/Stan_math_signatures.mli | 3 - 6 files changed, 1781 insertions(+), 1695 deletions(-) create mode 100644 src/frontend/Std_library_utils.ml diff --git a/src/frontend/SignatureMismatch.ml b/src/frontend/SignatureMismatch.ml index 7454d627fe..f204920ac3 100644 --- a/src/frontend/SignatureMismatch.ml +++ b/src/frontend/SignatureMismatch.ml @@ -398,7 +398,7 @@ let pp_math_lib_assignmentoperator_sigs ppf (lt, op) = | errors, _ -> Some (errors, true) in let pp_sigs ppf (signatures, omitted) = Fmt.pf ppf "@[%a%a@]" - (Fmt.list ~sep:Fmt.cut Stan_math_signatures.pp_math_sig) + (Fmt.list ~sep:Fmt.cut Std_library_utils.pp_math_sig) signatures (if omitted then Fmt.pf else Fmt.nop) "@ (Additional signatures omitted)" in diff --git a/src/frontend/Std_library_utils.ml b/src/frontend/Std_library_utils.ml new file mode 100644 index 0000000000..4b0628e9d8 --- /dev/null +++ b/src/frontend/Std_library_utils.ml @@ -0,0 +1,38 @@ +(** General functions and signatures for a Standard Library *) + +open Middle + +(* Types for the module representing the standard library *) +type fun_arg = UnsizedType.autodifftype * UnsizedType.t + +type signature = + UnsizedType.returntype * fun_arg list * Common.Helpers.mem_pattern + +type variadic_checker = + is_cond_dist:bool + -> Location_span.t + -> Environment.originblock + -> Environment.t + -> Ast.identifier + -> Ast.typed_expression list + -> Ast.typed_expression + +let pp_math_sig ppf (rt, args, mem_pattern) = + UnsizedType.pp ppf (UFun (args, rt, FnPlain, mem_pattern)) + +let pp_math_sigs ppf sigs = (Fmt.list ~sep:Fmt.cut pp_math_sig) ppf sigs +let pretty_print_math_sigs = Fmt.str "@[@,%a@]" pp_math_sigs + +module type Library = sig + val stan_math_signatures : (string, signature list) Hashtbl.t + (** Mapping from names to signature(s) of functions *) + + val distribution_families : string list + + val is_stan_math_function_name : string -> bool + (** Equivalent to [Hashtbl.mem stan_math_signatures s]*) + + val is_not_overloadable : string -> bool + val is_variadic_function_name : string -> bool + val operator_to_function_names : Operator.t -> string list +end diff --git a/src/frontend/Typechecker.ml b/src/frontend/Typechecker.ml index ab40113c23..ccea526f94 100644 --- a/src/frontend/Typechecker.ml +++ b/src/frontend/Typechecker.ml @@ -56,10 +56,10 @@ let context block = ; in_udf_dist_def= false ; loop_depth= 0 } -let calculate_autodifftype cf origin ut = +let calculate_autodifftype current_block origin ut = match origin with | Env.(Param | TParam | Model | Functions) - when not (UnsizedType.is_int_type ut || cf.current_block = GQuant) -> + when not (UnsizedType.is_int_type ut || current_block = Env.GQuant) -> UnsizedType.AutoDiffable | _ -> DataOnly @@ -83,1693 +83,1742 @@ let reserved_keywords = ; "get_lp"; "print"; "reject"; "typedef"; "struct"; "var"; "export"; "extern" ; "static"; "auto" ] -let verify_identifier id : unit = - if id.name = !model_name then - Semantic_error.ident_is_model_name id.id_loc id.name |> error - else if - String.is_suffix id.name ~suffix:"__" - || List.mem reserved_keywords id.name ~equal:String.equal - then Semantic_error.ident_is_keyword id.id_loc id.name |> error - -let distribution_name_variants name = - if name = "multiply_log" || name = "binomial_coefficient_log" then [name] - else - (* this will have some duplicates, but preserves order better *) - match Utils.split_distribution_suffix name with - | Some (stem, "lpmf") | Some (stem, "lpdf") | Some (stem, "log") -> - [name; stem ^ "_lpmf"; stem ^ "_lpdf"; stem ^ "_log"] - | Some (stem, "lcdf") | Some (stem, "cdf_log") -> - [name; stem ^ "_lcdf"; stem ^ "_cdf_log"] - | Some (stem, "lccdf") | Some (stem, "ccdf_log") -> - [name; stem ^ "_lccdf"; stem ^ "_ccdf_log"] - | _ -> [name] - -(** verify that the variable being declared is previous unused. +module type Typechecker = sig + val check_program_exn : untyped_program -> typed_program * Warnings.t list + (** + Type check a full Stan program. + Can raise [Errors.SemanticError] + *) + + val check_program : + untyped_program + -> (typed_program * Warnings.t list, Semantic_error.t) result + (** + The safe version of [check_program_exn]. This catches + all [Errors.SemanticError] exceptions and converts them + into a [Result.t] + *) + + val operator_stan_math_return_type : + Middle.Operator.t + -> (Middle.UnsizedType.autodifftype * Middle.UnsizedType.t) list + -> (Middle.UnsizedType.returntype * Promotion.t list) option + + val stan_math_return_type : + string + -> (Middle.UnsizedType.autodifftype * Middle.UnsizedType.t) list + -> Middle.UnsizedType.returntype option +end + +module Typecheck (StdLibrary : Std_library_utils.Library) : Typechecker = struct + let verify_identifier id : unit = + if id.name = !model_name then + Semantic_error.ident_is_model_name id.id_loc id.name |> error + else if + String.is_suffix id.name ~suffix:"__" + || List.mem reserved_keywords id.name ~equal:String.equal + then Semantic_error.ident_is_keyword id.id_loc id.name |> error + + let distribution_name_variants name = + if name = "multiply_log" || name = "binomial_coefficient_log" then [name] + else + (* this will have some duplicates, but preserves order better *) + match Utils.split_distribution_suffix name with + | Some (stem, "lpmf") | Some (stem, "lpdf") | Some (stem, "log") -> + [name; stem ^ "_lpmf"; stem ^ "_lpdf"; stem ^ "_log"] + | Some (stem, "lcdf") | Some (stem, "cdf_log") -> + [name; stem ^ "_lcdf"; stem ^ "_cdf_log"] + | Some (stem, "lccdf") | Some (stem, "ccdf_log") -> + [name; stem ^ "_lccdf"; stem ^ "_ccdf_log"] + | _ -> [name] + + (** verify that the variable being declared is previous unused. allowed to shadow StanLib *) -let verify_name_fresh_var loc tenv name = - if Utils.is_unnormalized_distribution name then - Semantic_error.ident_has_unnormalized_suffix loc name |> error - else if - List.exists (Env.find tenv name) ~f:(function - | {kind= `StanMath; _} -> - false (* user variables can shadow library names *) - | _ -> true ) - then Semantic_error.ident_in_use loc name |> error - -(** verify that the variable being declared is previous unused. *) -let verify_name_fresh_udf loc tenv name = - if - (* variadic functions are currently not in math sigs and aren't - overloadable due to their separate typechecking *) - Stan_math_signatures.is_reduce_sum_fn name - || Stan_math_signatures.is_variadic_ode_fn name - || Stan_math_signatures.is_variadic_dae_fn name - then Semantic_error.ident_is_stanmath_name loc name |> error - else if Utils.is_unnormalized_distribution name then - Semantic_error.udf_is_unnormalized_fn loc name |> error - else if - (* if a variable is already defined with this name - - not really possible as all functions are defined before data, - but future-proofing is good *) - List.exists - ~f:(function {kind= `Variable _; _} -> true | _ -> false) - (Env.find tenv name) - then Semantic_error.ident_in_use loc name |> error - -(** Checks that a variable/function name: + let verify_name_fresh_var loc tenv name = + if Utils.is_unnormalized_distribution name then + Semantic_error.ident_has_unnormalized_suffix loc name |> error + else if + List.exists (Env.find tenv name) ~f:(function + | {kind= `StanMath; _} -> + false (* user variables can shadow library names *) + | _ -> true ) + then Semantic_error.ident_in_use loc name |> error + + (** verify that the variable being declared is previous unused. *) + let verify_name_fresh_udf loc tenv name = + if + (* variadic functions are currently not in math sigs and aren't + overloadable due to their separate typechecking *) + StdLibrary.is_not_overloadable name + then Semantic_error.ident_is_stanmath_name loc name |> error + else if Utils.is_unnormalized_distribution name then + Semantic_error.udf_is_unnormalized_fn loc name |> error + else if + (* if a variable is already defined with this name + - not really possible as all functions are defined before data, + but future-proofing is good *) + List.exists + ~f:(function {kind= `Variable _; _} -> true | _ -> false) + (Env.find tenv name) + then Semantic_error.ident_in_use loc name |> error + + (** Checks that a variable/function name: - a function/identifier does not have the _lupdf/_lupmf suffix - is not already in use (for now) *) -let verify_name_fresh tenv id ~is_udf = - let f = - if is_udf then verify_name_fresh_udf id.id_loc tenv - else verify_name_fresh_var id.id_loc tenv in - List.iter ~f (distribution_name_variants id.name) - -let is_of_compatible_return_type rt1 srt2 = - UnsizedType.( - match (rt1, srt2) with - | Void, NoReturnType - |Void, Incomplete Void - |Void, Complete Void - |Void, AnyReturnType -> - true - | ReturnType UReal, Complete (ReturnType UInt) -> true - | ReturnType UComplex, Complete (ReturnType UReal) -> true - | ReturnType UComplex, Complete (ReturnType UInt) -> true - | ReturnType rt1, Complete (ReturnType rt2) -> rt1 = rt2 - | ReturnType _, AnyReturnType -> true - | _ -> false) - -(* -- Expressions ------------------------------------------------- *) -let check_ternary_if loc pe te fe = - let promote expr type_ ad_level = - if - (not (UnsizedType.equal expr.emeta.type_ type_)) - || UnsizedType.compare_autodifftype expr.emeta.ad_level ad_level <> 0 - then - { expr= Promotion (expr, UnsizedType.internal_scalar type_, ad_level) - ; emeta= {expr.emeta with type_; ad_level} } - else expr in - match - ( pe.emeta.type_ - , UnsizedType.common_type (te.emeta.type_, fe.emeta.type_) - , expr_ad_lub [pe; te; fe] ) - with - | UInt, Some type_, ad_level when not (UnsizedType.is_fun_type type_) -> - mk_typed_expression - ~expr: - (TernaryIf (pe, promote te type_ ad_level, promote fe type_ ad_level)) - ~ad_level ~type_ ~loc - | _, _, _ -> - Semantic_error.illtyped_ternary_if loc pe.emeta.type_ te.emeta.type_ - fe.emeta.type_ - |> error + let verify_name_fresh tenv id ~is_udf = + let f = + if is_udf then verify_name_fresh_udf id.id_loc tenv + else verify_name_fresh_var id.id_loc tenv in + List.iter ~f (distribution_name_variants id.name) + + let is_of_compatible_return_type rt1 srt2 = + UnsizedType.( + match (rt1, srt2) with + | Void, NoReturnType + |Void, Incomplete Void + |Void, Complete Void + |Void, AnyReturnType -> + true + | ReturnType UReal, Complete (ReturnType UInt) -> true + | ReturnType UComplex, Complete (ReturnType UReal) -> true + | ReturnType UComplex, Complete (ReturnType UInt) -> true + | ReturnType rt1, Complete (ReturnType rt2) -> rt1 = rt2 + | ReturnType _, AnyReturnType -> true + | _ -> false) + + (* -- Expressions ------------------------------------------------- *) + let check_ternary_if loc pe te fe = + let promote expr type_ ad_level = + if + (not (UnsizedType.equal expr.emeta.type_ type_)) + || UnsizedType.compare_autodifftype expr.emeta.ad_level ad_level <> 0 + then + { expr= Promotion (expr, UnsizedType.internal_scalar type_, ad_level) + ; emeta= {expr.emeta with type_; ad_level} } + else expr in + match + ( pe.emeta.type_ + , UnsizedType.common_type (te.emeta.type_, fe.emeta.type_) + , expr_ad_lub [pe; te; fe] ) + with + | UInt, Some type_, ad_level when not (UnsizedType.is_fun_type type_) -> + mk_typed_expression + ~expr: + (TernaryIf (pe, promote te type_ ad_level, promote fe type_ ad_level) + ) + ~ad_level ~type_ ~loc + | _, _, _ -> + Semantic_error.illtyped_ternary_if loc pe.emeta.type_ te.emeta.type_ + fe.emeta.type_ + |> error -let match_to_rt_option = function - | SignatureMismatch.UniqueMatch (rt, _, _) -> Some rt - | _ -> None - -let stan_math_return_type name arg_tys = - match name with - | x when Stan_math_signatures.is_reduce_sum_fn x -> - Some (UnsizedType.ReturnType UReal) - | x when Stan_math_signatures.is_variadic_ode_fn x -> - Some (UnsizedType.ReturnType (UArray UVector)) - | x when Stan_math_signatures.is_variadic_dae_fn x -> - Some (UnsizedType.ReturnType (UArray UVector)) - | _ -> - SignatureMismatch.matching_stanlib_function name arg_tys - |> match_to_rt_option - -let operator_stan_math_return_type op arg_tys = - match (op, arg_tys) with - | Operator.IntDivide, [(_, UnsizedType.UInt); (_, UInt)] -> - Some (UnsizedType.(ReturnType UInt), [Promotion.NoPromotion; NoPromotion]) - | IntDivide, _ -> None - | _ -> - Stan_math_signatures.operator_to_stan_math_fns op - |> List.filter_map ~f:(fun name -> - SignatureMismatch.matching_stanlib_function name arg_tys - |> function - | SignatureMismatch.UniqueMatch (rt, _, p) -> Some (rt, p) - | _ -> None ) - |> List.hd - -let assignmentoperator_stan_math_return_type assop arg_tys = - ( match assop with - | Operator.Divide -> - SignatureMismatch.matching_stanlib_function "divide" arg_tys - |> match_to_rt_option - | Plus | Minus | Times | EltTimes | EltDivide -> - operator_stan_math_return_type assop arg_tys |> Option.map ~f:fst - | _ -> None ) - |> Option.bind ~f:(function - | ReturnType rtype - when rtype = snd (List.hd_exn arg_tys) - && not - ( (assop = Operator.EltTimes || assop = Operator.EltDivide) - && UnsizedType.is_scalar_type rtype ) -> - Some UnsizedType.Void - | _ -> None ) - -let check_binop loc op le re = - let rt = [le; re] |> get_arg_types |> operator_stan_math_return_type op in - match rt with - | Some (ReturnType type_, [p1; p2]) -> - mk_typed_expression - ~expr:(BinOp (Promotion.promote le p1, op, Promotion.promote re p2)) - ~ad_level:(expr_ad_lub [le; re]) - ~type_ ~loc - | _ -> - Semantic_error.illtyped_binary_op loc op le.emeta.type_ re.emeta.type_ - |> error + let match_to_rt_option = function + | SignatureMismatch.UniqueMatch (rt, _, _) -> Some rt + | _ -> None -let check_prefixop loc op te = - let rt = operator_stan_math_return_type op [arg_type te] in - match rt with - | Some (ReturnType type_, _) -> - mk_typed_expression - ~expr:(PrefixOp (op, te)) - ~ad_level:(expr_ad_lub [te]) - ~type_ ~loc - | _ -> Semantic_error.illtyped_prefix_op loc op te.emeta.type_ |> error - -let check_postfixop loc op te = - let rt = operator_stan_math_return_type op [arg_type te] in - match rt with - | Some (ReturnType type_, _) -> - mk_typed_expression - ~expr:(PostfixOp (te, op)) - ~ad_level:(expr_ad_lub [te]) - ~type_ ~loc - | _ -> Semantic_error.illtyped_postfix_op loc op te.emeta.type_ |> error - -let check_id cf loc tenv id = - match Env.find tenv (Utils.stdlib_distribution_name id.name) with - | [] -> - Semantic_error.ident_not_in_scope loc id.name - (Env.nearest_ident tenv id.name) - |> error - | {kind= `StanMath; _} :: _ -> - ( calculate_autodifftype cf MathLibrary UMathLibraryFunction - , UnsizedType.UMathLibraryFunction ) - | {kind= `Variable {origin= Param | TParam | GQuant; _}; _} :: _ - when cf.in_toplevel_decl -> - Semantic_error.non_data_variable_size_decl loc |> error - | _ :: _ - when Utils.is_unnormalized_distribution id.name - && not - ( (cf.in_fun_def && (cf.in_udf_dist_def || cf.in_lp_fun_def)) - || cf.current_block = Model ) -> - Semantic_error.invalid_unnormalized_fn loc |> error - | {kind= `Variable {origin; _}; type_} :: _ -> - (calculate_autodifftype cf origin type_, type_) - | { kind= `UserDefined | `UserDeclared _ - ; type_= UFun (args, rt, FnLpdf _, mem_pattern) } - :: _ -> - let type_ = - UnsizedType.UFun - (args, rt, Fun_kind.suffix_from_name id.name, mem_pattern) in - (calculate_autodifftype cf Functions type_, type_) - | {kind= `UserDefined | `UserDeclared _; type_} :: _ -> - (calculate_autodifftype cf Functions type_, type_) - -let check_variable cf loc tenv id = - let ad_level, type_ = check_id cf loc tenv id in - mk_typed_expression ~expr:(Variable id) ~ad_level ~type_ ~loc - -let get_consistent_types ad_level type_ es = - let ad = - UnsizedType.lub_ad_type - (ad_level :: List.map ~f:(fun e -> e.emeta.ad_level) es) in - let f state e = - match state with - | Error e -> Error e - | Ok ty -> ( - match UnsizedType.common_type (ty, e.emeta.type_) with - | Some ty -> Ok ty - | None -> Error (ty, e.emeta) ) in - List.fold ~init:(Ok type_) ~f es - |> Result.map ~f:(fun ty -> - let promotions = - List.map (get_arg_types es) - ~f:(Promotion.get_type_promotion_exn (ad, ty)) in - (ad, ty, promotions) ) - -let check_array_expr loc es = - match es with - | [] -> Semantic_error.empty_array loc |> error - | {emeta= {ad_level; type_; _}; _} :: _ -> ( - match get_consistent_types ad_level type_ es with - | Error (ty, meta) -> - Semantic_error.mismatched_array_types meta.loc ty meta.type_ |> error - | Ok (ad_level, type_, promotions) -> - let type_ = UnsizedType.UArray type_ in - mk_typed_expression - ~expr:(ArrayExpr (Promotion.promote_list es promotions)) - ~ad_level ~type_ ~loc ) - -let check_rowvector loc es = - match es with - | {emeta= {ad_level; type_= UnsizedType.URowVector; _}; _} :: _ -> ( - match get_consistent_types ad_level URowVector es with - | Ok (ad_level, typ, promotions) -> + let stan_math_return_type name arg_tys = + match name with + | x when StdLibrary.is_variadic_function_name x -> Some (failwith "TODO") + | _ -> + SignatureMismatch.matching_stanlib_function name arg_tys + |> match_to_rt_option + + let operator_stan_math_return_type op arg_tys = + match (op, arg_tys) with + | Operator.IntDivide, [(_, UnsizedType.UInt); (_, UInt)] -> + Some + (UnsizedType.(ReturnType UInt), [Promotion.NoPromotion; NoPromotion]) + | IntDivide, _ -> None + | _ -> + StdLibrary.operator_to_function_names op + |> List.filter_map ~f:(fun name -> + SignatureMismatch.matching_stanlib_function name arg_tys + |> function + | SignatureMismatch.UniqueMatch (rt, _, p) -> Some (rt, p) + | _ -> None ) + |> List.hd + + let assignmentoperator_stan_math_return_type assop arg_tys = + ( match assop with + | Operator.Divide -> + SignatureMismatch.matching_stanlib_function "divide" arg_tys + |> match_to_rt_option + | Plus | Minus | Times | EltTimes | EltDivide -> + operator_stan_math_return_type assop arg_tys |> Option.map ~f:fst + | _ -> None ) + |> Option.bind ~f:(function + | ReturnType rtype + when rtype = snd (List.hd_exn arg_tys) + && not + ( (assop = Operator.EltTimes || assop = Operator.EltDivide) + && UnsizedType.is_scalar_type rtype ) -> + Some UnsizedType.Void + | _ -> None ) + + let check_binop loc op le re = + let rt = [le; re] |> get_arg_types |> operator_stan_math_return_type op in + match rt with + | Some (ReturnType type_, [p1; p2]) -> mk_typed_expression - ~expr:(RowVectorExpr (Promotion.promote_list es promotions)) - ~ad_level - ~type_:(if typ = UComplex then UComplexMatrix else UMatrix) - ~loc - | Error (_, meta) -> - Semantic_error.invalid_matrix_types meta.loc meta.type_ |> error ) - | {emeta= {ad_level; type_= UnsizedType.UComplexRowVector; _}; _} :: _ -> ( - match get_consistent_types ad_level UComplexRowVector es with - | Ok (ad_level, _, promotions) -> + ~expr:(BinOp (Promotion.promote le p1, op, Promotion.promote re p2)) + ~ad_level:(expr_ad_lub [le; re]) + ~type_ ~loc + | _ -> + Semantic_error.illtyped_binary_op loc op le.emeta.type_ re.emeta.type_ + |> error + + let check_prefixop loc op te = + let rt = operator_stan_math_return_type op [arg_type te] in + match rt with + | Some (ReturnType type_, _) -> mk_typed_expression - ~expr:(RowVectorExpr (Promotion.promote_list es promotions)) - ~ad_level ~type_:UComplexMatrix ~loc - | Error (_, meta) -> - Semantic_error.invalid_matrix_types meta.loc meta.type_ |> error ) - | _ -> ( - match get_consistent_types DataOnly UReal es with - | Ok (ad_level, typ, promotions) -> + ~expr:(PrefixOp (op, te)) + ~ad_level:(expr_ad_lub [te]) + ~type_ ~loc + | _ -> Semantic_error.illtyped_prefix_op loc op te.emeta.type_ |> error + + let check_postfixop loc op te = + let rt = operator_stan_math_return_type op [arg_type te] in + match rt with + | Some (ReturnType type_, _) -> mk_typed_expression - ~expr:(RowVectorExpr (Promotion.promote_list es promotions)) - ~ad_level - ~type_:(if typ = UComplex then UComplexRowVector else URowVector) - ~loc - | Error (_, meta) -> - Semantic_error.invalid_row_vector_types meta.loc meta.type_ |> error ) - -(* index checking *) - -let indexing_type idx = - match idx with - | Single {emeta= {type_= UnsizedType.UInt; _}; _} -> `Single - | _ -> `Multi - -let is_multiindex i = - match indexing_type i with `Single -> false | `Multi -> true - -let inferred_unsizedtype_of_indexed ~loc ut indices = - let rec aux type_ idcs = - let vec, rowvec, scalar = - if UnsizedType.is_complex_type type_ then - UnsizedType.(UComplexVector, UComplexRowVector, UComplex) - else (UVector, URowVector, UReal) in - match (type_, idcs) with - | _, [] -> type_ - | UnsizedType.UArray type_, `Single :: tl -> aux type_ tl - | UArray type_, `Multi :: tl -> aux type_ tl |> UnsizedType.UArray - | (UVector | URowVector | UComplexRowVector | UComplexVector), [`Single] - |(UMatrix | UComplexMatrix), [`Single; `Single] -> - scalar - | ( ( UVector | URowVector | UMatrix | UComplexVector | UComplexMatrix - | UComplexRowVector ) - , [`Multi] ) - |(UMatrix | UComplexMatrix), [`Multi; `Multi] -> - type_ - | (UMatrix | UComplexMatrix), ([`Single] | [`Single; `Multi]) -> rowvec - | (UMatrix | UComplexMatrix), [`Multi; `Single] -> vec - | (UMatrix | UComplexMatrix), _ :: _ :: _ :: _ - |(UVector | URowVector | UComplexRowVector | UComplexVector), _ :: _ :: _ - |(UInt | UReal | UComplex | UFun _ | UMathLibraryFunction), _ :: _ -> - Semantic_error.not_indexable loc ut (List.length indices) |> error in - aux ut (List.map ~f:indexing_type indices) - -let inferred_ad_type_of_indexed at uindices = - UnsizedType.lub_ad_type - ( at - :: List.map - ~f:(function - | All -> UnsizedType.DataOnly - | Single ue1 | Upfrom ue1 | Downfrom ue1 -> - UnsizedType.lub_ad_type [at; ue1.emeta.ad_level] - | Between (ue1, ue2) -> - UnsizedType.lub_ad_type - [at; ue1.emeta.ad_level; ue2.emeta.ad_level] ) - uindices ) - -(* function checking *) - -let verify_conddist_name loc id = - if - List.exists - ~f:(fun x -> String.is_suffix id.name ~suffix:x) - Utils.conditioning_suffices - then () - else Semantic_error.conditional_notation_not_allowed loc |> error - -let verify_fn_conditioning loc id = - if - List.exists - ~f:(fun suffix -> String.is_suffix id.name ~suffix) - Utils.conditioning_suffices - && not (String.is_suffix id.name ~suffix:"_cdf") - then Semantic_error.conditioning_required loc |> error - -(** `Target+=` can only be used in model and functions + ~expr:(PostfixOp (te, op)) + ~ad_level:(expr_ad_lub [te]) + ~type_ ~loc + | _ -> Semantic_error.illtyped_postfix_op loc op te.emeta.type_ |> error + + let check_id cf loc tenv id = + match Env.find tenv (Utils.stdlib_distribution_name id.name) with + | [] -> + Semantic_error.ident_not_in_scope loc id.name + (Env.nearest_ident tenv id.name) + |> error + | {kind= `StanMath; _} :: _ -> + ( calculate_autodifftype cf.current_block MathLibrary + UMathLibraryFunction + , UnsizedType.UMathLibraryFunction ) + | {kind= `Variable {origin= Param | TParam | GQuant; _}; _} :: _ + when cf.in_toplevel_decl -> + Semantic_error.non_data_variable_size_decl loc |> error + | _ :: _ + when Utils.is_unnormalized_distribution id.name + && not + ( (cf.in_fun_def && (cf.in_udf_dist_def || cf.in_lp_fun_def)) + || cf.current_block = Model ) -> + Semantic_error.invalid_unnormalized_fn loc |> error + | {kind= `Variable {origin; _}; type_} :: _ -> + (calculate_autodifftype cf.current_block origin type_, type_) + | { kind= `UserDefined | `UserDeclared _ + ; type_= UFun (args, rt, FnLpdf _, mem_pattern) } + :: _ -> + let type_ = + UnsizedType.UFun + (args, rt, Fun_kind.suffix_from_name id.name, mem_pattern) in + (calculate_autodifftype cf.current_block Functions type_, type_) + | {kind= `UserDefined | `UserDeclared _; type_} :: _ -> + (calculate_autodifftype cf.current_block Functions type_, type_) + + let check_variable cf loc tenv id = + let ad_level, type_ = check_id cf loc tenv id in + mk_typed_expression ~expr:(Variable id) ~ad_level ~type_ ~loc + + let get_consistent_types ad_level type_ es = + let ad = + UnsizedType.lub_ad_type + (ad_level :: List.map ~f:(fun e -> e.emeta.ad_level) es) in + let f state e = + match state with + | Error e -> Error e + | Ok ty -> ( + match UnsizedType.common_type (ty, e.emeta.type_) with + | Some ty -> Ok ty + | None -> Error (ty, e.emeta) ) in + List.fold ~init:(Ok type_) ~f es + |> Result.map ~f:(fun ty -> + let promotions = + List.map (get_arg_types es) + ~f:(Promotion.get_type_promotion_exn (ad, ty)) in + (ad, ty, promotions) ) + + let check_array_expr loc es = + match es with + | [] -> Semantic_error.empty_array loc |> error + | {emeta= {ad_level; type_; _}; _} :: _ -> ( + match get_consistent_types ad_level type_ es with + | Error (ty, meta) -> + Semantic_error.mismatched_array_types meta.loc ty meta.type_ |> error + | Ok (ad_level, type_, promotions) -> + let type_ = UnsizedType.UArray type_ in + mk_typed_expression + ~expr:(ArrayExpr (Promotion.promote_list es promotions)) + ~ad_level ~type_ ~loc ) + + let check_rowvector loc es = + match es with + | {emeta= {ad_level; type_= UnsizedType.URowVector; _}; _} :: _ -> ( + match get_consistent_types ad_level URowVector es with + | Ok (ad_level, typ, promotions) -> + mk_typed_expression + ~expr:(RowVectorExpr (Promotion.promote_list es promotions)) + ~ad_level + ~type_:(if typ = UComplex then UComplexMatrix else UMatrix) + ~loc + | Error (_, meta) -> + Semantic_error.invalid_matrix_types meta.loc meta.type_ |> error ) + | {emeta= {ad_level; type_= UnsizedType.UComplexRowVector; _}; _} :: _ -> ( + match get_consistent_types ad_level UComplexRowVector es with + | Ok (ad_level, _, promotions) -> + mk_typed_expression + ~expr:(RowVectorExpr (Promotion.promote_list es promotions)) + ~ad_level ~type_:UComplexMatrix ~loc + | Error (_, meta) -> + Semantic_error.invalid_matrix_types meta.loc meta.type_ |> error ) + | _ -> ( + match get_consistent_types DataOnly UReal es with + | Ok (ad_level, typ, promotions) -> + mk_typed_expression + ~expr:(RowVectorExpr (Promotion.promote_list es promotions)) + ~ad_level + ~type_:(if typ = UComplex then UComplexRowVector else URowVector) + ~loc + | Error (_, meta) -> + Semantic_error.invalid_row_vector_types meta.loc meta.type_ |> error ) + + (* index checking *) + + let indexing_type idx = + match idx with + | Single {emeta= {type_= UnsizedType.UInt; _}; _} -> `Single + | _ -> `Multi + + let is_multiindex i = + match indexing_type i with `Single -> false | `Multi -> true + + let inferred_unsizedtype_of_indexed ~loc ut indices = + let rec aux type_ idcs = + let vec, rowvec, scalar = + if UnsizedType.is_complex_type type_ then + UnsizedType.(UComplexVector, UComplexRowVector, UComplex) + else (UVector, URowVector, UReal) in + match (type_, idcs) with + | _, [] -> type_ + | UnsizedType.UArray type_, `Single :: tl -> aux type_ tl + | UArray type_, `Multi :: tl -> aux type_ tl |> UnsizedType.UArray + | (UVector | URowVector | UComplexRowVector | UComplexVector), [`Single] + |(UMatrix | UComplexMatrix), [`Single; `Single] -> + scalar + | ( ( UVector | URowVector | UMatrix | UComplexVector | UComplexMatrix + | UComplexRowVector ) + , [`Multi] ) + |(UMatrix | UComplexMatrix), [`Multi; `Multi] -> + type_ + | (UMatrix | UComplexMatrix), ([`Single] | [`Single; `Multi]) -> rowvec + | (UMatrix | UComplexMatrix), [`Multi; `Single] -> vec + | (UMatrix | UComplexMatrix), _ :: _ :: _ :: _ + |(UVector | URowVector | UComplexRowVector | UComplexVector), _ :: _ :: _ + |(UInt | UReal | UComplex | UFun _ | UMathLibraryFunction), _ :: _ -> + Semantic_error.not_indexable loc ut (List.length indices) |> error + in + aux ut (List.map ~f:indexing_type indices) + + let inferred_ad_type_of_indexed at uindices = + UnsizedType.lub_ad_type + ( at + :: List.map + ~f:(function + | All -> UnsizedType.DataOnly + | Single ue1 | Upfrom ue1 | Downfrom ue1 -> + UnsizedType.lub_ad_type [at; ue1.emeta.ad_level] + | Between (ue1, ue2) -> + UnsizedType.lub_ad_type + [at; ue1.emeta.ad_level; ue2.emeta.ad_level] ) + uindices ) + + (* function checking *) + + let verify_conddist_name loc id = + if + List.exists + ~f:(fun x -> String.is_suffix id.name ~suffix:x) + Utils.conditioning_suffices + then () + else Semantic_error.conditional_notation_not_allowed loc |> error + + let verify_fn_conditioning loc id = + if + List.exists + ~f:(fun suffix -> String.is_suffix id.name ~suffix) + Utils.conditioning_suffices + && not (String.is_suffix id.name ~suffix:"_cdf") + then Semantic_error.conditioning_required loc |> error + + (** `Target+=` can only be used in model and functions with right suffix (same for tilde etc) *) -let verify_fn_target_plus_equals cf loc id = - if - String.is_suffix id.name ~suffix:"_lp" - && not - ( cf.in_lp_fun_def || cf.current_block = Model - || cf.current_block = TParam ) - then Semantic_error.target_plusequals_outisde_model_or_logprob loc |> error - -(** Rng functions cannot be used in Tp or Model and only + let verify_fn_target_plus_equals cf loc id = + if + String.is_suffix id.name ~suffix:"_lp" + && not + ( cf.in_lp_fun_def || cf.current_block = Model + || cf.current_block = TParam ) + then Semantic_error.target_plusequals_outisde_model_or_logprob loc |> error + + (** Rng functions cannot be used in Tp or Model and only in function defs with the right suffix *) -let verify_fn_rng cf loc id = - if String.is_suffix id.name ~suffix:"_rng" && cf.in_toplevel_decl then - Semantic_error.invalid_decl_rng_fn loc |> error - else if - String.is_suffix id.name ~suffix:"_rng" - && ( (cf.in_fun_def && not cf.in_rng_fun_def) - || cf.current_block = TParam || cf.current_block = Model ) - then Semantic_error.invalid_rng_fn loc |> error - -(** unnormalized _lpdf/_lpmf functions can only be used in _lpdf/_lpmf/_lp udfs + let verify_fn_rng cf loc id = + if String.is_suffix id.name ~suffix:"_rng" && cf.in_toplevel_decl then + Semantic_error.invalid_decl_rng_fn loc |> error + else if + String.is_suffix id.name ~suffix:"_rng" + && ( (cf.in_fun_def && not cf.in_rng_fun_def) + || cf.current_block = TParam || cf.current_block = Model ) + then Semantic_error.invalid_rng_fn loc |> error + + (** unnormalized _lpdf/_lpmf functions can only be used in _lpdf/_lpmf/_lp udfs or the model block *) -let verify_unnormalized cf loc id = - if - Utils.is_unnormalized_distribution id.name - && not ((cf.in_fun_def && cf.in_udf_dist_def) || cf.current_block = Model) - then Semantic_error.invalid_unnormalized_fn loc |> error - -let mk_fun_app ~is_cond_dist (x, y, z) = - if is_cond_dist then CondDistApp (x, y, z) else FunApp (x, y, z) - -let check_normal_fn ~is_cond_dist loc tenv id es = - match Env.find tenv (Utils.normalized_name id.name) with - | {kind= `Variable _; _} :: _ - (* variables can sometimes shadow stanlib functions, so we have to check this *) - when not - (Stan_math_signatures.is_stan_math_function_name - (Utils.normalized_name id.name) ) -> - Semantic_error.returning_fn_expected_nonfn_found loc id.name |> error - | [] -> - ( match Utils.split_distribution_suffix id.name with - | Some (prefix, suffix) -> ( - let known_families = - List.map - ~f:(fun (_, y, _, _) -> y) - Stan_math_signatures.distributions in - let is_known_family s = - List.mem known_families s ~equal:String.equal in - match suffix with - | ("lpmf" | "lumpf") when Env.mem tenv (prefix ^ "_lpdf") -> - Semantic_error.returning_fn_expected_wrong_dist_suffix_found loc - (prefix, suffix) - | ("lpdf" | "lumdf") when Env.mem tenv (prefix ^ "_lpmf") -> - Semantic_error.returning_fn_expected_wrong_dist_suffix_found loc - (prefix, suffix) - | _ -> - if - is_known_family prefix - && List.mem ~equal:String.equal - Utils.cumulative_distribution_suffices_w_rng suffix - then - Semantic_error - .returning_fn_expected_undeclared_dist_suffix_found loc + let verify_unnormalized cf loc id = + if + Utils.is_unnormalized_distribution id.name + && not ((cf.in_fun_def && cf.in_udf_dist_def) || cf.current_block = Model) + then Semantic_error.invalid_unnormalized_fn loc |> error + + let mk_fun_app ~is_cond_dist (x, y, z) = + if is_cond_dist then CondDistApp (x, y, z) else FunApp (x, y, z) + + let check_normal_fn ~is_cond_dist loc tenv id es = + match Env.find tenv (Utils.normalized_name id.name) with + | {kind= `Variable _; _} :: _ + (* variables can sometimes shadow stanlib functions, so we have to check this *) + when not + (StdLibrary.is_stan_math_function_name + (Utils.normalized_name id.name) ) -> + Semantic_error.returning_fn_expected_nonfn_found loc id.name |> error + | [] -> + ( match Utils.split_distribution_suffix id.name with + | Some (prefix, suffix) -> ( + let known_families = StdLibrary.distribution_families in + let is_known_family s = + List.mem known_families s ~equal:String.equal in + match suffix with + | ("lpmf" | "lumpf") when Env.mem tenv (prefix ^ "_lpdf") -> + Semantic_error.returning_fn_expected_wrong_dist_suffix_found loc (prefix, suffix) - else - Semantic_error.returning_fn_expected_undeclaredident_found loc - id.name - (Env.nearest_ident tenv id.name) ) - | None -> - Semantic_error.returning_fn_expected_undeclaredident_found loc id.name - (Env.nearest_ident tenv id.name) ) - |> error - | _ (* a function *) -> ( - (* NB: At present, [SignatureMismatch.matching_function] cannot handle overloaded function types. - This is not needed until UDFs can be higher-order, as it is special cased for - variadic functions - *) - match - SignatureMismatch.matching_function tenv id.name (get_arg_types es) - with - | UniqueMatch (Void, _, _) -> - Semantic_error.returning_fn_expected_nonreturning_found loc id.name - |> error - | UniqueMatch (ReturnType ut, fnk, promotions) -> - mk_typed_expression - ~expr: - (mk_fun_app ~is_cond_dist - ( fnk (Fun_kind.suffix_from_name id.name) - , id - , Promotion.promote_list es promotions ) ) - ~ad_level:(expr_ad_lub es) ~type_:ut ~loc - | AmbiguousMatch sigs -> - Semantic_error.ambiguous_function_promotion loc id.name - (Some (List.map ~f:type_of_expr_typed es)) - sigs + | ("lpdf" | "lumdf") when Env.mem tenv (prefix ^ "_lpmf") -> + Semantic_error.returning_fn_expected_wrong_dist_suffix_found loc + (prefix, suffix) + | _ -> + if + is_known_family prefix + && List.mem ~equal:String.equal + Utils.cumulative_distribution_suffices_w_rng suffix + then + Semantic_error + .returning_fn_expected_undeclared_dist_suffix_found loc + (prefix, suffix) + else + Semantic_error.returning_fn_expected_undeclaredident_found loc + id.name + (Env.nearest_ident tenv id.name) ) + | None -> + Semantic_error.returning_fn_expected_undeclaredident_found loc + id.name + (Env.nearest_ident tenv id.name) ) |> error - | SignatureErrors (l, b) -> - es - |> List.map ~f:(fun e -> e.emeta.type_) - |> Semantic_error.illtyped_fn_app loc id.name (l, b) - |> error ) + | _ (* a function *) -> ( + (* NB: At present, [SignatureMismatch.matching_function] cannot handle overloaded function types. + This is not needed until UDFs can be higher-order, as it is special cased for + variadic functions + *) + match + SignatureMismatch.matching_function tenv id.name (get_arg_types es) + with + | UniqueMatch (Void, _, _) -> + Semantic_error.returning_fn_expected_nonreturning_found loc id.name + |> error + | UniqueMatch (ReturnType ut, fnk, promotions) -> + mk_typed_expression + ~expr: + (mk_fun_app ~is_cond_dist + ( fnk (Fun_kind.suffix_from_name id.name) + , id + , Promotion.promote_list es promotions ) ) + ~ad_level:(expr_ad_lub es) ~type_:ut ~loc + | AmbiguousMatch sigs -> + Semantic_error.ambiguous_function_promotion loc id.name + (Some (List.map ~f:type_of_expr_typed es)) + sigs + |> error + | SignatureErrors (l, b) -> + es + |> List.map ~f:(fun e -> e.emeta.type_) + |> Semantic_error.illtyped_fn_app loc id.name (l, b) + |> error ) -(** Given a constraint function [matches], find any signature which exists + (** Given a constraint function [matches], find any signature which exists Returns the first [Ok] if any exist, or else [Error] *) -let find_matching_first_order_fn tenv matches fname = - let candidates = - Utils.stdlib_distribution_name fname.name - |> Env.find tenv |> List.map ~f:matches in - let ok, errs = List.partition_map candidates ~f:Result.to_either in - match SignatureMismatch.unique_minimum_promotion ok with - | Ok a -> SignatureMismatch.UniqueMatch a - | Error (Some promotions) -> - List.filter_map promotions ~f:(function - | UnsizedType.UFun (args, rt, _, _) -> Some (rt, args) - | _ -> None ) - |> AmbiguousMatch - | Error None -> SignatureMismatch.SignatureErrors (List.hd_exn errs) - -let make_function_variable cf loc id = function - | UnsizedType.UFun (args, rt, FnLpdf _, mem_pattern) -> - let type_ = - UnsizedType.UFun - (args, rt, Fun_kind.suffix_from_name id.name, mem_pattern) in - mk_typed_expression ~expr:(Variable id) - ~ad_level:(calculate_autodifftype cf Functions type_) - ~type_ ~loc - | UnsizedType.UFun _ as type_ -> - mk_typed_expression ~expr:(Variable id) - ~ad_level:(calculate_autodifftype cf Functions type_) - ~type_ ~loc - | type_ -> - Common.FatalError.fatal_error_msg - [%message - "Attempting to create function variable out of " - (type_ : UnsizedType.t)] - -let rec check_fn ~is_cond_dist loc cf tenv id (tes : Ast.typed_expression list) - = - if Stan_math_signatures.is_reduce_sum_fn id.name then - check_reduce_sum ~is_cond_dist loc cf tenv id tes - else if Stan_math_signatures.is_variadic_ode_fn id.name then - check_variadic_ode ~is_cond_dist loc cf tenv id tes - else if Stan_math_signatures.is_variadic_dae_fn id.name then - check_variadic_dae ~is_cond_dist loc cf tenv id tes - else check_normal_fn ~is_cond_dist loc tenv id tes - -and check_reduce_sum ~is_cond_dist loc cf tenv id tes = - let basic_mismatch () = - let mandatory_args = - UnsizedType.[(AutoDiffable, UArray UReal); (AutoDiffable, UInt)] in - let mandatory_fun_args = - UnsizedType. - [(AutoDiffable, UArray UReal); (DataOnly, UInt); (DataOnly, UInt)] in - SignatureMismatch.check_variadic_args true mandatory_args mandatory_fun_args - UReal (get_arg_types tes) in - let fail () = - let expected_args, err = - basic_mismatch () |> Result.error |> Option.value_exn in - Semantic_error.illtyped_reduce_sum_generic loc id.name - (List.map ~f:type_of_expr_typed tes) - expected_args err - |> error in - let matching remaining_es fn = - match fn with - | Env. - { type_= - UnsizedType.UFun - (((_, sliced_arg_fun_type) as sliced_arg_fun) :: _, _, _, _) as - ftype - ; _ } - when List.mem Stan_math_signatures.reduce_sum_slice_types - sliced_arg_fun_type ~equal:( = ) -> - let mandatory_args = [sliced_arg_fun; (AutoDiffable, UInt)] in - let mandatory_fun_args = - [sliced_arg_fun; (DataOnly, UInt); (DataOnly, UInt)] in - let arg_types = - (calculate_autodifftype cf Functions ftype, ftype) - :: get_arg_types remaining_es in - SignatureMismatch.check_variadic_args true mandatory_args - mandatory_fun_args UReal arg_types - | _ -> basic_mismatch () in - match tes with - | {expr= Variable fname; _} :: remaining_es -> ( - match find_matching_first_order_fn tenv (matching remaining_es) fname with - | SignatureMismatch.UniqueMatch (ftype, promotions) -> - (* a valid signature exists *) - let tes = make_function_variable cf loc fname ftype :: remaining_es in - mk_typed_expression - ~expr: - (mk_fun_app ~is_cond_dist - (StanLib FnPlain, id, Promotion.promote_list tes promotions) ) - ~ad_level:(expr_ad_lub tes) ~type_:UnsizedType.UReal ~loc - | AmbiguousMatch ps -> - Semantic_error.ambiguous_function_promotion loc fname.name None ps - |> error - | SignatureErrors (expected_args, err) -> - Semantic_error.illtyped_reduce_sum loc id.name - (List.map ~f:type_of_expr_typed tes) - expected_args err - |> error ) - | _ -> fail () - -and check_variadic_ode ~is_cond_dist loc cf tenv id tes = - let optional_tol_mandatory_args = - if Stan_math_signatures.variadic_ode_adjoint_fn = id.name then - Stan_math_signatures.variadic_ode_adjoint_ctl_tol_arg_types - else if Stan_math_signatures.is_variadic_ode_nonadjoint_tol_fn id.name then - Stan_math_signatures.variadic_ode_tol_arg_types - else [] in - let mandatory_arg_types = - Stan_math_signatures.variadic_ode_mandatory_arg_types - @ optional_tol_mandatory_args in - let fail () = - let expected_args, err = - SignatureMismatch.check_variadic_args false mandatory_arg_types - Stan_math_signatures.variadic_ode_mandatory_fun_args - Stan_math_signatures.variadic_ode_fun_return_type (get_arg_types tes) - |> Result.error |> Option.value_exn in - Semantic_error.illtyped_variadic_ode loc id.name - (List.map ~f:type_of_expr_typed tes) - expected_args err - |> error in - let matching remaining_es Env.{type_= ftype; _} = - let arg_types = - (calculate_autodifftype cf Functions ftype, ftype) - :: get_arg_types remaining_es in - SignatureMismatch.check_variadic_args false mandatory_arg_types - Stan_math_signatures.variadic_ode_mandatory_fun_args - Stan_math_signatures.variadic_ode_fun_return_type arg_types in - match tes with - | {expr= Variable fname; _} :: remaining_es -> ( - match find_matching_first_order_fn tenv (matching remaining_es) fname with - | SignatureMismatch.UniqueMatch (ftype, promotions) -> - let tes = make_function_variable cf loc fname ftype :: remaining_es in - mk_typed_expression - ~expr: - (mk_fun_app ~is_cond_dist - (StanLib FnPlain, id, Promotion.promote_list tes promotions) ) - ~ad_level:(expr_ad_lub tes) - ~type_:Stan_math_signatures.variadic_ode_return_type ~loc - | AmbiguousMatch ps -> - Semantic_error.ambiguous_function_promotion loc fname.name None ps - |> error - | SignatureErrors (expected_args, err) -> - Semantic_error.illtyped_variadic_ode loc id.name - (List.map ~f:type_of_expr_typed tes) - expected_args err - |> error ) - | _ -> fail () - -and check_variadic_dae ~is_cond_dist loc cf tenv id tes = - let optional_tol_mandatory_args = - if Stan_math_signatures.is_variadic_dae_tol_fn id.name then - Stan_math_signatures.variadic_dae_tol_arg_types - else [] in - let mandatory_arg_types = - Stan_math_signatures.variadic_dae_mandatory_arg_types - @ optional_tol_mandatory_args in - let fail () = - let expected_args, err = - SignatureMismatch.check_variadic_args false mandatory_arg_types - Stan_math_signatures.variadic_dae_mandatory_fun_args - Stan_math_signatures.variadic_dae_fun_return_type (get_arg_types tes) - |> Result.error |> Option.value_exn in - Semantic_error.illtyped_variadic_dae loc id.name - (List.map ~f:type_of_expr_typed tes) - expected_args err - |> error in - let matching remaining_es Env.{type_= ftype; _} = - let arg_types = - (calculate_autodifftype cf Functions ftype, ftype) - :: get_arg_types remaining_es in - SignatureMismatch.check_variadic_args false mandatory_arg_types - Stan_math_signatures.variadic_dae_mandatory_fun_args - Stan_math_signatures.variadic_dae_fun_return_type arg_types in - match tes with - | {expr= Variable fname; _} :: remaining_es -> ( - match find_matching_first_order_fn tenv (matching remaining_es) fname with - | SignatureMismatch.UniqueMatch (ftype, promotions) -> - let tes = make_function_variable cf loc fname ftype :: remaining_es in - mk_typed_expression - ~expr: - (mk_fun_app ~is_cond_dist - (StanLib FnPlain, id, Promotion.promote_list tes promotions) ) - ~ad_level:(expr_ad_lub tes) - ~type_:Stan_math_signatures.variadic_dae_return_type ~loc - | AmbiguousMatch ps -> - Semantic_error.ambiguous_function_promotion loc fname.name None ps - |> error - | SignatureErrors (expected_args, err) -> - Semantic_error.illtyped_variadic_dae loc id.name - (List.map ~f:type_of_expr_typed tes) - expected_args err - |> error ) - | _ -> fail () - -and check_funapp loc cf tenv ~is_cond_dist id (es : Ast.typed_expression list) = - let name_check = - if is_cond_dist then verify_conddist_name else verify_fn_conditioning in - let res = check_fn ~is_cond_dist loc cf tenv id es in - verify_identifier id ; - name_check loc id ; - verify_fn_target_plus_equals cf loc id ; - verify_fn_rng cf loc id ; - verify_unnormalized cf loc id ; - res - -and check_indexed loc cf tenv e indices = - let tindices = List.map ~f:(check_index cf tenv) indices in - let te = check_expression cf tenv e in - let ad_level = inferred_ad_type_of_indexed te.emeta.ad_level tindices in - let type_ = inferred_unsizedtype_of_indexed ~loc te.emeta.type_ tindices in - mk_typed_expression ~expr:(Indexed (te, tindices)) ~ad_level ~type_ ~loc - -and check_index cf tenv = function - | All -> All - (* Check that indexes have int (container) type *) - | Single e -> - let te = check_expression cf tenv e in - if has_int_type te || has_int_array_type te then Single te - else - Semantic_error.int_intarray_or_range_expected te.emeta.loc - te.emeta.type_ - |> error - | Upfrom e -> check_expression_of_int_type cf tenv e "Range bound" |> Upfrom - | Downfrom e -> - check_expression_of_int_type cf tenv e "Range bound" |> Downfrom - | Between (e1, e2) -> - let le = check_expression_of_int_type cf tenv e1 "Range bound" in - let ue = check_expression_of_int_type cf tenv e2 "Range bound" in - Between (le, ue) - -and check_expression cf tenv ({emeta; expr} : Ast.untyped_expression) : - Ast.typed_expression = - let loc = emeta.loc in - let ce = check_expression cf tenv in - match expr with - | TernaryIf (e1, e2, e3) -> - let pe = ce e1 in - let te = ce e2 in - let fe = ce e3 in - check_ternary_if loc pe te fe - | BinOp (e1, op, e2) -> - let le = ce e1 in - let re = ce e2 in - let binop_type_warnings x y = - match (x.emeta.type_, y.emeta.type_, op) with - | UInt, UInt, Divide -> - let hint ppf () = - match (x.expr, y.expr) with - | IntNumeral x, _ -> - Fmt.pf ppf "%s.0 / %a" x Pretty_printing.pp_typed_expression y - | _, Ast.IntNumeral y -> - Fmt.pf ppf "%a / %s.0" Pretty_printing.pp_typed_expression x y - | _ -> - Fmt.pf ppf "%a * 1.0 / %a" Pretty_printing.pp_typed_expression - x Pretty_printing.pp_typed_expression y in - let s = - Fmt.str - "@[@[Found int division:@]@ @[%a@]@,\ - @[%a@]@ @[%a@]@,\ - @[%a@]@]" - Pretty_printing.pp_expression {expr; emeta} Fmt.text - "Values will be rounded towards zero. If rounding is not \ - desired you can write the division as" - hint () Fmt.text - "If rounding is intended please use the integer division \ - operator %/%." in - add_warning x.emeta.loc s - | (UArray UMatrix | UMatrix), (UInt | UReal), Pow -> - let s = - Fmt.str - "@[@[Found matrix^scalar:@]@ @[%a@]@,\ - @[%a@]@ @[%a@]@]" Pretty_printing.pp_expression - {expr; emeta} Fmt.text - "matrix ^ number is interpreted as element-wise \ - exponentiation. If this is intended, you can silence this \ - warning by using elementwise operator .^" - Fmt.text - "If you intended matrix exponentiation, use the function \ - matrix_power(matrix,int) instead." in - add_warning x.emeta.loc s - | _ -> () in - binop_type_warnings le re ; check_binop loc op le re - | PrefixOp (op, e) -> ce e |> check_prefixop loc op - | PostfixOp (e, op) -> ce e |> check_postfixop loc op - | Variable id -> - verify_identifier id ; - check_variable cf loc tenv id - | IntNumeral s -> ( - match float_of_string_opt s with - | Some i when i < 2_147_483_648.0 -> - mk_typed_expression ~expr:(IntNumeral s) ~ad_level:DataOnly ~type_:UInt - ~loc - | _ -> Semantic_error.bad_int_literal loc |> error ) - | RealNumeral s -> - mk_typed_expression ~expr:(RealNumeral s) ~ad_level:DataOnly ~type_:UReal - ~loc - | ImagNumeral s -> - mk_typed_expression ~expr:(ImagNumeral s) ~ad_level:DataOnly - ~type_:UComplex ~loc - | GetLP -> - (* Target+= can only be used in model and functions with right suffix (same for tilde etc) *) - if - not - ( cf.in_lp_fun_def || cf.current_block = Model - || cf.current_block = TParam ) - then - Semantic_error.target_plusequals_outisde_model_or_logprob loc |> error - else - mk_typed_expression ~expr:GetLP - ~ad_level:(calculate_autodifftype cf cf.current_block UReal) - ~type_:UReal ~loc - | GetTarget -> - (* Target+= can only be used in model and functions with right suffix (same for tilde etc) *) - if - not - ( cf.in_lp_fun_def || cf.current_block = Model - || cf.current_block = TParam ) - then - Semantic_error.target_plusequals_outisde_model_or_logprob loc |> error - else - mk_typed_expression ~expr:GetTarget - ~ad_level:(calculate_autodifftype cf cf.current_block UReal) + (* let find_matching_first_order_fn tenv matches fname = + let candidates = + Utils.stdlib_distribution_name fname.name + |> Env.find tenv |> List.map ~f:matches in + let ok, errs = List.partition_map candidates ~f:Result.to_either in + match SignatureMismatch.unique_minimum_promotion ok with + | Ok a -> SignatureMismatch.UniqueMatch a + | Error (Some promotions) -> + List.filter_map promotions ~f:(function + | UnsizedType.UFun (args, rt, _, _) -> Some (rt, args) + | _ -> None ) + |> AmbiguousMatch + | Error None -> SignatureMismatch.SignatureErrors (List.hd_exn errs) *) + + (* let make_function_variable current_block loc id = function + | UnsizedType.UFun (args, rt, FnLpdf _, mem_pattern) -> + let type_ = + UnsizedType.UFun + (args, rt, Fun_kind.suffix_from_name id.name, mem_pattern) in + mk_typed_expression ~expr:(Variable id) + ~ad_level:(calculate_autodifftype current_block Functions type_) + ~type_ ~loc + | UnsizedType.UFun _ as type_ -> + mk_typed_expression ~expr:(Variable id) + ~ad_level:(calculate_autodifftype current_block Functions type_) + ~type_ ~loc + | type_ -> + Common.FatalError.fatal_error_msg + [%message + "Attempting to create function variable out of " + (type_ : UnsizedType.t)] *) + + let rec check_fn ~is_cond_dist loc cf tenv id (tes : Ast.typed_expression list) + = + if StdLibrary.is_variadic_function_name id.name then ( + Stdlib.ignore cf ; + failwith "TODO" + (* if StdLibrary.is_reduce_sum_fn id.name then + check_reduce_sum ~is_cond_dist loc cf tenv id tes + else if StdLibrary.is_variadic_ode_fn id.name then + check_variadic_ode ~is_cond_dist loc cf tenv id tes + else if StdLibrary.is_variadic_dae_fn id.name then + check_variadic_dae ~is_cond_dist loc cf tenv id tes *) ) + else check_normal_fn ~is_cond_dist loc tenv id tes + + (* and check_reduce_sum ~is_cond_dist loc cf tenv id tes = + let basic_mismatch () = + let mandatory_args = + UnsizedType.[(AutoDiffable, UArray UReal); (AutoDiffable, UInt)] in + let mandatory_fun_args = + UnsizedType. + [(AutoDiffable, UArray UReal); (DataOnly, UInt); (DataOnly, UInt)] + in + SignatureMismatch.check_variadic_args true mandatory_args + mandatory_fun_args UReal (get_arg_types tes) in + let fail () = + let expected_args, err = + basic_mismatch () |> Result.error |> Option.value_exn in + Semantic_error.illtyped_reduce_sum_generic loc id.name + (List.map ~f:type_of_expr_typed tes) + expected_args err + |> error in + let matching remaining_es fn = + match fn with + | Env. + { type_= + UnsizedType.UFun + (((_, sliced_arg_fun_type) as sliced_arg_fun) :: _, _, _, _) as + ftype + ; _ } + when List.mem StdLibrary.reduce_sum_slice_types sliced_arg_fun_type + ~equal:( = ) -> + let mandatory_args = [sliced_arg_fun; (AutoDiffable, UInt)] in + let mandatory_fun_args = + [sliced_arg_fun; (DataOnly, UInt); (DataOnly, UInt)] in + let arg_types = + (calculate_autodifftype cf.current_block Functions ftype, ftype) + :: get_arg_types remaining_es in + SignatureMismatch.check_variadic_args true mandatory_args + mandatory_fun_args UReal arg_types + | _ -> basic_mismatch () in + match tes with + | {expr= Variable fname; _} :: remaining_es -> ( + match find_matching_first_order_fn tenv (matching remaining_es) fname with + | SignatureMismatch.UniqueMatch (ftype, promotions) -> + (* a valid signature exists *) + let tes = make_function_variable cf loc fname ftype :: remaining_es in + mk_typed_expression + ~expr: + (mk_fun_app ~is_cond_dist + (StanLib FnPlain, id, Promotion.promote_list tes promotions) ) + ~ad_level:(expr_ad_lub tes) ~type_:UnsizedType.UReal ~loc + | AmbiguousMatch ps -> + Semantic_error.ambiguous_function_promotion loc fname.name None ps + |> error + | SignatureErrors (expected_args, err) -> + Semantic_error.illtyped_reduce_sum loc id.name + (List.map ~f:type_of_expr_typed tes) + expected_args err + |> error ) + | _ -> fail () + + and check_variadic_ode ~is_cond_dist loc cf tenv id tes = + let optional_tol_mandatory_args = + if StdLibrary.variadic_ode_adjoint_fn = id.name then + StdLibrary.variadic_ode_adjoint_ctl_tol_arg_types + else if StdLibrary.is_variadic_ode_nonadjoint_tol_fn id.name then + StdLibrary.variadic_ode_tol_arg_types + else [] in + let mandatory_arg_types = + StdLibrary.variadic_ode_mandatory_arg_types @ optional_tol_mandatory_args + in + let fail () = + let expected_args, err = + SignatureMismatch.check_variadic_args false mandatory_arg_types + StdLibrary.variadic_ode_mandatory_fun_args + StdLibrary.variadic_ode_fun_return_type (get_arg_types tes) + |> Result.error |> Option.value_exn in + Semantic_error.illtyped_variadic_ode loc id.name + (List.map ~f:type_of_expr_typed tes) + expected_args err + |> error in + let matching remaining_es Env.{type_= ftype; _} = + let arg_types = + (calculate_autodifftype cf.current_block Functions ftype, ftype) + :: get_arg_types remaining_es in + SignatureMismatch.check_variadic_args false mandatory_arg_types + StdLibrary.variadic_ode_mandatory_fun_args + StdLibrary.variadic_ode_fun_return_type arg_types in + match tes with + | {expr= Variable fname; _} :: remaining_es -> ( + match find_matching_first_order_fn tenv (matching remaining_es) fname with + | SignatureMismatch.UniqueMatch (ftype, promotions) -> + let tes = make_function_variable cf loc fname ftype :: remaining_es in + mk_typed_expression + ~expr: + (mk_fun_app ~is_cond_dist + (StanLib FnPlain, id, Promotion.promote_list tes promotions) ) + ~ad_level:(expr_ad_lub tes) + ~type_:StdLibrary.variadic_ode_return_type ~loc + | AmbiguousMatch ps -> + Semantic_error.ambiguous_function_promotion loc fname.name None ps + |> error + | SignatureErrors (expected_args, err) -> + Semantic_error.illtyped_variadic_ode loc id.name + (List.map ~f:type_of_expr_typed tes) + expected_args err + |> error ) + | _ -> fail () + + and check_variadic_dae ~is_cond_dist loc cf tenv id tes = + let optional_tol_mandatory_args = + if StdLibrary.is_variadic_dae_tol_fn id.name then + StdLibrary.variadic_dae_tol_arg_types + else [] in + let mandatory_arg_types = + StdLibrary.variadic_dae_mandatory_arg_types @ optional_tol_mandatory_args + in + let fail () = + let expected_args, err = + SignatureMismatch.check_variadic_args false mandatory_arg_types + StdLibrary.variadic_dae_mandatory_fun_args + StdLibrary.variadic_dae_fun_return_type (get_arg_types tes) + |> Result.error |> Option.value_exn in + Semantic_error.illtyped_variadic_dae loc id.name + (List.map ~f:type_of_expr_typed tes) + expected_args err + |> error in + let matching remaining_es Env.{type_= ftype; _} = + let arg_types = + (calculate_autodifftype cf.current_block Functions ftype, ftype) + :: get_arg_types remaining_es in + SignatureMismatch.check_variadic_args false mandatory_arg_types + StdLibrary.variadic_dae_mandatory_fun_args + StdLibrary.variadic_dae_fun_return_type arg_types in + match tes with + | {expr= Variable fname; _} :: remaining_es -> ( + match find_matching_first_order_fn tenv (matching remaining_es) fname with + | SignatureMismatch.UniqueMatch (ftype, promotions) -> + let tes = make_function_variable cf loc fname ftype :: remaining_es in + mk_typed_expression + ~expr: + (mk_fun_app ~is_cond_dist + (StanLib FnPlain, id, Promotion.promote_list tes promotions) ) + ~ad_level:(expr_ad_lub tes) + ~type_:StdLibrary.variadic_dae_return_type ~loc + | AmbiguousMatch ps -> + Semantic_error.ambiguous_function_promotion loc fname.name None ps + |> error + | SignatureErrors (expected_args, err) -> + Semantic_error.illtyped_variadic_dae loc id.name + (List.map ~f:type_of_expr_typed tes) + expected_args err + |> error ) + | _ -> fail () *) + + and check_funapp loc cf tenv ~is_cond_dist id (es : Ast.typed_expression list) + = + let name_check = + if is_cond_dist then verify_conddist_name else verify_fn_conditioning + in + let res = check_fn ~is_cond_dist loc cf tenv id es in + verify_identifier id ; + name_check loc id ; + verify_fn_target_plus_equals cf loc id ; + verify_fn_rng cf loc id ; + verify_unnormalized cf loc id ; + res + + and check_indexed loc cf tenv e indices = + let tindices = List.map ~f:(check_index cf tenv) indices in + let te = check_expression cf tenv e in + let ad_level = inferred_ad_type_of_indexed te.emeta.ad_level tindices in + let type_ = inferred_unsizedtype_of_indexed ~loc te.emeta.type_ tindices in + mk_typed_expression ~expr:(Indexed (te, tindices)) ~ad_level ~type_ ~loc + + and check_index cf tenv = function + | All -> All + (* Check that indexes have int (container) type *) + | Single e -> + let te = check_expression cf tenv e in + if has_int_type te || has_int_array_type te then Single te + else + Semantic_error.int_intarray_or_range_expected te.emeta.loc + te.emeta.type_ + |> error + | Upfrom e -> check_expression_of_int_type cf tenv e "Range bound" |> Upfrom + | Downfrom e -> + check_expression_of_int_type cf tenv e "Range bound" |> Downfrom + | Between (e1, e2) -> + let le = check_expression_of_int_type cf tenv e1 "Range bound" in + let ue = check_expression_of_int_type cf tenv e2 "Range bound" in + Between (le, ue) + + and check_expression cf tenv ({emeta; expr} : Ast.untyped_expression) : + Ast.typed_expression = + let loc = emeta.loc in + let ce = check_expression cf tenv in + match expr with + | TernaryIf (e1, e2, e3) -> + let pe = ce e1 in + let te = ce e2 in + let fe = ce e3 in + check_ternary_if loc pe te fe + | BinOp (e1, op, e2) -> + let le = ce e1 in + let re = ce e2 in + let binop_type_warnings x y = + match (x.emeta.type_, y.emeta.type_, op) with + | UInt, UInt, Divide -> + let hint ppf () = + match (x.expr, y.expr) with + | IntNumeral x, _ -> + Fmt.pf ppf "%s.0 / %a" x Pretty_printing.pp_typed_expression + y + | _, Ast.IntNumeral y -> + Fmt.pf ppf "%a / %s.0" Pretty_printing.pp_typed_expression x + y + | _ -> + Fmt.pf ppf "%a * 1.0 / %a" + Pretty_printing.pp_typed_expression x + Pretty_printing.pp_typed_expression y in + let s = + Fmt.str + "@[@[Found int division:@]@ @[%a@]@,\ + @[%a@]@ @[%a@]@,\ + @[%a@]@]" + Pretty_printing.pp_expression {expr; emeta} Fmt.text + "Values will be rounded towards zero. If rounding is not \ + desired you can write the division as" + hint () Fmt.text + "If rounding is intended please use the integer division \ + operator %/%." in + add_warning x.emeta.loc s + | (UArray UMatrix | UMatrix), (UInt | UReal), Pow -> + let s = + Fmt.str + "@[@[Found matrix^scalar:@]@ @[%a@]@,\ + @[%a@]@ @[%a@]@]" Pretty_printing.pp_expression + {expr; emeta} Fmt.text + "matrix ^ number is interpreted as element-wise \ + exponentiation. If this is intended, you can silence this \ + warning by using elementwise operator .^" + Fmt.text + "If you intended matrix exponentiation, use the function \ + matrix_power(matrix,int) instead." in + add_warning x.emeta.loc s + | _ -> () in + binop_type_warnings le re ; check_binop loc op le re + | PrefixOp (op, e) -> ce e |> check_prefixop loc op + | PostfixOp (e, op) -> ce e |> check_postfixop loc op + | Variable id -> + verify_identifier id ; + check_variable cf loc tenv id + | IntNumeral s -> ( + match float_of_string_opt s with + | Some i when i < 2_147_483_648.0 -> + mk_typed_expression ~expr:(IntNumeral s) ~ad_level:DataOnly + ~type_:UInt ~loc + | _ -> Semantic_error.bad_int_literal loc |> error ) + | RealNumeral s -> + mk_typed_expression ~expr:(RealNumeral s) ~ad_level:DataOnly ~type_:UReal ~loc - | ArrayExpr es -> es |> List.map ~f:ce |> check_array_expr loc - | RowVectorExpr es -> es |> List.map ~f:ce |> check_rowvector loc - | Paren e -> - let te = ce e in - mk_typed_expression ~expr:(Paren te) ~ad_level:te.emeta.ad_level - ~type_:te.emeta.type_ ~loc - | Indexed (e, indices) -> check_indexed loc cf tenv e indices - | FunApp ((), id, es) -> - es |> List.map ~f:ce |> check_funapp loc cf tenv ~is_cond_dist:false id - | CondDistApp ((), id, es) -> - es |> List.map ~f:ce |> check_funapp loc cf tenv ~is_cond_dist:true id - | Promotion (e, _, _) -> - (* Should never happen: promotions are produced during typechecking *) - Common.FatalError.fatal_error_msg - [%message "Promotion in untyped AST" (e : Ast.untyped_expression)] - -and check_expression_of_int_type cf tenv e name = - let te = check_expression cf tenv e in - if has_int_type te then te - else Semantic_error.int_expected te.emeta.loc name te.emeta.type_ |> error - -let check_expression_of_int_or_real_type cf tenv e name = - let te = check_expression cf tenv e in - if has_int_or_real_type te then te - else - Semantic_error.int_or_real_expected te.emeta.loc name te.emeta.type_ - |> error - -let check_expression_of_scalar_or_type cf tenv t e name = - let te = check_expression cf tenv e in - if UnsizedType.is_scalar_type te.emeta.type_ || te.emeta.type_ = t then te - else - Semantic_error.scalar_or_type_expected te.emeta.loc name t te.emeta.type_ - |> error - -(* -- Statements ------------------------------------------------- *) -(* non returning functions *) -let verify_nrfn_target loc cf id = - if - String.is_suffix id.name ~suffix:"_lp" - && not - ( cf.in_lp_fun_def || cf.current_block = Model - || cf.current_block = TParam ) - then Semantic_error.target_plusequals_outisde_model_or_logprob loc |> error - -let check_nrfn loc tenv id es = - match Env.find tenv id.name with - | {kind= `Variable _; _} :: _ - (* variables can shadow stanlib functions, so we have to check this *) - when not (Stan_math_signatures.is_stan_math_function_name id.name) -> - Semantic_error.nonreturning_fn_expected_nonfn_found loc id.name |> error - | [] -> - Semantic_error.nonreturning_fn_expected_undeclaredident_found loc id.name - (Env.nearest_ident tenv id.name) + | ImagNumeral s -> + mk_typed_expression ~expr:(ImagNumeral s) ~ad_level:DataOnly + ~type_:UComplex ~loc + | GetLP -> + (* Target+= can only be used in model and functions with right suffix (same for tilde etc) *) + if + not + ( cf.in_lp_fun_def || cf.current_block = Model + || cf.current_block = TParam ) + then + Semantic_error.target_plusequals_outisde_model_or_logprob loc |> error + else + mk_typed_expression ~expr:GetLP + ~ad_level: + (calculate_autodifftype cf.current_block cf.current_block UReal) + ~type_:UReal ~loc + | GetTarget -> + (* Target+= can only be used in model and functions with right suffix (same for tilde etc) *) + if + not + ( cf.in_lp_fun_def || cf.current_block = Model + || cf.current_block = TParam ) + then + Semantic_error.target_plusequals_outisde_model_or_logprob loc |> error + else + mk_typed_expression ~expr:GetTarget + ~ad_level: + (calculate_autodifftype cf.current_block cf.current_block UReal) + ~type_:UReal ~loc + | ArrayExpr es -> es |> List.map ~f:ce |> check_array_expr loc + | RowVectorExpr es -> es |> List.map ~f:ce |> check_rowvector loc + | Paren e -> + let te = ce e in + mk_typed_expression ~expr:(Paren te) ~ad_level:te.emeta.ad_level + ~type_:te.emeta.type_ ~loc + | Indexed (e, indices) -> check_indexed loc cf tenv e indices + | FunApp ((), id, es) -> + es |> List.map ~f:ce |> check_funapp loc cf tenv ~is_cond_dist:false id + | CondDistApp ((), id, es) -> + es |> List.map ~f:ce |> check_funapp loc cf tenv ~is_cond_dist:true id + | Promotion (e, _, _) -> + (* Should never happen: promotions are produced during typechecking *) + Common.FatalError.fatal_error_msg + [%message "Promotion in untyped AST" (e : Ast.untyped_expression)] + + and check_expression_of_int_type cf tenv e name = + let te = check_expression cf tenv e in + if has_int_type te then te + else Semantic_error.int_expected te.emeta.loc name te.emeta.type_ |> error + + let check_expression_of_int_or_real_type cf tenv e name = + let te = check_expression cf tenv e in + if has_int_or_real_type te then te + else + Semantic_error.int_or_real_expected te.emeta.loc name te.emeta.type_ |> error - | _ (* a function *) -> ( - match - SignatureMismatch.matching_function tenv id.name (get_arg_types es) - with - | UniqueMatch (Void, fnk, promotions) -> - mk_typed_statement - ~stmt: - (NRFunApp - ( fnk (Fun_kind.suffix_from_name id.name) - , id - , Promotion.promote_list es promotions ) ) - ~return_type:NoReturnType ~loc - | UniqueMatch (ReturnType _, _, _) -> - Semantic_error.nonreturning_fn_expected_returning_found loc id.name - |> error - | AmbiguousMatch sigs -> - Semantic_error.ambiguous_function_promotion loc id.name - (Some (List.map ~f:type_of_expr_typed es)) - sigs - |> error - | SignatureErrors (l, b) -> - es - |> List.map ~f:type_of_expr_typed - |> Semantic_error.illtyped_fn_app loc id.name (l, b) - |> error ) - -let check_nr_fn_app loc cf tenv id es = - let tes = List.map ~f:(check_expression cf tenv) es in - verify_identifier id ; - verify_nrfn_target loc cf id ; - check_nrfn loc tenv id tes - -(* assignments *) -let verify_assignment_read_only loc is_readonly id = - if is_readonly then - Semantic_error.cannot_assign_to_read_only loc id.name |> error - -(* Variables from previous blocks are read-only. - In particular, data and parameters never assigned to -*) -let verify_assignment_global loc cf block is_global id = - if (not is_global) || block = cf.current_block then () - else Semantic_error.cannot_assign_to_global loc id.name |> error -(* Until function types are added to the user language, we - disallow assignments to function values -*) -let verify_assignment_non_function loc ut id = - match ut with - | UnsizedType.UFun _ | UMathLibraryFunction -> - Semantic_error.cannot_assign_function loc ut id.name |> error - | _ -> () - -let check_assignment_operator loc assop lhs rhs = - let err op = - Semantic_error.illtyped_assignment loc op lhs.lmeta.type_ rhs.emeta.type_ - in - match assop with - | Assign | ArrowAssign -> ( - match - SignatureMismatch.check_of_same_type_mod_conv lhs.lmeta.type_ - rhs.emeta.type_ - with - | Ok p -> Promotion.promote rhs p - | Error _ -> err Operator.Equals |> error ) - | OperatorAssign op -> ( - let args = List.map ~f:arg_type [Ast.expr_of_lvalue lhs; rhs] in - let return_type = assignmentoperator_stan_math_return_type op args in - match return_type with Some Void -> rhs | _ -> err op |> error ) - -let check_lvalue cf tenv = function - | {lval= LVariable id; lmeta= ({loc} : located_meta)} -> - verify_identifier id ; - let ad_level, type_ = check_id cf loc tenv id in - {lval= LVariable id; lmeta= {ad_level; type_; loc}} - | {lval= LIndexed (lval, idcs); lmeta= {loc}} -> - let rec check_inner = function - | {lval= LVariable id; lmeta= ({loc} : located_meta)} -> - verify_identifier id ; - let ad_level, type_ = check_id cf loc tenv id in - let var = {lval= LVariable id; lmeta= {ad_level; type_; loc}} in - (var, var, []) - | {lval= LIndexed (lval, idcs); lmeta= {loc}} -> - let lval, var, flat = check_inner lval in - let idcs = List.map ~f:(check_index cf tenv) idcs in - let ad_level = - inferred_ad_type_of_indexed lval.lmeta.ad_level idcs in - let type_ = - inferred_unsizedtype_of_indexed ~loc lval.lmeta.type_ idcs in - ( {lval= LIndexed (lval, idcs); lmeta= {ad_level; type_; loc}} - , var - , flat @ idcs ) in - let lval, var, flat = check_inner lval in - let idcs = List.map ~f:(check_index cf tenv) idcs in - let ad_level = inferred_ad_type_of_indexed lval.lmeta.ad_level idcs in - let type_ = inferred_unsizedtype_of_indexed ~loc lval.lmeta.type_ idcs in - if List.exists ~f:is_multiindex flat then ( - add_warning loc - "Nested multi-indexing on the left hand side of assignment does not \ - behave the same as nested indexing in expressions. This is \ - considered a bug and will be disallowed in Stan 2.32.0. The \ - indexing can be automatically fixed using the canonicalize flag for \ - stanc." ; - let lvalue_rvalue_types_differ = - try - let flat_type = - inferred_unsizedtype_of_indexed ~loc var.lmeta.type_ (flat @ idcs) - in - let rec can_assign = function - | UnsizedType.(UArray t1, UArray t2) -> can_assign (t1, t2) - | UVector, URowVector | URowVector, UVector -> false - | t1, t2 -> UnsizedType.compare t1 t2 <> 0 in - can_assign (flat_type, type_) - with Errors.SemanticError _ -> true in - if lvalue_rvalue_types_differ then - Semantic_error.cannot_assign_to_multiindex loc |> error ) ; - {lval= LIndexed (lval, idcs); lmeta= {ad_level; type_; loc}} - -let check_assignment loc cf tenv assign_lhs assign_op assign_rhs = - let assign_id = Ast.id_of_lvalue assign_lhs in - let lhs = check_lvalue cf tenv assign_lhs in - let rhs = check_expression cf tenv assign_rhs in - let block, global, readonly = - let var = Env.find tenv assign_id.name in - match var with - | {kind= `Variable {origin; global; readonly}; _} :: _ -> - (origin, global, readonly) - | {kind= `StanMath; _} :: _ -> (MathLibrary, true, false) - | {kind= `UserDefined | `UserDeclared _; _} :: _ -> (Functions, true, false) - | _ -> - Semantic_error.ident_not_in_scope loc assign_id.name - (Env.nearest_ident tenv assign_id.name) - |> error in - verify_assignment_global loc cf block global assign_id ; - verify_assignment_read_only loc readonly assign_id ; - verify_assignment_non_function loc rhs.emeta.type_ assign_id ; - let rhs' = check_assignment_operator loc assign_op lhs rhs in - mk_typed_statement ~return_type:NoReturnType ~loc - ~stmt:(Assignment {assign_lhs= lhs; assign_op; assign_rhs= rhs'}) - -(* target plus-equals / increment log-prob *) - -let verify_target_pe_expr_type loc e = - if UnsizedType.is_fun_type e.emeta.type_ then - Semantic_error.int_or_real_container_expected loc e.emeta.type_ |> error - -let verify_target_pe_usage loc cf = - if cf.in_lp_fun_def || cf.current_block = Model then () - else Semantic_error.target_plusequals_outisde_model_or_logprob loc |> error - -let check_target_pe loc cf tenv e = - let te = check_expression cf tenv e in - verify_target_pe_usage loc cf ; - verify_target_pe_expr_type loc te ; - mk_typed_statement ~stmt:(TargetPE te) ~return_type:NoReturnType ~loc - -let check_incr_logprob loc cf tenv e = - let te = check_expression cf tenv e in - verify_target_pe_usage loc cf ; - verify_target_pe_expr_type loc te ; - mk_typed_statement ~stmt:(IncrementLogProb te) ~return_type:NoReturnType ~loc - -(* tilde/sampling notation*) -let verify_sampling_pdf_pmf id = - if - String.( - is_suffix id.name ~suffix:"_lpdf" - || is_suffix id.name ~suffix:"_lpmf" - || is_suffix id.name ~suffix:"_lupdf" - || is_suffix id.name ~suffix:"_lupmf") - then Semantic_error.invalid_sampling_pdf_or_pmf id.id_loc |> error - -let verify_sampling_cdf_ccdf loc id = - if - String.( - is_suffix id.name ~suffix:"_cdf" || is_suffix id.name ~suffix:"_ccdf") - then Semantic_error.invalid_sampling_cdf_or_ccdf loc id.name |> error - -(* Target+= can only be used in model and functions with right suffix (same for tilde etc) *) -let verify_valid_sampling_pos loc cf = - if cf.in_lp_fun_def || cf.current_block = Model then () - else Semantic_error.target_plusequals_outisde_model_or_logprob loc |> error - -let verify_sampling_distribution loc tenv id arguments = - let name = id.name in - let argumenttypes = List.map ~f:arg_type arguments in - let name_w_suffix_sampling_dist suffix = - SignatureMismatch.matching_function tenv (name ^ suffix) argumenttypes in - let sampling_dists = - List.map ~f:name_w_suffix_sampling_dist Utils.distribution_suffices in - let is_sampling_dist_defined = - List.exists - ~f:(function UniqueMatch (ReturnType UReal, _, _) -> true | _ -> false) - sampling_dists - && name <> "binomial_coefficient" - && name <> "multiply" in - if is_sampling_dist_defined then () - else - match - List.max_elt sampling_dists - ~compare:SignatureMismatch.compare_match_results - with - | None | Some (UniqueMatch _) | Some (SignatureErrors ([], _)) -> - (* Either non-existant or a very odd case, - output the old non-informative error *) - Semantic_error.invalid_sampling_no_such_dist loc name |> error - | Some (AmbiguousMatch sigs) -> - Semantic_error.ambiguous_function_promotion loc id.name - (Some (List.map ~f:type_of_expr_typed arguments)) - sigs - |> error - | Some (SignatureErrors (l, b)) -> - arguments - |> List.map ~f:(fun e -> e.emeta.type_) - |> Semantic_error.illtyped_fn_app loc id.name (l, b) + let check_expression_of_scalar_or_type cf tenv t e name = + let te = check_expression cf tenv e in + if UnsizedType.is_scalar_type te.emeta.type_ || te.emeta.type_ = t then te + else + Semantic_error.scalar_or_type_expected te.emeta.loc name t te.emeta.type_ + |> error + + (* -- Statements ------------------------------------------------- *) + (* non returning functions *) + let verify_nrfn_target loc cf id = + if + String.is_suffix id.name ~suffix:"_lp" + && not + ( cf.in_lp_fun_def || cf.current_block = Model + || cf.current_block = TParam ) + then Semantic_error.target_plusequals_outisde_model_or_logprob loc |> error + + let check_nrfn loc tenv id es = + match Env.find tenv id.name with + | {kind= `Variable _; _} :: _ + (* variables can shadow stanlib functions, so we have to check this *) + when not (StdLibrary.is_stan_math_function_name id.name) -> + Semantic_error.nonreturning_fn_expected_nonfn_found loc id.name |> error + | [] -> + Semantic_error.nonreturning_fn_expected_undeclaredident_found loc + id.name + (Env.nearest_ident tenv id.name) |> error + | _ (* a function *) -> ( + match + SignatureMismatch.matching_function tenv id.name (get_arg_types es) + with + | UniqueMatch (Void, fnk, promotions) -> + mk_typed_statement + ~stmt: + (NRFunApp + ( fnk (Fun_kind.suffix_from_name id.name) + , id + , Promotion.promote_list es promotions ) ) + ~return_type:NoReturnType ~loc + | UniqueMatch (ReturnType _, _, _) -> + Semantic_error.nonreturning_fn_expected_returning_found loc id.name + |> error + | AmbiguousMatch sigs -> + Semantic_error.ambiguous_function_promotion loc id.name + (Some (List.map ~f:type_of_expr_typed es)) + sigs + |> error + | SignatureErrors (l, b) -> + es + |> List.map ~f:type_of_expr_typed + |> Semantic_error.illtyped_fn_app loc id.name (l, b) + |> error ) -let is_cumulative_density_defined tenv id arguments = - let name = id.name in - let argumenttypes = List.map ~f:arg_type arguments in - let valid_arg_types_for_suffix suffix = - match - SignatureMismatch.matching_function tenv (name ^ suffix) argumenttypes - with - | UniqueMatch (ReturnType UReal, _, _) -> true - | _ -> false in - (valid_arg_types_for_suffix "_lcdf" || valid_arg_types_for_suffix "_cdf_log") - && ( valid_arg_types_for_suffix "_lccdf" - || valid_arg_types_for_suffix "_ccdf_log" ) - -let verify_can_truncate_distribution loc (arg : typed_expression) = function - | NoTruncate -> () - | _ -> - if UnsizedType.is_scalar_type arg.emeta.type_ then () - else Semantic_error.multivariate_truncation loc |> error - -let verify_sampling_cdf_defined loc tenv id truncation args = - let check e = is_cumulative_density_defined tenv id (e :: args) in - match truncation with - | NoTruncate -> () - | (TruncateUpFrom e | TruncateDownFrom e) when check e -> () - | TruncateBetween (e1, e2) when check e1 && check e2 -> () - | _ -> Semantic_error.invalid_truncation_cdf_or_ccdf loc |> error - -let check_truncation cf tenv truncation = - let check e = - check_expression_of_int_or_real_type cf tenv e "Truncation bound" in - match truncation with - | NoTruncate -> NoTruncate - | TruncateUpFrom e -> check e |> TruncateUpFrom - | TruncateDownFrom e -> check e |> TruncateDownFrom - | TruncateBetween (e1, e2) -> (check e1, check e2) |> TruncateBetween - -let check_tilde loc cf tenv distribution truncation arg args = - let te = check_expression cf tenv arg in - let tes = List.map ~f:(check_expression cf tenv) args in - let ttrunc = check_truncation cf tenv truncation in - verify_identifier distribution ; - verify_sampling_pdf_pmf distribution ; - verify_valid_sampling_pos loc cf ; - verify_sampling_cdf_ccdf loc distribution ; - verify_sampling_distribution loc tenv distribution (te :: tes) ; - verify_sampling_cdf_defined loc tenv distribution ttrunc tes ; - verify_can_truncate_distribution loc te ttrunc ; - let stmt = Tilde {arg= te; distribution; args= tes; truncation= ttrunc} in - mk_typed_statement ~stmt ~loc ~return_type:NoReturnType - -(* Break and continue only occur in loops. *) -let check_break loc cf = - if cf.loop_depth = 0 then Semantic_error.break_outside_loop loc |> error - else mk_typed_statement ~stmt:Break ~return_type:NoReturnType ~loc - -let check_continue loc cf = - if cf.loop_depth = 0 then Semantic_error.continue_outside_loop loc |> error - else mk_typed_statement ~stmt:Continue ~return_type:NoReturnType ~loc - -let check_return loc cf tenv e = - if not cf.in_returning_fun_def then - Semantic_error.expression_return_outside_returning_fn loc |> error - else - let te = check_expression cf tenv e in - mk_typed_statement ~stmt:(Return te) - ~return_type:(Complete (ReturnType te.emeta.type_)) ~loc - -let check_returnvoid loc cf = - if (not cf.in_fun_def) || cf.in_returning_fun_def then - Semantic_error.void_ouside_nonreturning_fn loc |> error - else mk_typed_statement ~stmt:ReturnVoid ~return_type:(Complete Void) ~loc - -let check_printable cf tenv = function - | PString s -> PString s - (* Print/reject expressions cannot be of function type. *) - | PExpr e -> ( - let te = check_expression cf tenv e in - match te.emeta.type_ with - | UFun _ | UMathLibraryFunction -> - Semantic_error.not_printable te.emeta.loc |> error - | _ -> PExpr te ) - -let check_print loc cf tenv ps = - let tps = List.map ~f:(check_printable cf tenv) ps in - mk_typed_statement ~stmt:(Print tps) ~return_type:NoReturnType ~loc - -let check_reject loc cf tenv ps = - let tps = List.map ~f:(check_printable cf tenv) ps in - mk_typed_statement ~stmt:(Reject tps) ~return_type:AnyReturnType ~loc - -let check_skip loc = - mk_typed_statement ~stmt:Skip ~return_type:NoReturnType ~loc - -let rec stmt_is_escape {stmt; _} = - match stmt with - | Break | Continue | Reject _ | Return _ | ReturnVoid -> true - | _ -> false - -and list_until_escape xs = - let rec aux accu = function - | [next; next'] when stmt_is_escape next' -> List.rev (next' :: next :: accu) - | next :: next' :: unreachable :: _ when stmt_is_escape next' -> - add_warning unreachable.smeta.loc - "Unreachable statement (following a reject, break, continue, or \ - return) found, is this intended?" ; - List.rev (next' :: next :: accu) - | next :: rest -> aux (next :: accu) rest - | [] -> List.rev accu in - aux [] xs - -let returntype_leastupperbound loc rt1 rt2 = - match (rt1, rt2) with - | UnsizedType.ReturnType UReal, UnsizedType.ReturnType UInt - |ReturnType UInt, ReturnType UReal -> - UnsizedType.ReturnType UReal - | _, _ when rt1 = rt2 -> rt2 - | _ -> Semantic_error.mismatched_return_types loc rt1 rt2 |> error - -let try_compute_block_statement_returntype loc srt1 srt2 = - match (srt1, srt2) with - | Complete rt1, Complete rt2 | Incomplete rt1, Complete rt2 -> - Complete (returntype_leastupperbound loc rt1 rt2) - | Incomplete rt1, Incomplete rt2 | Complete rt1, Incomplete rt2 -> - Incomplete (returntype_leastupperbound loc rt1 rt2) - | NoReturnType, NoReturnType -> NoReturnType - | AnyReturnType, Incomplete rt - |Complete rt, NoReturnType - |NoReturnType, Incomplete rt - |Incomplete rt, NoReturnType -> - Incomplete rt - | NoReturnType, Complete rt - |Complete rt, AnyReturnType - |Incomplete rt, AnyReturnType - |AnyReturnType, Complete rt -> - Complete rt - | AnyReturnType, NoReturnType - |NoReturnType, AnyReturnType - |AnyReturnType, AnyReturnType -> - AnyReturnType - -let try_compute_ifthenelse_statement_returntype loc srt1 srt2 = - match (srt1, srt2) with - | Complete rt1, Complete rt2 -> - returntype_leastupperbound loc rt1 rt2 |> Complete - | Incomplete rt1, Incomplete rt2 - |Complete rt1, Incomplete rt2 - |Incomplete rt1, Complete rt2 -> - returntype_leastupperbound loc rt1 rt2 |> Incomplete - | AnyReturnType, NoReturnType - |NoReturnType, AnyReturnType - |NoReturnType, NoReturnType -> - NoReturnType - | AnyReturnType, Incomplete rt - |Incomplete rt, AnyReturnType - |Complete rt, NoReturnType - |NoReturnType, Complete rt - |NoReturnType, Incomplete rt - |Incomplete rt, NoReturnType -> - Incomplete rt - | Complete rt, AnyReturnType | AnyReturnType, Complete rt -> Complete rt - | AnyReturnType, AnyReturnType -> AnyReturnType - -(* statements which contain statements, and therefore need to be mutually recursive - with check_statement -*) -let rec check_if_then_else loc cf tenv pred_e s_true s_false_opt = - (* we don't need these nested type environments *) - let _, ts_true = check_statement cf tenv s_true in - let ts_false_opt = - s_false_opt |> Option.map ~f:(check_statement cf tenv) |> Option.map ~f:snd - in - let te = - check_expression_of_int_or_real_type cf tenv pred_e - "Condition in conditional" in - let stmt = IfThenElse (te, ts_true, ts_false_opt) in - let srt1 = ts_true.smeta.return_type in - let srt2 = - ts_false_opt - |> Option.map ~f:(fun s -> s.smeta.return_type) - |> Option.value ~default:NoReturnType in - let return_type = try_compute_ifthenelse_statement_returntype loc srt1 srt2 in - mk_typed_statement ~stmt ~return_type ~loc - -and check_while loc cf tenv cond_e loop_body = - let _, ts = - check_statement {cf with loop_depth= cf.loop_depth + 1} tenv loop_body - and te = - check_expression_of_int_or_real_type cf tenv cond_e - "Condition in while-loop" in - mk_typed_statement - ~stmt:(While (te, ts)) - ~return_type:ts.smeta.return_type ~loc - -and check_for loc cf tenv loop_var lower_bound_e upper_bound_e loop_body = - let te1 = - check_expression_of_int_type cf tenv lower_bound_e "Lower bound of for-loop" - and te2 = - check_expression_of_int_type cf tenv upper_bound_e "Upper bound of for-loop" - in - verify_identifier loop_var ; - let ts = check_loop_body cf tenv loop_var UnsizedType.UInt loop_body in - mk_typed_statement - ~stmt: - (For - { loop_variable= loop_var - ; lower_bound= te1 - ; upper_bound= te2 - ; loop_body= ts } ) - ~return_type:ts.smeta.return_type ~loc - -and check_foreach_loop_identifier_type loc ty = - match ty with - | UnsizedType.UArray ut -> ut - | UVector | URowVector | UMatrix -> UnsizedType.UReal - | _ -> Semantic_error.array_vector_rowvector_matrix_expected loc ty |> error - -and check_foreach loc cf tenv loop_var foreach_e loop_body = - let te = check_expression cf tenv foreach_e in - verify_identifier loop_var ; - let loop_var_ty = - check_foreach_loop_identifier_type te.emeta.loc te.emeta.type_ in - let ts = check_loop_body cf tenv loop_var loop_var_ty loop_body in - mk_typed_statement - ~stmt:(ForEach (loop_var, te, ts)) - ~return_type:ts.smeta.return_type ~loc - -and check_loop_body cf tenv loop_var loop_var_ty loop_body = - verify_name_fresh tenv loop_var ~is_udf:false ; - (* Add to type environment as readonly. - Check that function args and loop identifiers are not modified in - function. (passed by const ref) + let check_nr_fn_app loc cf tenv id es = + let tes = List.map ~f:(check_expression cf tenv) es in + verify_identifier id ; + verify_nrfn_target loc cf id ; + check_nrfn loc tenv id tes + + (* assignments *) + let verify_assignment_read_only loc is_readonly id = + if is_readonly then + Semantic_error.cannot_assign_to_read_only loc id.name |> error + + (* Variables from previous blocks are read-only. + In particular, data and parameters never assigned to *) - let tenv = - Env.add tenv loop_var.name loop_var_ty - (`Variable {origin= cf.current_block; global= false; readonly= true}) - in - snd (check_statement {cf with loop_depth= cf.loop_depth + 1} tenv loop_body) - -and check_block loc cf tenv stmts = - let _, checked_stmts = - List.fold_map stmts ~init:tenv ~f:(check_statement cf) in - let return_type = - checked_stmts |> list_until_escape - |> List.map ~f:(fun s -> s.smeta.return_type) - |> List.fold ~init:NoReturnType - ~f:(try_compute_block_statement_returntype loc) in - mk_typed_statement ~stmt:(Block checked_stmts) ~return_type ~loc - -and check_profile loc cf tenv name stmts = - let _, checked_stmts = - List.fold_map stmts ~init:tenv ~f:(check_statement cf) in - let return_type = - checked_stmts |> list_until_escape - |> List.map ~f:(fun s -> s.smeta.return_type) - |> List.fold ~init:NoReturnType - ~f:(try_compute_block_statement_returntype loc) in - mk_typed_statement ~stmt:(Profile (name, checked_stmts)) ~return_type ~loc - -(* variable declarations *) -and verify_valid_transformation_for_type loc is_global sized_ty trans = - let is_real {emeta; _} = emeta.type_ = UReal in - let is_real_transformation = - match trans with - | Transformation.Lower e -> is_real e - | Upper e -> is_real e - | LowerUpper (e1, e2) -> is_real e1 || is_real e2 - | _ -> false in - if is_global && sized_ty = SizedType.SInt && is_real_transformation then - Semantic_error.non_int_bounds loc |> error ; - let is_transformation = - match trans with Transformation.Identity -> false | _ -> true in - if is_global && SizedType.(contains_complex sized_ty) && is_transformation - then Semantic_error.complex_transform loc |> error - -and verify_transformed_param_ty loc cf is_global unsized_ty = - if - is_global - && (cf.current_block = Param || cf.current_block = TParam) - && UnsizedType.is_int_type unsized_ty - then Semantic_error.transformed_params_int loc |> error - -and check_sizedtype cf tenv sizedty = - let check e msg = check_expression_of_int_type cf tenv e msg in - match sizedty with - | SizedType.SInt -> SizedType.SInt - | SReal -> SReal - | SComplex -> SComplex - | SVector (mem_pattern, e) -> - let te = check e "Vector sizes" in - SVector (mem_pattern, te) - | SRowVector (mem_pattern, e) -> - let te = check e "Row vector sizes" in - SRowVector (mem_pattern, te) - | SMatrix (mem_pattern, e1, e2) -> - let te1 = check e1 "Matrix row size" in - let te2 = check e2 "Matrix column size" in - SMatrix (mem_pattern, te1, te2) - | SComplexVector e -> - let te = check e "complex vector sizes" in - SComplexVector te - | SComplexRowVector e -> - let te = check e "complex row vector sizes" in - SComplexRowVector te - | SComplexMatrix (e1, e2) -> - let te1 = check e1 "Complex matrix row size" in - let te2 = check e2 "Complex matrix column size" in - SComplexMatrix (te1, te2) - | SArray (st, e) -> - let tst = check_sizedtype cf tenv st in - let te = check e "Array sizes" in - SArray (tst, te) - -and check_var_decl_initial_value loc cf tenv id init_val_opt = - match init_val_opt with - | Some e -> ( - let lhs = check_lvalue cf tenv {lval= LVariable id; lmeta= {loc}} in - let rhs = check_expression cf tenv e in + let verify_assignment_global loc cf block is_global id = + if (not is_global) || block = cf.current_block then () + else Semantic_error.cannot_assign_to_global loc id.name |> error + + (* Until function types are added to the user language, we + disallow assignments to function values + *) + let verify_assignment_non_function loc ut id = + match ut with + | UnsizedType.UFun _ | UMathLibraryFunction -> + Semantic_error.cannot_assign_function loc ut id.name |> error + | _ -> () + + let check_assignment_operator loc assop lhs rhs = + let err op = + Semantic_error.illtyped_assignment loc op lhs.lmeta.type_ rhs.emeta.type_ + in + match assop with + | Assign | ArrowAssign -> ( match SignatureMismatch.check_of_same_type_mod_conv lhs.lmeta.type_ rhs.emeta.type_ with - | Ok p -> Some (Promotion.promote rhs p) - | Error _ -> - Semantic_error.illtyped_assignment loc Equals lhs.lmeta.type_ - rhs.emeta.type_ - |> error ) - | None -> None - -and check_transformation cf tenv ut trans = - let check e msg = check_expression_of_scalar_or_type cf tenv ut e msg in - match trans with - | Transformation.Identity -> Transformation.Identity - | Lower e -> check e "Lower bound" |> Lower - | Upper e -> check e "Upper bound" |> Upper - | LowerUpper (e1, e2) -> - (check e1 "Lower bound", check e2 "Upper bound") |> LowerUpper - | Offset e -> check e "Offset" |> Offset - | Multiplier e -> check e "Multiplier" |> Multiplier - | OffsetMultiplier (e1, e2) -> - (check e1 "Offset", check e2 "Multiplier") |> OffsetMultiplier - | Ordered -> Ordered - | PositiveOrdered -> PositiveOrdered - | Simplex -> Simplex - | UnitVector -> UnitVector - | CholeskyCorr -> CholeskyCorr - | CholeskyCov -> CholeskyCov - | Correlation -> Correlation - | Covariance -> Covariance - -and check_var_decl loc cf tenv sized_ty trans id init is_global = - let checked_type = - check_sizedtype {cf with in_toplevel_decl= is_global} tenv sized_ty in - let unsized_type = SizedType.to_unsized checked_type in - let checked_trans = check_transformation cf tenv unsized_type trans in - verify_identifier id ; - verify_name_fresh tenv id ~is_udf:false ; - let tenv = - Env.add tenv id.name unsized_type - (`Variable {origin= cf.current_block; global= is_global; readonly= false}) - in - let tinit = check_var_decl_initial_value loc cf tenv id init in - verify_valid_transformation_for_type loc is_global checked_type checked_trans ; - verify_transformed_param_ty loc cf is_global unsized_type ; - let stmt = - VarDecl - { decl_type= Sized checked_type - ; transformation= checked_trans - ; identifier= id - ; initial_value= tinit - ; is_global } in - (tenv, mk_typed_statement ~stmt ~loc ~return_type:NoReturnType) - -(* function definitions *) -and exists_matching_fn_declared tenv id arg_tys rt = - let options = - List.concat_map ~f:(Env.find tenv) (distribution_name_variants id.name) - in - let f = function - | Env.{kind= `UserDeclared _; type_= UFun (listedtypes, rt', _, _)} - when arg_tys = listedtypes && rt = rt' -> - true - | _ -> false in - List.exists ~f options - -and verify_unique_signature tenv loc id arg_tys rt = - let existing = - List.concat_map ~f:(Env.find tenv) (distribution_name_variants id.name) - in - let same_args = function - | Env.{type_= UFun (listedtypes, _, _, _); _} - when List.map ~f:snd arg_tys = List.map ~f:snd listedtypes -> - true - | _ -> false in - match List.filter existing ~f:same_args with - | [] -> () - | {type_= UFun (_, rt', _, _); _} :: _ when rt <> rt' -> - Semantic_error.fn_overload_rt_only loc id.name rt rt' |> error - | {kind; _} :: _ -> - Semantic_error.fn_decl_redefined loc id.name - ~stan_math:(kind = `StanMath) - (UnsizedType.UFun (arg_tys, rt, Fun_kind.suffix_from_name id.name, AoS)) - |> error + | Ok p -> Promotion.promote rhs p + | Error _ -> err Operator.Equals |> error ) + | OperatorAssign op -> ( + let args = List.map ~f:arg_type [Ast.expr_of_lvalue lhs; rhs] in + let return_type = assignmentoperator_stan_math_return_type op args in + match return_type with Some Void -> rhs | _ -> err op |> error ) + + let check_lvalue cf tenv = function + | {lval= LVariable id; lmeta= ({loc} : located_meta)} -> + verify_identifier id ; + let ad_level, type_ = check_id cf loc tenv id in + {lval= LVariable id; lmeta= {ad_level; type_; loc}} + | {lval= LIndexed (lval, idcs); lmeta= {loc}} -> + let rec check_inner = function + | {lval= LVariable id; lmeta= ({loc} : located_meta)} -> + verify_identifier id ; + let ad_level, type_ = check_id cf loc tenv id in + let var = {lval= LVariable id; lmeta= {ad_level; type_; loc}} in + (var, var, []) + | {lval= LIndexed (lval, idcs); lmeta= {loc}} -> + let lval, var, flat = check_inner lval in + let idcs = List.map ~f:(check_index cf tenv) idcs in + let ad_level = + inferred_ad_type_of_indexed lval.lmeta.ad_level idcs in + let type_ = + inferred_unsizedtype_of_indexed ~loc lval.lmeta.type_ idcs in + ( {lval= LIndexed (lval, idcs); lmeta= {ad_level; type_; loc}} + , var + , flat @ idcs ) in + let lval, var, flat = check_inner lval in + let idcs = List.map ~f:(check_index cf tenv) idcs in + let ad_level = inferred_ad_type_of_indexed lval.lmeta.ad_level idcs in + let type_ = inferred_unsizedtype_of_indexed ~loc lval.lmeta.type_ idcs in + if List.exists ~f:is_multiindex flat then ( + add_warning loc + "Nested multi-indexing on the left hand side of assignment does \ + not behave the same as nested indexing in expressions. This is \ + considered a bug and will be disallowed in Stan 2.32.0. The \ + indexing can be automatically fixed using the canonicalize flag \ + for stanc." ; + let lvalue_rvalue_types_differ = + try + let flat_type = + inferred_unsizedtype_of_indexed ~loc var.lmeta.type_ + (flat @ idcs) in + let rec can_assign = function + | UnsizedType.(UArray t1, UArray t2) -> can_assign (t1, t2) + | UVector, URowVector | URowVector, UVector -> false + | t1, t2 -> UnsizedType.compare t1 t2 <> 0 in + can_assign (flat_type, type_) + with Errors.SemanticError _ -> true in + if lvalue_rvalue_types_differ then + Semantic_error.cannot_assign_to_multiindex loc |> error ) ; + {lval= LIndexed (lval, idcs); lmeta= {ad_level; type_; loc}} + + let check_assignment loc cf tenv assign_lhs assign_op assign_rhs = + let assign_id = Ast.id_of_lvalue assign_lhs in + let lhs = check_lvalue cf tenv assign_lhs in + let rhs = check_expression cf tenv assign_rhs in + let block, global, readonly = + let var = Env.find tenv assign_id.name in + match var with + | {kind= `Variable {origin; global; readonly}; _} :: _ -> + (origin, global, readonly) + | {kind= `StanMath; _} :: _ -> (MathLibrary, true, false) + | {kind= `UserDefined | `UserDeclared _; _} :: _ -> + (Functions, true, false) + | _ -> + Semantic_error.ident_not_in_scope loc assign_id.name + (Env.nearest_ident tenv assign_id.name) + |> error in + verify_assignment_global loc cf block global assign_id ; + verify_assignment_read_only loc readonly assign_id ; + verify_assignment_non_function loc rhs.emeta.type_ assign_id ; + let rhs' = check_assignment_operator loc assign_op lhs rhs in + mk_typed_statement ~return_type:NoReturnType ~loc + ~stmt:(Assignment {assign_lhs= lhs; assign_op; assign_rhs= rhs'}) + + (* target plus-equals / increment log-prob *) + + let verify_target_pe_expr_type loc e = + if UnsizedType.is_fun_type e.emeta.type_ then + Semantic_error.int_or_real_container_expected loc e.emeta.type_ |> error + + let verify_target_pe_usage loc cf = + if cf.in_lp_fun_def || cf.current_block = Model then () + else Semantic_error.target_plusequals_outisde_model_or_logprob loc |> error + + let check_target_pe loc cf tenv e = + let te = check_expression cf tenv e in + verify_target_pe_usage loc cf ; + verify_target_pe_expr_type loc te ; + mk_typed_statement ~stmt:(TargetPE te) ~return_type:NoReturnType ~loc -and verify_fundef_overloaded loc tenv id arg_tys rt = - if exists_matching_fn_declared tenv id arg_tys rt then - (* this is the definition to an existing forward declaration *) - () - else - (* this should be an overload with a unique signature *) - verify_unique_signature tenv loc id arg_tys rt ; - verify_name_fresh tenv id ~is_udf:true - -and get_fn_decl_or_defn loc tenv id arg_tys rt body = - match body with - | {stmt= Skip; _} -> - if exists_matching_fn_declared tenv id arg_tys rt then - Semantic_error.fn_decl_without_def loc |> error - else `UserDeclared id.id_loc - | _ -> `UserDefined - -and verify_fundef_dist_rt loc id return_ty = - let is_dist = - List.exists - ~f:(fun x -> String.is_suffix id.name ~suffix:x) - Utils.conditioning_suffices_w_log in - if is_dist then - match return_ty with - | UnsizedType.ReturnType UReal -> () - | _ -> Semantic_error.non_real_prob_fn_def loc |> error - -and verify_pdf_fundef_first_arg_ty loc id arg_tys = - if String.is_suffix id.name ~suffix:"_lpdf" then - let rt = List.hd arg_tys |> Option.map ~f:snd in - match rt with - | Some rt when not (UnsizedType.is_int_type rt) -> () - | _ -> Semantic_error.prob_density_non_real_variate loc rt |> error + let check_incr_logprob loc cf tenv e = + let te = check_expression cf tenv e in + verify_target_pe_usage loc cf ; + verify_target_pe_expr_type loc te ; + mk_typed_statement ~stmt:(IncrementLogProb te) ~return_type:NoReturnType + ~loc -and verify_pmf_fundef_first_arg_ty loc id arg_tys = - if String.is_suffix id.name ~suffix:"_lpmf" then - let rt = List.hd arg_tys |> Option.map ~f:snd in - match rt with - | Some rt when UnsizedType.is_int_type rt -> () - | _ -> Semantic_error.prob_mass_non_int_variate loc rt |> error - -and verify_fundef_distinct_arg_ids loc arg_names = - let dup_exists l = - List.find_a_dup ~compare:String.compare l |> Option.is_some in - if dup_exists arg_names then Semantic_error.duplicate_arg_names loc |> error - -and verify_fundef_return_tys loc return_type body = - if - body.stmt = Skip - || is_of_compatible_return_type return_type body.smeta.return_type - then () - else Semantic_error.incompatible_return_types loc |> error - -and add_function tenv name type_ defined = - (* if we're providing a definition, we remove prior declarations - to simplify the environment *) - if defined = `UserDefined then - let existing_defns = Env.find tenv name in - let defns = - List.filter - ~f:(function - | Env.{kind= `UserDeclared _; type_= type'} when type' = type_ -> - false - | _ -> true ) - existing_defns in - let new_fn = Env.{kind= `UserDefined; type_} in - Env.set_raw tenv name (new_fn :: defns) - else Env.add tenv name type_ defined - -and check_fundef loc cf tenv return_ty id args body = - List.iter args ~f:(fun (_, _, id) -> verify_identifier id) ; - verify_identifier id ; - let arg_types = List.map ~f:(fun (w, y, _) -> (w, y)) args in - let arg_identifiers = List.map ~f:(fun (_, _, z) -> z) args in - let arg_names = List.map ~f:(fun x -> x.name) arg_identifiers in - verify_fundef_overloaded loc tenv id arg_types return_ty ; - let defined = get_fn_decl_or_defn loc tenv id arg_types return_ty body in - verify_fundef_dist_rt loc id return_ty ; - verify_pdf_fundef_first_arg_ty loc id arg_types ; - verify_pmf_fundef_first_arg_ty loc id arg_types ; - let tenv = - add_function tenv id.name - (UFun (arg_types, return_ty, Fun_kind.suffix_from_name id.name, AoS)) - defined in - List.iter - ~f:(fun id -> verify_name_fresh tenv id ~is_udf:false) - arg_identifiers ; - verify_fundef_distinct_arg_ids loc arg_names ; - (* We treat DataOnly arguments as if they are data and AutoDiffable arguments - as if they are parameters, for the purposes of type checking. + (* tilde/sampling notation*) + let verify_sampling_pdf_pmf id = + if + String.( + is_suffix id.name ~suffix:"_lpdf" + || is_suffix id.name ~suffix:"_lpmf" + || is_suffix id.name ~suffix:"_lupdf" + || is_suffix id.name ~suffix:"_lupmf") + then Semantic_error.invalid_sampling_pdf_or_pmf id.id_loc |> error + + let verify_sampling_cdf_ccdf loc id = + if + String.( + is_suffix id.name ~suffix:"_cdf" || is_suffix id.name ~suffix:"_ccdf") + then Semantic_error.invalid_sampling_cdf_or_ccdf loc id.name |> error + + (* Target+= can only be used in model and functions with right suffix (same for tilde etc) *) + let verify_valid_sampling_pos loc cf = + if cf.in_lp_fun_def || cf.current_block = Model then () + else Semantic_error.target_plusequals_outisde_model_or_logprob loc |> error + + let verify_sampling_distribution loc tenv id arguments = + let name = id.name in + let argumenttypes = List.map ~f:arg_type arguments in + let name_w_suffix_sampling_dist suffix = + SignatureMismatch.matching_function tenv (name ^ suffix) argumenttypes + in + let sampling_dists = + List.map ~f:name_w_suffix_sampling_dist Utils.distribution_suffices in + let is_sampling_dist_defined = + List.exists + ~f:(function UniqueMatch (ReturnType UReal, _, _) -> true | _ -> false) + sampling_dists + && name <> "binomial_coefficient" + && name <> "multiply" in + if is_sampling_dist_defined then () + else + match + List.max_elt sampling_dists + ~compare:SignatureMismatch.compare_match_results + with + | None | Some (UniqueMatch _) | Some (SignatureErrors ([], _)) -> + (* Either non-existant or a very odd case, + output the old non-informative error *) + Semantic_error.invalid_sampling_no_such_dist loc name |> error + | Some (AmbiguousMatch sigs) -> + Semantic_error.ambiguous_function_promotion loc id.name + (Some (List.map ~f:type_of_expr_typed arguments)) + sigs + |> error + | Some (SignatureErrors (l, b)) -> + arguments + |> List.map ~f:(fun e -> e.emeta.type_) + |> Semantic_error.illtyped_fn_app loc id.name (l, b) + |> error + + let is_cumulative_density_defined tenv id arguments = + let name = id.name in + let argumenttypes = List.map ~f:arg_type arguments in + let valid_arg_types_for_suffix suffix = + match + SignatureMismatch.matching_function tenv (name ^ suffix) argumenttypes + with + | UniqueMatch (ReturnType UReal, _, _) -> true + | _ -> false in + (valid_arg_types_for_suffix "_lcdf" || valid_arg_types_for_suffix "_cdf_log") + && ( valid_arg_types_for_suffix "_lccdf" + || valid_arg_types_for_suffix "_ccdf_log" ) + + let verify_can_truncate_distribution loc (arg : typed_expression) = function + | NoTruncate -> () + | _ -> + if UnsizedType.is_scalar_type arg.emeta.type_ then () + else Semantic_error.multivariate_truncation loc |> error + + let verify_sampling_cdf_defined loc tenv id truncation args = + let check e = is_cumulative_density_defined tenv id (e :: args) in + match truncation with + | NoTruncate -> () + | (TruncateUpFrom e | TruncateDownFrom e) when check e -> () + | TruncateBetween (e1, e2) when check e1 && check e2 -> () + | _ -> Semantic_error.invalid_truncation_cdf_or_ccdf loc |> error + + let check_truncation cf tenv truncation = + let check e = + check_expression_of_int_or_real_type cf tenv e "Truncation bound" in + match truncation with + | NoTruncate -> NoTruncate + | TruncateUpFrom e -> check e |> TruncateUpFrom + | TruncateDownFrom e -> check e |> TruncateDownFrom + | TruncateBetween (e1, e2) -> (check e1, check e2) |> TruncateBetween + + let check_tilde loc cf tenv distribution truncation arg args = + let te = check_expression cf tenv arg in + let tes = List.map ~f:(check_expression cf tenv) args in + let ttrunc = check_truncation cf tenv truncation in + verify_identifier distribution ; + verify_sampling_pdf_pmf distribution ; + verify_valid_sampling_pos loc cf ; + verify_sampling_cdf_ccdf loc distribution ; + verify_sampling_distribution loc tenv distribution (te :: tes) ; + verify_sampling_cdf_defined loc tenv distribution ttrunc tes ; + verify_can_truncate_distribution loc te ttrunc ; + let stmt = Tilde {arg= te; distribution; args= tes; truncation= ttrunc} in + mk_typed_statement ~stmt ~loc ~return_type:NoReturnType + + (* Break and continue only occur in loops. *) + let check_break loc cf = + if cf.loop_depth = 0 then Semantic_error.break_outside_loop loc |> error + else mk_typed_statement ~stmt:Break ~return_type:NoReturnType ~loc + + let check_continue loc cf = + if cf.loop_depth = 0 then Semantic_error.continue_outside_loop loc |> error + else mk_typed_statement ~stmt:Continue ~return_type:NoReturnType ~loc + + let check_return loc cf tenv e = + if not cf.in_returning_fun_def then + Semantic_error.expression_return_outside_returning_fn loc |> error + else + let te = check_expression cf tenv e in + mk_typed_statement ~stmt:(Return te) + ~return_type:(Complete (ReturnType te.emeta.type_)) ~loc + + let check_returnvoid loc cf = + if (not cf.in_fun_def) || cf.in_returning_fun_def then + Semantic_error.void_ouside_nonreturning_fn loc |> error + else mk_typed_statement ~stmt:ReturnVoid ~return_type:(Complete Void) ~loc + + let check_printable cf tenv = function + | PString s -> PString s + (* Print/reject expressions cannot be of function type. *) + | PExpr e -> ( + let te = check_expression cf tenv e in + match te.emeta.type_ with + | UFun _ | UMathLibraryFunction -> + Semantic_error.not_printable te.emeta.loc |> error + | _ -> PExpr te ) + + let check_print loc cf tenv ps = + let tps = List.map ~f:(check_printable cf tenv) ps in + mk_typed_statement ~stmt:(Print tps) ~return_type:NoReturnType ~loc + + let check_reject loc cf tenv ps = + let tps = List.map ~f:(check_printable cf tenv) ps in + mk_typed_statement ~stmt:(Reject tps) ~return_type:AnyReturnType ~loc + + let check_skip loc = + mk_typed_statement ~stmt:Skip ~return_type:NoReturnType ~loc + + let rec stmt_is_escape {stmt; _} = + match stmt with + | Break | Continue | Reject _ | Return _ | ReturnVoid -> true + | _ -> false + + and list_until_escape xs = + let rec aux accu = function + | [next; next'] when stmt_is_escape next' -> + List.rev (next' :: next :: accu) + | next :: next' :: unreachable :: _ when stmt_is_escape next' -> + add_warning unreachable.smeta.loc + "Unreachable statement (following a reject, break, continue, or \ + return) found, is this intended?" ; + List.rev (next' :: next :: accu) + | next :: rest -> aux (next :: accu) rest + | [] -> List.rev accu in + aux [] xs + + let returntype_leastupperbound loc rt1 rt2 = + match (rt1, rt2) with + | UnsizedType.ReturnType UReal, UnsizedType.ReturnType UInt + |ReturnType UInt, ReturnType UReal -> + UnsizedType.ReturnType UReal + | _, _ when rt1 = rt2 -> rt2 + | _ -> Semantic_error.mismatched_return_types loc rt1 rt2 |> error + + let try_compute_block_statement_returntype loc srt1 srt2 = + match (srt1, srt2) with + | Complete rt1, Complete rt2 | Incomplete rt1, Complete rt2 -> + Complete (returntype_leastupperbound loc rt1 rt2) + | Incomplete rt1, Incomplete rt2 | Complete rt1, Incomplete rt2 -> + Incomplete (returntype_leastupperbound loc rt1 rt2) + | NoReturnType, NoReturnType -> NoReturnType + | AnyReturnType, Incomplete rt + |Complete rt, NoReturnType + |NoReturnType, Incomplete rt + |Incomplete rt, NoReturnType -> + Incomplete rt + | NoReturnType, Complete rt + |Complete rt, AnyReturnType + |Incomplete rt, AnyReturnType + |AnyReturnType, Complete rt -> + Complete rt + | AnyReturnType, NoReturnType + |NoReturnType, AnyReturnType + |AnyReturnType, AnyReturnType -> + AnyReturnType + + let try_compute_ifthenelse_statement_returntype loc srt1 srt2 = + match (srt1, srt2) with + | Complete rt1, Complete rt2 -> + returntype_leastupperbound loc rt1 rt2 |> Complete + | Incomplete rt1, Incomplete rt2 + |Complete rt1, Incomplete rt2 + |Incomplete rt1, Complete rt2 -> + returntype_leastupperbound loc rt1 rt2 |> Incomplete + | AnyReturnType, NoReturnType + |NoReturnType, AnyReturnType + |NoReturnType, NoReturnType -> + NoReturnType + | AnyReturnType, Incomplete rt + |Incomplete rt, AnyReturnType + |Complete rt, NoReturnType + |NoReturnType, Complete rt + |NoReturnType, Incomplete rt + |Incomplete rt, NoReturnType -> + Incomplete rt + | Complete rt, AnyReturnType | AnyReturnType, Complete rt -> Complete rt + | AnyReturnType, AnyReturnType -> AnyReturnType + + (* statements which contain statements, and therefore need to be mutually recursive + with check_statement *) - let arg_types_internal = - List.map - ~f:(function - | UnsizedType.DataOnly, ut -> (Env.Data, ut) - | AutoDiffable, ut -> (Param, ut) ) - arg_types in - let tenv_body = - List.fold2_exn arg_names arg_types_internal ~init:tenv - ~f:(fun env name (origin, typ) -> - Env.add env name typ - (* readonly so that function args and loop identifiers - are not modified in function. (passed by const ref) *) - (`Variable {origin; readonly= true; global= false}) ) in - let context = - let is_udf_dist name = + let rec check_if_then_else loc cf tenv pred_e s_true s_false_opt = + (* we don't need these nested type environments *) + let _, ts_true = check_statement cf tenv s_true in + let ts_false_opt = + s_false_opt + |> Option.map ~f:(check_statement cf tenv) + |> Option.map ~f:snd in + let te = + check_expression_of_int_or_real_type cf tenv pred_e + "Condition in conditional" in + let stmt = IfThenElse (te, ts_true, ts_false_opt) in + let srt1 = ts_true.smeta.return_type in + let srt2 = + ts_false_opt + |> Option.map ~f:(fun s -> s.smeta.return_type) + |> Option.value ~default:NoReturnType in + let return_type = + try_compute_ifthenelse_statement_returntype loc srt1 srt2 in + mk_typed_statement ~stmt ~return_type ~loc + + and check_while loc cf tenv cond_e loop_body = + let _, ts = + check_statement {cf with loop_depth= cf.loop_depth + 1} tenv loop_body + and te = + check_expression_of_int_or_real_type cf tenv cond_e + "Condition in while-loop" in + mk_typed_statement + ~stmt:(While (te, ts)) + ~return_type:ts.smeta.return_type ~loc + + and check_for loc cf tenv loop_var lower_bound_e upper_bound_e loop_body = + let te1 = + check_expression_of_int_type cf tenv lower_bound_e + "Lower bound of for-loop" + and te2 = + check_expression_of_int_type cf tenv upper_bound_e + "Upper bound of for-loop" in + verify_identifier loop_var ; + let ts = check_loop_body cf tenv loop_var UnsizedType.UInt loop_body in + mk_typed_statement + ~stmt: + (For + { loop_variable= loop_var + ; lower_bound= te1 + ; upper_bound= te2 + ; loop_body= ts } ) + ~return_type:ts.smeta.return_type ~loc + + and check_foreach_loop_identifier_type loc ty = + match ty with + | UnsizedType.UArray ut -> ut + | UVector | URowVector | UMatrix -> UnsizedType.UReal + | _ -> Semantic_error.array_vector_rowvector_matrix_expected loc ty |> error + + and check_foreach loc cf tenv loop_var foreach_e loop_body = + let te = check_expression cf tenv foreach_e in + verify_identifier loop_var ; + let loop_var_ty = + check_foreach_loop_identifier_type te.emeta.loc te.emeta.type_ in + let ts = check_loop_body cf tenv loop_var loop_var_ty loop_body in + mk_typed_statement + ~stmt:(ForEach (loop_var, te, ts)) + ~return_type:ts.smeta.return_type ~loc + + and check_loop_body cf tenv loop_var loop_var_ty loop_body = + verify_name_fresh tenv loop_var ~is_udf:false ; + (* Add to type environment as readonly. + Check that function args and loop identifiers are not modified in + function. (passed by const ref) + *) + let tenv = + Env.add tenv loop_var.name loop_var_ty + (`Variable {origin= cf.current_block; global= false; readonly= true}) + in + snd (check_statement {cf with loop_depth= cf.loop_depth + 1} tenv loop_body) + + and check_block loc cf tenv stmts = + let _, checked_stmts = + List.fold_map stmts ~init:tenv ~f:(check_statement cf) in + let return_type = + checked_stmts |> list_until_escape + |> List.map ~f:(fun s -> s.smeta.return_type) + |> List.fold ~init:NoReturnType + ~f:(try_compute_block_statement_returntype loc) in + mk_typed_statement ~stmt:(Block checked_stmts) ~return_type ~loc + + and check_profile loc cf tenv name stmts = + let _, checked_stmts = + List.fold_map stmts ~init:tenv ~f:(check_statement cf) in + let return_type = + checked_stmts |> list_until_escape + |> List.map ~f:(fun s -> s.smeta.return_type) + |> List.fold ~init:NoReturnType + ~f:(try_compute_block_statement_returntype loc) in + mk_typed_statement ~stmt:(Profile (name, checked_stmts)) ~return_type ~loc + + (* variable declarations *) + and verify_valid_transformation_for_type loc is_global sized_ty trans = + let is_real {emeta; _} = emeta.type_ = UReal in + let is_real_transformation = + match trans with + | Transformation.Lower e -> is_real e + | Upper e -> is_real e + | LowerUpper (e1, e2) -> is_real e1 || is_real e2 + | _ -> false in + if is_global && sized_ty = SizedType.SInt && is_real_transformation then + Semantic_error.non_int_bounds loc |> error ; + let is_transformation = + match trans with Transformation.Identity -> false | _ -> true in + if is_global && SizedType.(contains_complex sized_ty) && is_transformation + then Semantic_error.complex_transform loc |> error + + and verify_transformed_param_ty loc cf is_global unsized_ty = + if + is_global + && (cf.current_block = Param || cf.current_block = TParam) + && UnsizedType.is_int_type unsized_ty + then Semantic_error.transformed_params_int loc |> error + + and check_sizedtype cf tenv sizedty = + let check e msg = check_expression_of_int_type cf tenv e msg in + match sizedty with + | SizedType.SInt -> SizedType.SInt + | SReal -> SReal + | SComplex -> SComplex + | SVector (mem_pattern, e) -> + let te = check e "Vector sizes" in + SVector (mem_pattern, te) + | SRowVector (mem_pattern, e) -> + let te = check e "Row vector sizes" in + SRowVector (mem_pattern, te) + | SMatrix (mem_pattern, e1, e2) -> + let te1 = check e1 "Matrix row size" in + let te2 = check e2 "Matrix column size" in + SMatrix (mem_pattern, te1, te2) + | SComplexVector e -> + let te = check e "complex vector sizes" in + SComplexVector te + | SComplexRowVector e -> + let te = check e "complex row vector sizes" in + SComplexRowVector te + | SComplexMatrix (e1, e2) -> + let te1 = check e1 "Complex matrix row size" in + let te2 = check e2 "Complex matrix column size" in + SComplexMatrix (te1, te2) + | SArray (st, e) -> + let tst = check_sizedtype cf tenv st in + let te = check e "Array sizes" in + SArray (tst, te) + + and check_var_decl_initial_value loc cf tenv id init_val_opt = + match init_val_opt with + | Some e -> ( + let lhs = check_lvalue cf tenv {lval= LVariable id; lmeta= {loc}} in + let rhs = check_expression cf tenv e in + match + SignatureMismatch.check_of_same_type_mod_conv lhs.lmeta.type_ + rhs.emeta.type_ + with + | Ok p -> Some (Promotion.promote rhs p) + | Error _ -> + Semantic_error.illtyped_assignment loc Equals lhs.lmeta.type_ + rhs.emeta.type_ + |> error ) + | None -> None + + and check_transformation cf tenv ut trans = + let check e msg = check_expression_of_scalar_or_type cf tenv ut e msg in + match trans with + | Transformation.Identity -> Transformation.Identity + | Lower e -> check e "Lower bound" |> Lower + | Upper e -> check e "Upper bound" |> Upper + | LowerUpper (e1, e2) -> + (check e1 "Lower bound", check e2 "Upper bound") |> LowerUpper + | Offset e -> check e "Offset" |> Offset + | Multiplier e -> check e "Multiplier" |> Multiplier + | OffsetMultiplier (e1, e2) -> + (check e1 "Offset", check e2 "Multiplier") |> OffsetMultiplier + | Ordered -> Ordered + | PositiveOrdered -> PositiveOrdered + | Simplex -> Simplex + | UnitVector -> UnitVector + | CholeskyCorr -> CholeskyCorr + | CholeskyCov -> CholeskyCov + | Correlation -> Correlation + | Covariance -> Covariance + + and check_var_decl loc cf tenv sized_ty trans id init is_global = + let checked_type = + check_sizedtype {cf with in_toplevel_decl= is_global} tenv sized_ty in + let unsized_type = SizedType.to_unsized checked_type in + let checked_trans = check_transformation cf tenv unsized_type trans in + verify_identifier id ; + verify_name_fresh tenv id ~is_udf:false ; + let tenv = + Env.add tenv id.name unsized_type + (`Variable + {origin= cf.current_block; global= is_global; readonly= false} ) in + let tinit = check_var_decl_initial_value loc cf tenv id init in + verify_valid_transformation_for_type loc is_global checked_type + checked_trans ; + verify_transformed_param_ty loc cf is_global unsized_type ; + let stmt = + VarDecl + { decl_type= Sized checked_type + ; transformation= checked_trans + ; identifier= id + ; initial_value= tinit + ; is_global } in + (tenv, mk_typed_statement ~stmt ~loc ~return_type:NoReturnType) + + (* function definitions *) + and exists_matching_fn_declared tenv id arg_tys rt = + let options = + List.concat_map ~f:(Env.find tenv) (distribution_name_variants id.name) + in + let f = function + | Env.{kind= `UserDeclared _; type_= UFun (listedtypes, rt', _, _)} + when arg_tys = listedtypes && rt = rt' -> + true + | _ -> false in + List.exists ~f options + + and verify_unique_signature tenv loc id arg_tys rt = + let existing = + List.concat_map ~f:(Env.find tenv) (distribution_name_variants id.name) + in + let same_args = function + | Env.{type_= UFun (listedtypes, _, _, _); _} + when List.map ~f:snd arg_tys = List.map ~f:snd listedtypes -> + true + | _ -> false in + match List.filter existing ~f:same_args with + | [] -> () + | {type_= UFun (_, rt', _, _); _} :: _ when rt <> rt' -> + Semantic_error.fn_overload_rt_only loc id.name rt rt' |> error + | {kind; _} :: _ -> + Semantic_error.fn_decl_redefined loc id.name + ~stan_math:(kind = `StanMath) + (UnsizedType.UFun (arg_tys, rt, Fun_kind.suffix_from_name id.name, AoS) + ) + |> error + + and verify_fundef_overloaded loc tenv id arg_tys rt = + if exists_matching_fn_declared tenv id arg_tys rt then + (* this is the definition to an existing forward declaration *) + () + else + (* this should be an overload with a unique signature *) + verify_unique_signature tenv loc id arg_tys rt ; + verify_name_fresh tenv id ~is_udf:true + + and get_fn_decl_or_defn loc tenv id arg_tys rt body = + match body with + | {stmt= Skip; _} -> + if exists_matching_fn_declared tenv id arg_tys rt then + Semantic_error.fn_decl_without_def loc |> error + else `UserDeclared id.id_loc + | _ -> `UserDefined + + and verify_fundef_dist_rt loc id return_ty = + let is_dist = List.exists - ~f:(fun suffix -> String.is_suffix name ~suffix) - Utils.distribution_suffices in - { cf with - in_fun_def= true - ; in_rng_fun_def= String.is_suffix id.name ~suffix:"_rng" - ; in_lp_fun_def= String.is_suffix id.name ~suffix:"_lp" - ; in_udf_dist_def= is_udf_dist id.name - ; in_returning_fun_def= return_ty <> Void } in - let _, checked_body = check_statement context tenv_body body in - verify_fundef_return_tys loc return_ty checked_body ; - let stmt = - FunDef - {returntype= return_ty; funname= id; arguments= args; body= checked_body} - in - (* NB: **not** tenv_body, so args don't leak out *) - (tenv, mk_typed_statement ~return_type:NoReturnType ~loc ~stmt) - -and check_statement (cf : context_flags_record) (tenv : Env.t) - (s : Ast.untyped_statement) : Env.t * typed_statement = - let loc = s.smeta.loc in - match s.stmt with - | NRFunApp (_, id, es) -> (tenv, check_nr_fn_app loc cf tenv id es) - | Assignment {assign_lhs; assign_op; assign_rhs} -> - (tenv, check_assignment loc cf tenv assign_lhs assign_op assign_rhs) - | TargetPE e -> (tenv, check_target_pe loc cf tenv e) - | IncrementLogProb e -> (tenv, check_incr_logprob loc cf tenv e) - | Tilde {arg; distribution; args; truncation} -> - (tenv, check_tilde loc cf tenv distribution truncation arg args) - | Break -> (tenv, check_break loc cf) - | Continue -> (tenv, check_continue loc cf) - | Return e -> (tenv, check_return loc cf tenv e) - | ReturnVoid -> (tenv, check_returnvoid loc cf) - | Print ps -> (tenv, check_print loc cf tenv ps) - | Reject ps -> (tenv, check_reject loc cf tenv ps) - | Skip -> (tenv, check_skip loc) - (* the following can contain further statements *) - | IfThenElse (e, s1, os2) -> (tenv, check_if_then_else loc cf tenv e s1 os2) - | While (e, s) -> (tenv, check_while loc cf tenv e s) - | For {loop_variable; lower_bound; upper_bound; loop_body} -> - ( tenv - , check_for loc cf tenv loop_variable lower_bound upper_bound loop_body ) - | ForEach (id, e, s) -> (tenv, check_foreach loc cf tenv id e s) - | Block stmts -> (tenv, check_block loc cf tenv stmts) - | Profile (name, vdsl) -> (tenv, check_profile loc cf tenv name vdsl) - | VarDecl {decl_type= Unsized _; _} -> - (* currently unallowed by parser *) + ~f:(fun x -> String.is_suffix id.name ~suffix:x) + Utils.conditioning_suffices_w_log in + if is_dist then + match return_ty with + | UnsizedType.ReturnType UReal -> () + | _ -> Semantic_error.non_real_prob_fn_def loc |> error + + and verify_pdf_fundef_first_arg_ty loc id arg_tys = + if String.is_suffix id.name ~suffix:"_lpdf" then + let rt = List.hd arg_tys |> Option.map ~f:snd in + match rt with + | Some rt when not (UnsizedType.is_int_type rt) -> () + | _ -> Semantic_error.prob_density_non_real_variate loc rt |> error + + and verify_pmf_fundef_first_arg_ty loc id arg_tys = + if String.is_suffix id.name ~suffix:"_lpmf" then + let rt = List.hd arg_tys |> Option.map ~f:snd in + match rt with + | Some rt when UnsizedType.is_int_type rt -> () + | _ -> Semantic_error.prob_mass_non_int_variate loc rt |> error + + and verify_fundef_distinct_arg_ids loc arg_names = + let dup_exists l = + List.find_a_dup ~compare:String.compare l |> Option.is_some in + if dup_exists arg_names then Semantic_error.duplicate_arg_names loc |> error + + and verify_fundef_return_tys loc return_type body = + if + body.stmt = Skip + || is_of_compatible_return_type return_type body.smeta.return_type + then () + else Semantic_error.incompatible_return_types loc |> error + + and add_function tenv name type_ defined = + (* if we're providing a definition, we remove prior declarations + to simplify the environment *) + if defined = `UserDefined then + let existing_defns = Env.find tenv name in + let defns = + List.filter + ~f:(function + | Env.{kind= `UserDeclared _; type_= type'} when type' = type_ -> + false + | _ -> true ) + existing_defns in + let new_fn = Env.{kind= `UserDefined; type_} in + Env.set_raw tenv name (new_fn :: defns) + else Env.add tenv name type_ defined + + and check_fundef loc cf tenv return_ty id args body = + List.iter args ~f:(fun (_, _, id) -> verify_identifier id) ; + verify_identifier id ; + let arg_types = List.map ~f:(fun (w, y, _) -> (w, y)) args in + let arg_identifiers = List.map ~f:(fun (_, _, z) -> z) args in + let arg_names = List.map ~f:(fun x -> x.name) arg_identifiers in + verify_fundef_overloaded loc tenv id arg_types return_ty ; + let defined = get_fn_decl_or_defn loc tenv id arg_types return_ty body in + verify_fundef_dist_rt loc id return_ty ; + verify_pdf_fundef_first_arg_ty loc id arg_types ; + verify_pmf_fundef_first_arg_ty loc id arg_types ; + let tenv = + add_function tenv id.name + (UFun (arg_types, return_ty, Fun_kind.suffix_from_name id.name, AoS)) + defined in + List.iter + ~f:(fun id -> verify_name_fresh tenv id ~is_udf:false) + arg_identifiers ; + verify_fundef_distinct_arg_ids loc arg_names ; + (* We treat DataOnly arguments as if they are data and AutoDiffable arguments + as if they are parameters, for the purposes of type checking. + *) + let arg_types_internal = + List.map + ~f:(function + | UnsizedType.DataOnly, ut -> (Env.Data, ut) + | AutoDiffable, ut -> (Param, ut) ) + arg_types in + let tenv_body = + List.fold2_exn arg_names arg_types_internal ~init:tenv + ~f:(fun env name (origin, typ) -> + Env.add env name typ + (* readonly so that function args and loop identifiers + are not modified in function. (passed by const ref) *) + (`Variable {origin; readonly= true; global= false}) ) in + let context = + let is_udf_dist name = + List.exists + ~f:(fun suffix -> String.is_suffix name ~suffix) + Utils.distribution_suffices in + { cf with + in_fun_def= true + ; in_rng_fun_def= String.is_suffix id.name ~suffix:"_rng" + ; in_lp_fun_def= String.is_suffix id.name ~suffix:"_lp" + ; in_udf_dist_def= is_udf_dist id.name + ; in_returning_fun_def= return_ty <> Void } in + let _, checked_body = check_statement context tenv_body body in + verify_fundef_return_tys loc return_ty checked_body ; + let stmt = + FunDef + {returntype= return_ty; funname= id; arguments= args; body= checked_body} + in + (* NB: **not** tenv_body, so args don't leak out *) + (tenv, mk_typed_statement ~return_type:NoReturnType ~loc ~stmt) + + and check_statement (cf : context_flags_record) (tenv : Env.t) + (s : Ast.untyped_statement) : Env.t * typed_statement = + let loc = s.smeta.loc in + match s.stmt with + | NRFunApp (_, id, es) -> (tenv, check_nr_fn_app loc cf tenv id es) + | Assignment {assign_lhs; assign_op; assign_rhs} -> + (tenv, check_assignment loc cf tenv assign_lhs assign_op assign_rhs) + | TargetPE e -> (tenv, check_target_pe loc cf tenv e) + | IncrementLogProb e -> (tenv, check_incr_logprob loc cf tenv e) + | Tilde {arg; distribution; args; truncation} -> + (tenv, check_tilde loc cf tenv distribution truncation arg args) + | Break -> (tenv, check_break loc cf) + | Continue -> (tenv, check_continue loc cf) + | Return e -> (tenv, check_return loc cf tenv e) + | ReturnVoid -> (tenv, check_returnvoid loc cf) + | Print ps -> (tenv, check_print loc cf tenv ps) + | Reject ps -> (tenv, check_reject loc cf tenv ps) + | Skip -> (tenv, check_skip loc) + (* the following can contain further statements *) + | IfThenElse (e, s1, os2) -> (tenv, check_if_then_else loc cf tenv e s1 os2) + | While (e, s) -> (tenv, check_while loc cf tenv e s) + | For {loop_variable; lower_bound; upper_bound; loop_body} -> + ( tenv + , check_for loc cf tenv loop_variable lower_bound upper_bound loop_body + ) + | ForEach (id, e, s) -> (tenv, check_foreach loc cf tenv id e s) + | Block stmts -> (tenv, check_block loc cf tenv stmts) + | Profile (name, vdsl) -> (tenv, check_profile loc cf tenv name vdsl) + | VarDecl {decl_type= Unsized _; _} -> + (* currently unallowed by parser *) + Common.FatalError.fatal_error_msg + [%message "Don't support unsized declarations yet."] + (* these two are special in that they're allowed to change the type environment *) + | VarDecl + { decl_type= Sized st + ; transformation + ; identifier + ; initial_value + ; is_global } -> + check_var_decl loc cf tenv st transformation identifier initial_value + is_global + | FunDef {returntype; funname; arguments; body} -> + check_fundef loc cf tenv returntype funname arguments body + + let verify_fun_def_body_in_block = function + | {stmt= FunDef {body= {stmt= Block _; _}; _}; _} + |{stmt= FunDef {body= {stmt= Skip; _}; _}; _} -> + () + | {stmt= FunDef {body= {stmt= _; smeta}; _}; _} -> + Semantic_error.fn_decl_needs_block smeta.loc |> error + | _ -> () + + let verify_functions_have_defn tenv function_block_stmts_opt = + let error_on_undefined funs = + List.iter funs ~f:(fun f -> + match f with + | Env.{kind= `UserDeclared loc; _} -> + Semantic_error.fn_decl_without_def loc |> error + | _ -> () ) in + if !check_that_all_functions_have_definition then + Env.iter tenv error_on_undefined ; + match function_block_stmts_opt with + | Some {stmts= []; _} | None -> () + | Some {stmts= ls; _} -> List.iter ~f:verify_fun_def_body_in_block ls + + let check_toplevel_block block tenv stmts_opt = + let cf = context block in + match stmts_opt with + | Some {stmts; xloc} -> + let tenv', stmts = + List.fold_map stmts ~init:tenv ~f:(check_statement cf) in + (tenv', Some {stmts; xloc}) + | None -> (tenv, None) + + let verify_correctness_invariant (ast : untyped_program) + (decorated_ast : typed_program) = + let detyped = untyped_program_of_typed_program decorated_ast in + if compare_untyped_program ast detyped = 0 then () + else Common.FatalError.fatal_error_msg - [%message "Don't support unsized declarations yet."] - (* these two are special in that they're allowed to change the type environment *) - | VarDecl - {decl_type= Sized st; transformation; identifier; initial_value; is_global} - -> - check_var_decl loc cf tenv st transformation identifier initial_value - is_global - | FunDef {returntype; funname; arguments; body} -> - check_fundef loc cf tenv returntype funname arguments body - -let verify_fun_def_body_in_block = function - | {stmt= FunDef {body= {stmt= Block _; _}; _}; _} - |{stmt= FunDef {body= {stmt= Skip; _}; _}; _} -> - () - | {stmt= FunDef {body= {stmt= _; smeta}; _}; _} -> - Semantic_error.fn_decl_needs_block smeta.loc |> error - | _ -> () - -let verify_functions_have_defn tenv function_block_stmts_opt = - let error_on_undefined funs = - List.iter funs ~f:(fun f -> - match f with - | Env.{kind= `UserDeclared loc; _} -> - Semantic_error.fn_decl_without_def loc |> error - | _ -> () ) in - if !check_that_all_functions_have_definition then - Env.iter tenv error_on_undefined ; - match function_block_stmts_opt with - | Some {stmts= []; _} | None -> () - | Some {stmts= ls; _} -> List.iter ~f:verify_fun_def_body_in_block ls - -let check_toplevel_block block tenv stmts_opt = - let cf = context block in - match stmts_opt with - | Some {stmts; xloc} -> - let tenv', stmts = - List.fold_map stmts ~init:tenv ~f:(check_statement cf) in - (tenv', Some {stmts; xloc}) - | None -> (tenv, None) - -let verify_correctness_invariant (ast : untyped_program) - (decorated_ast : typed_program) = - let detyped = untyped_program_of_typed_program decorated_ast in - if compare_untyped_program ast detyped = 0 then () - else - Common.FatalError.fatal_error_msg - [%message - "Type checked AST does not match original AST. " - (detyped : untyped_program) - (ast : untyped_program)] - -let check_program_exn - ( { functionblock= fb - ; datablock= db - ; transformeddatablock= tdb - ; parametersblock= pb - ; transformedparametersblock= tpb - ; modelblock= mb - ; generatedquantitiesblock= gqb - ; comments } as ast ) = - warnings := [] ; - (* create a new type environment which has only stan-math functions *) - let tenv = Env.stan_math_environment in - let tenv, typed_fb = check_toplevel_block Functions tenv fb in - verify_functions_have_defn tenv typed_fb ; - let tenv, typed_db = check_toplevel_block Data tenv db in - let tenv, typed_tdb = check_toplevel_block TData tenv tdb in - let tenv, typed_pb = check_toplevel_block Param tenv pb in - let tenv, typed_tpb = check_toplevel_block TParam tenv tpb in - let _, typed_mb = check_toplevel_block Model tenv mb in - let _, typed_gqb = check_toplevel_block GQuant tenv gqb in - let prog = - { functionblock= typed_fb - ; datablock= typed_db - ; transformeddatablock= typed_tdb - ; parametersblock= typed_pb - ; transformedparametersblock= typed_tpb - ; modelblock= typed_mb - ; generatedquantitiesblock= typed_gqb - ; comments } in - verify_correctness_invariant ast prog ; - attach_warnings prog - -let check_program ast = - try Result.Ok (check_program_exn ast) - with Errors.SemanticError err -> Result.Error err + [%message + "Type checked AST does not match original AST. " + (detyped : untyped_program) + (ast : untyped_program)] + + let check_program_exn + ( { functionblock= fb + ; datablock= db + ; transformeddatablock= tdb + ; parametersblock= pb + ; transformedparametersblock= tpb + ; modelblock= mb + ; generatedquantitiesblock= gqb + ; comments } as ast ) = + warnings := [] ; + (* create a new type environment which has only stan-math functions *) + let tenv = Env.stan_math_environment in + let tenv, typed_fb = check_toplevel_block Functions tenv fb in + verify_functions_have_defn tenv typed_fb ; + let tenv, typed_db = check_toplevel_block Data tenv db in + let tenv, typed_tdb = check_toplevel_block TData tenv tdb in + let tenv, typed_pb = check_toplevel_block Param tenv pb in + let tenv, typed_tpb = check_toplevel_block TParam tenv tpb in + let _, typed_mb = check_toplevel_block Model tenv mb in + let _, typed_gqb = check_toplevel_block GQuant tenv gqb in + let prog = + { functionblock= typed_fb + ; datablock= typed_db + ; transformeddatablock= typed_tdb + ; parametersblock= typed_pb + ; transformedparametersblock= typed_tpb + ; modelblock= typed_mb + ; generatedquantitiesblock= typed_gqb + ; comments } in + verify_correctness_invariant ast prog ; + attach_warnings prog + + let check_program ast = + try Result.Ok (check_program_exn ast) + with Errors.SemanticError err -> Result.Error err +end diff --git a/src/frontend/Typechecker.mli b/src/frontend/Typechecker.mli index 07662dfe11..21b75b5497 100644 --- a/src/frontend/Typechecker.mli +++ b/src/frontend/Typechecker.mli @@ -15,33 +15,40 @@ open Ast -val check_program_exn : untyped_program -> typed_program * Warnings.t list -(** - Type check a full Stan program. - Can raise [Errors.SemanticError] -*) - -val check_program : - untyped_program -> (typed_program * Warnings.t list, Semantic_error.t) result -(** - The safe version of [check_program_exn]. This catches - all [Errors.SemanticError] exceptions and converts them - into a [Result.t] -*) - -val operator_stan_math_return_type : - Middle.Operator.t - -> (Middle.UnsizedType.autodifftype * Middle.UnsizedType.t) list - -> (Middle.UnsizedType.returntype * Promotion.t list) option - -val stan_math_return_type : - string - -> (Middle.UnsizedType.autodifftype * Middle.UnsizedType.t) list - -> Middle.UnsizedType.returntype option val model_name : string ref (** A reference to hold the model name. Relevant for checking variable - clashes and used in code generation. *) + clashes and used in code generation. *) val check_that_all_functions_have_definition : bool ref (** A switch to determine whether we check that all functions have a definition *) + +module type Typechecker = sig + val check_program_exn : untyped_program -> typed_program * Warnings.t list + (** + Type check a full Stan program. + Can raise [Errors.SemanticError] + *) + + val check_program : + untyped_program + -> (typed_program * Warnings.t list, Semantic_error.t) result + (** + The safe version of [check_program_exn]. This catches + all [Errors.SemanticError] exceptions and converts them + into a [Result.t] + *) + + val operator_stan_math_return_type : + Middle.Operator.t + -> (Middle.UnsizedType.autodifftype * Middle.UnsizedType.t) list + -> (Middle.UnsizedType.returntype * Promotion.t list) option + + val stan_math_return_type : + string + -> (Middle.UnsizedType.autodifftype * Middle.UnsizedType.t) list + -> Middle.UnsizedType.returntype option + +end + +module Typecheck (StdLibrary : Std_library_utils.Library): Typechecker diff --git a/src/middle/Stan_math_signatures.ml b/src/middle/Stan_math_signatures.ml index e7c929dab3..66435f7b84 100644 --- a/src/middle/Stan_math_signatures.ml +++ b/src/middle/Stan_math_signatures.ml @@ -441,14 +441,6 @@ let make_assignmentoperator_stan_math_signatures assop = else [(Void, [(ad1, lhs); (ad2, rhs)], SoA)] | _ -> [] ) -let pp_math_sig ppf (rt, args, mem_pattern) = - UnsizedType.pp ppf (UFun (args, rt, FnPlain, mem_pattern)) - -let pp_math_sigs ppf name = - (Fmt.list ~sep:Fmt.cut pp_math_sig) ppf (get_sigs name) - -let pretty_print_math_sigs = Fmt.str "@[@,%a@]" pp_math_sigs - let string_operator_to_stan_math_fns str = match str with | "Plus__" -> "add" @@ -499,8 +491,11 @@ let pretty_print_all_math_distributions ppf () = let pretty_print_math_lib_operator_sigs op = if op = Operator.IntDivide then - [Fmt.str "@[@,%a@]" pp_math_sig int_divide_type] - else operator_to_stan_math_fns op |> List.map ~f:pretty_print_math_sigs + [Fmt.str "@[@,%a@]" Std_library_utils.pp_math_sig int_divide_type] + else + operator_to_stan_math_fns op + |> List.map + ~f:(Fn.compose Std_library_utils.pretty_print_math_sigs get_sigs) (* -- Some helper definitions to populate stan_math_signatures -- *) let bare_types = diff --git a/src/middle/Stan_math_signatures.mli b/src/middle/Stan_math_signatures.mli index 30b906c4d7..8d3ffb1bc7 100644 --- a/src/middle/Stan_math_signatures.mli +++ b/src/middle/Stan_math_signatures.mli @@ -21,9 +21,6 @@ val stan_math_signatures : (string, signature list) Hashtbl.t val is_stan_math_function_name : string -> bool (** Equivalent to [Hashtbl.mem stan_math_signatures s]*) -(** Pretty printers *) - -val pp_math_sig : signature Fmt.t val pretty_print_all_math_sigs : unit Fmt.t val pretty_print_all_math_distributions : unit Fmt.t From 138326050290b43af668d25ecef9191d06199844 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Wed, 27 Apr 2022 11:57:45 -0400 Subject: [PATCH 02/14] Start parameterizing modules --- src/analysis_and_optimization/Mem_pattern.ml | 999 +++++++++--------- src/analysis_and_optimization/Optimize.ml | 16 +- .../Partial_evaluator.ml | 9 +- src/frontend/Canonicalize.ml | 448 ++++---- src/frontend/Canonicalize.mli | 16 +- src/frontend/Deprecation_analysis.ml | 379 +++---- src/frontend/Deprecation_analysis.mli | 29 +- src/frontend/Environment.ml | 4 +- src/frontend/Environment.mli | 12 +- src/frontend/Semantic_error.ml | 106 +- src/frontend/Semantic_error.mli | 50 +- src/frontend/SignatureMismatch.ml | 3 - src/frontend/SignatureMismatch.mli | 6 - src/frontend/Std_library_utils.ml | 33 +- src/frontend/Typechecker.ml | 57 +- src/frontend/Typechecker.mli | 8 +- src/frontend/dune | 10 +- src/middle/dune | 7 +- .../Stan_math_signatures.ml | 44 +- .../Stan_math_signatures.mli | 2 +- src/stan_math_backend/dune | 2 +- 21 files changed, 1139 insertions(+), 1101 deletions(-) rename src/{middle => stan_math_backend}/Stan_math_signatures.ml (99%) rename src/{middle => stan_math_backend}/Stan_math_signatures.mli (97%) diff --git a/src/analysis_and_optimization/Mem_pattern.ml b/src/analysis_and_optimization/Mem_pattern.ml index 7613ea45d0..5b252626d4 100644 --- a/src/analysis_and_optimization/Mem_pattern.ml +++ b/src/analysis_and_optimization/Mem_pattern.ml @@ -2,74 +2,75 @@ open Core_kernel open Core_kernel.Poly open Middle -(** +module Make (StdLib : Frontend.Std_library_utils.Library) = struct + (** Return a Var expression of the name for each type containing an eigen matrix *) -let rec matrix_set Expr.Fixed.{pattern; meta= Expr.Typed.Meta.{type_; _} as meta} - = - let union_recur exprs = Set.Poly.union_list (List.map exprs ~f:matrix_set) in - if UnsizedType.contains_eigen_type type_ then - match pattern with - | Var s -> Set.Poly.singleton (Dataflow_types.VVar s, meta) - | Lit _ -> Set.Poly.empty - | FunApp (_, exprs) -> - if UnsizedType.contains_eigen_type type_ then union_recur exprs - else Set.Poly.empty - | TernaryIf (_, expr2, expr3) -> union_recur [expr2; expr3] - | Indexed (expr, _) | Promotion (expr, _, _) -> matrix_set expr - | EAnd (expr1, expr2) | EOr (expr1, expr2) -> union_recur [expr1; expr2] - else Set.Poly.empty + let rec matrix_set + Expr.Fixed.{pattern; meta= Expr.Typed.Meta.{type_; _} as meta} = + let union_recur exprs = Set.Poly.union_list (List.map exprs ~f:matrix_set) in + if UnsizedType.contains_eigen_type type_ then + match pattern with + | Var s -> Set.Poly.singleton (Dataflow_types.VVar s, meta) + | Lit _ -> Set.Poly.empty + | FunApp (_, exprs) -> + if UnsizedType.contains_eigen_type type_ then union_recur exprs + else Set.Poly.empty + | TernaryIf (_, expr2, expr3) -> union_recur [expr2; expr3] + | Indexed (expr, _) | Promotion (expr, _, _) -> matrix_set expr + | EAnd (expr1, expr2) | EOr (expr1, expr2) -> union_recur [expr1; expr2] + else Set.Poly.empty -(** + (** Return a set of all types containing autodiffable Eigen matrices in an expression. *) -let query_var_eigen_names (expr : Expr.Typed.t) : string Set.Poly.t = - let get_expr_eigen_names - (Dataflow_types.VVar s, Expr.Typed.Meta.{adlevel; type_; _}) = - if - UnsizedType.contains_eigen_type type_ - && UnsizedType.is_autodifftype adlevel - then Some s - else None in - Set.Poly.filter_map ~f:get_expr_eigen_names (matrix_set expr) + let query_var_eigen_names (expr : Expr.Typed.t) : string Set.Poly.t = + let get_expr_eigen_names + (Dataflow_types.VVar s, Expr.Typed.Meta.{adlevel; type_; _}) = + if + UnsizedType.contains_eigen_type type_ + && UnsizedType.is_autodifftype adlevel + then Some s + else None in + Set.Poly.filter_map ~f:get_expr_eigen_names (matrix_set expr) -(** + (** Check whether one set is a nonzero subset of another set. *) -let is_nonzero_subset ~set ~subset = - Set.Poly.is_subset subset ~of_:set - && (not (Set.Poly.is_empty set)) - && not (Set.Poly.is_empty subset) + let is_nonzero_subset ~set ~subset = + Set.Poly.is_subset subset ~of_:set + && (not (Set.Poly.is_empty set)) + && not (Set.Poly.is_empty subset) -(** + (** Check an expression to count how many times we see a single index. @param acc An accumulator from previous folds of multiple expressions. @param pattern The expression patterns to match against *) -let rec count_single_idx_exprs (acc : int) Expr.Fixed.{pattern; _} : int = - match pattern with - | Expr.Fixed.Pattern.FunApp (_, (exprs : Expr.Typed.t list)) -> - List.fold_left ~init:acc ~f:count_single_idx_exprs exprs - | TernaryIf (predicate, texpr, fexpr) -> - acc - + count_single_idx_exprs 0 predicate - + count_single_idx_exprs 0 texpr - + count_single_idx_exprs 0 fexpr - | Indexed (idx_expr, indexed) -> - acc - + count_single_idx_exprs 0 idx_expr - + List.fold_left ~init:0 ~f:count_single_idx indexed - | Promotion (expr, _, _) -> count_single_idx_exprs acc expr - | EAnd (lhs, rhs) -> - acc + count_single_idx_exprs 0 lhs + count_single_idx_exprs 0 rhs - | EOr (lhs, rhs) -> - acc + count_single_idx_exprs 0 lhs + count_single_idx_exprs 0 rhs - | Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) -> - acc + let rec count_single_idx_exprs (acc : int) Expr.Fixed.{pattern; _} : int = + match pattern with + | Expr.Fixed.Pattern.FunApp (_, (exprs : Expr.Typed.t list)) -> + List.fold_left ~init:acc ~f:count_single_idx_exprs exprs + | TernaryIf (predicate, texpr, fexpr) -> + acc + + count_single_idx_exprs 0 predicate + + count_single_idx_exprs 0 texpr + + count_single_idx_exprs 0 fexpr + | Indexed (idx_expr, indexed) -> + acc + + count_single_idx_exprs 0 idx_expr + + List.fold_left ~init:0 ~f:count_single_idx indexed + | Promotion (expr, _, _) -> count_single_idx_exprs acc expr + | EAnd (lhs, rhs) -> + acc + count_single_idx_exprs 0 lhs + count_single_idx_exprs 0 rhs + | EOr (lhs, rhs) -> + acc + count_single_idx_exprs 0 lhs + count_single_idx_exprs 0 rhs + | Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) -> + acc -(** + (** Check an Index to count how many times we see a single index. @param acc An accumulator from previous folds of multiple expressions. @param idx An Index to match. For Single types this adds 1 to the @@ -77,12 +78,12 @@ let rec count_single_idx_exprs (acc : int) Expr.Fixed.{pattern; _} : int = for a Single index. All and Between cannot be Single cell access and so pass acc along. *) -and count_single_idx (acc : int) (idx : Expr.Typed.t Index.t) = - match idx with - | Index.All | Between _ | Upfrom _ | MultiIndex _ -> acc - | Single _ -> acc + 1 + and count_single_idx (acc : int) (idx : Expr.Typed.t Index.t) = + match idx with + | Index.All | Between _ | Upfrom _ | MultiIndex _ -> acc + | Single _ -> acc + 1 -(** + (** Find indices on Matrix and Vector types that perform single cell access. Returns true if it finds a vector, row vector, matrix, or matrix with single cell access @@ -92,52 +93,49 @@ and count_single_idx (acc : int) (idx : Expr.Typed.t Index.t) = @param index This list is checked for Single cell access either at the top level or within the [Index] types of the list. *) -let rec is_uni_eigen_loop_indexing in_loop (ut : UnsizedType.t) - (index : Expr.Typed.t Index.t list) = - match in_loop with - | false -> false - | true -> ( - let contains_single_idx = - List.fold_left ~init:0 ~f:count_single_idx index in - match (ut, index) with - | (UnsizedType.UVector | URowVector), _ when contains_single_idx > 0 -> - true - | UMatrix, _ when contains_single_idx > 1 -> true - | (UArray t | UFun (_, ReturnType t, _, _)), index -> ( - match List.tl index with - | Some cut_list -> is_uni_eigen_loop_indexing in_loop t cut_list - | None -> false ) - | _ -> false ) + let rec is_uni_eigen_loop_indexing in_loop (ut : UnsizedType.t) + (index : Expr.Typed.t Index.t list) = + match in_loop with + | false -> false + | true -> ( + let contains_single_idx = + List.fold_left ~init:0 ~f:count_single_idx index in + match (ut, index) with + | (UnsizedType.UVector | URowVector), _ when contains_single_idx > 0 -> + true + | UMatrix, _ when contains_single_idx > 1 -> true + | (UArray t | UFun (_, ReturnType t, _, _)), index -> ( + match List.tl index with + | Some cut_list -> is_uni_eigen_loop_indexing in_loop t cut_list + | None -> false ) + | _ -> false ) -let query_stan_math_mem_pattern_support (name : string) - (args : (UnsizedType.autodifftype * UnsizedType.t) list) = - let open Stan_math_signatures in - match name with - | x when is_reduce_sum_fn x -> false - | x when is_variadic_ode_fn x -> false - | x when is_variadic_dae_fn x -> false - | _ -> - let name = - string_operator_to_stan_math_fns (Utils.stdlib_distribution_name name) - in - let namematches = Hashtbl.find_multi stan_math_signatures name in - let filteredmatches = - List.filter - ~f:(fun x -> - Frontend.SignatureMismatch.check_compatible_arguments_mod_conv - (snd3 x) args - |> Result.is_ok ) - namematches in - let is_soa ((_ : UnsizedType.returntype), _, mem) = - mem = Common.Helpers.SoA in - List.exists ~f:is_soa filteredmatches + let query_stan_math_mem_pattern_support (name : string) + (args : (UnsizedType.autodifftype * UnsizedType.t) list) = + match name with + | x when StdLib.is_variadic_function_name x -> false + | _ -> + let name = + StdLib.string_operator_to_function_name + (Utils.stdlib_distribution_name name) in + let namematches = StdLib.get_signatures name in + let filteredmatches = + List.filter + ~f:(fun x -> + Frontend.SignatureMismatch.check_compatible_arguments_mod_conv + (snd3 x) args + |> Result.is_ok ) + namematches in + let is_soa ((_ : UnsizedType.returntype), _, mem) = + mem = Common.Helpers.SoA in + List.exists ~f:is_soa filteredmatches -(*Validate whether a function can support SoA matrices*) -let is_fun_soa_supported name exprs = - let fun_args = List.map ~f:Expr.Typed.fun_arg exprs in - query_stan_math_mem_pattern_support name fun_args + (*Validate whether a function can support SoA matrices*) + let is_fun_soa_supported name exprs = + let fun_args = List.map ~f:Expr.Typed.fun_arg exprs in + query_stan_math_mem_pattern_support name fun_args -(** + (** Query to find the initial set of objects that cannot be SoA. This is mostly recursing over expressions, with the exceptions being functions and indexing expressions. For the logic on functions @@ -147,41 +145,43 @@ let is_fun_soa_supported name exprs = will be returned if the matrix or vector is accessed by single cell indexing. *) -let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t) - Expr.Fixed.{pattern; _} : string Set.Poly.t = - let query_expr (accum : string Set.Poly.t) = - query_initial_demotable_expr in_loop ~acc:accum in - match pattern with - | FunApp (kind, (exprs : Expr.Typed.t list)) -> - query_initial_demotable_funs in_loop acc kind exprs - | Indexed ((Expr.Fixed.{meta= {type_; _}; _} as expr), indexed) -> - let index_set = - Set.Poly.union_list - (List.map - ~f: - (Index.apply ~default:Set.Poly.empty ~merge:Set.Poly.union - (query_expr acc) ) - indexed ) in - let index_demotes = - if is_uni_eigen_loop_indexing in_loop type_ indexed then - Set.Poly.union (query_var_eigen_names expr) index_set - else Set.Poly.union (query_expr acc expr) index_set in - Set.Poly.union acc index_demotes - | Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) -> - acc - | Promotion (expr, _, _) -> query_expr acc expr - | TernaryIf (predicate, texpr, fexpr) -> - let predicate_demotes = query_expr acc predicate in - Set.Poly.union - (Set.Poly.union predicate_demotes (query_var_eigen_names texpr)) - (query_var_eigen_names fexpr) - | EAnd (lhs, rhs) | EOr (lhs, rhs) -> - (*We need to get the demotes from both sides*) - let full_lhs_rhs = - Set.Poly.union (query_expr acc lhs) (query_expr acc rhs) in - Set.Poly.union (query_expr full_lhs_rhs lhs) (query_expr full_lhs_rhs rhs) + let rec query_initial_demotable_expr (in_loop : bool) + ~(acc : string Set.Poly.t) Expr.Fixed.{pattern; _} : string Set.Poly.t = + let query_expr (accum : string Set.Poly.t) = + query_initial_demotable_expr in_loop ~acc:accum in + match pattern with + | FunApp (kind, (exprs : Expr.Typed.t list)) -> + query_initial_demotable_funs in_loop acc kind exprs + | Indexed ((Expr.Fixed.{meta= {type_; _}; _} as expr), indexed) -> + let index_set = + Set.Poly.union_list + (List.map + ~f: + (Index.apply ~default:Set.Poly.empty ~merge:Set.Poly.union + (query_expr acc) ) + indexed ) in + let index_demotes = + if is_uni_eigen_loop_indexing in_loop type_ indexed then + Set.Poly.union (query_var_eigen_names expr) index_set + else Set.Poly.union (query_expr acc expr) index_set in + Set.Poly.union acc index_demotes + | Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) -> + acc + | Promotion (expr, _, _) -> query_expr acc expr + | TernaryIf (predicate, texpr, fexpr) -> + let predicate_demotes = query_expr acc predicate in + Set.Poly.union + (Set.Poly.union predicate_demotes (query_var_eigen_names texpr)) + (query_var_eigen_names fexpr) + | EAnd (lhs, rhs) | EOr (lhs, rhs) -> + (*We need to get the demotes from both sides*) + let full_lhs_rhs = + Set.Poly.union (query_expr acc lhs) (query_expr acc rhs) in + Set.Poly.union + (query_expr full_lhs_rhs lhs) + (query_expr full_lhs_rhs rhs) -(** + (** Query a function to detect if it or any of its used expression's objects or expressions should be demoted to AoS. * @@ -198,134 +198,136 @@ let rec query_initial_demotable_expr (in_loop : bool) ~(acc : string Set.Poly.t) to the UDF. exprs The expression list passed to the functions. *) -and query_initial_demotable_funs (in_loop : bool) (acc : string Set.Poly.t) - (kind : 'a Fun_kind.t) (exprs : Expr.Typed.t list) : string Set.Poly.t = - let query_expr accum = query_initial_demotable_expr in_loop ~acc:accum in - let top_level_eigen_names = - Set.Poly.union_list (List.map ~f:query_var_eigen_names exprs) in - let demoted_eigen_names = List.fold ~init:acc ~f:query_expr exprs in - let demoted_and_top_level_names = - Set.Poly.union demoted_eigen_names top_level_eigen_names in - match kind with - | Fun_kind.StanLib (name, (_ : bool Fun_kind.suffix), _) -> ( - match name with - | "check_matching_dims" -> acc - | name -> ( - match is_fun_soa_supported name exprs with - | true -> Set.Poly.union acc demoted_eigen_names - | false -> Set.Poly.union acc demoted_and_top_level_names ) ) - | CompilerInternal (Internal_fun.FnMakeArray | FnMakeRowVec) -> - Set.Poly.union acc demoted_and_top_level_names - | CompilerInternal (_ : 'a Internal_fun.t) -> acc - | UserDefined ((_ : string), (_ : bool Fun_kind.suffix)) -> - Set.Poly.union acc demoted_and_top_level_names + and query_initial_demotable_funs (in_loop : bool) (acc : string Set.Poly.t) + (kind : 'a Fun_kind.t) (exprs : Expr.Typed.t list) : string Set.Poly.t = + let query_expr accum = query_initial_demotable_expr in_loop ~acc:accum in + let top_level_eigen_names = + Set.Poly.union_list (List.map ~f:query_var_eigen_names exprs) in + let demoted_eigen_names = List.fold ~init:acc ~f:query_expr exprs in + let demoted_and_top_level_names = + Set.Poly.union demoted_eigen_names top_level_eigen_names in + match kind with + | Fun_kind.StanLib (name, (_ : bool Fun_kind.suffix), _) -> ( + match name with + | "check_matching_dims" -> acc + | name -> ( + match is_fun_soa_supported name exprs with + | true -> Set.Poly.union acc demoted_eigen_names + | false -> Set.Poly.union acc demoted_and_top_level_names ) ) + | CompilerInternal (Internal_fun.FnMakeArray | FnMakeRowVec) -> + Set.Poly.union acc demoted_and_top_level_names + | CompilerInternal (_ : 'a Internal_fun.t) -> acc + | UserDefined ((_ : string), (_ : bool Fun_kind.suffix)) -> + Set.Poly.union acc demoted_and_top_level_names -(** + (** Check whether any functions in the right hand side expression of an assignment support SoA. If so then return true, otherwise return false. *) -let rec is_any_soa_supported_expr - Expr.Fixed.{pattern; meta= Expr.Typed.Meta.{adlevel; type_; _}} : bool = - if - UnsizedType.is_dataonlytype adlevel - || not (UnsizedType.contains_eigen_type type_) - then true - else - match pattern with - | FunApp (kind, (exprs : Expr.Typed.t list)) -> - is_any_soa_supported_fun_expr kind exprs - | Indexed (expr, (_ : Expr.Typed.t Index.t list)) | Promotion (expr, _, _) - -> - is_any_soa_supported_expr expr - | Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) -> - true - | TernaryIf (_, texpr, fexpr) -> - is_any_soa_supported_expr texpr && is_any_soa_supported_expr fexpr - | EAnd (lhs, rhs) | EOr (lhs, rhs) -> - is_any_soa_supported_expr lhs && is_any_soa_supported_expr rhs + let rec is_any_soa_supported_expr + Expr.Fixed.{pattern; meta= Expr.Typed.Meta.{adlevel; type_; _}} : bool = + if + UnsizedType.is_dataonlytype adlevel + || not (UnsizedType.contains_eigen_type type_) + then true + else + match pattern with + | FunApp (kind, (exprs : Expr.Typed.t list)) -> + is_any_soa_supported_fun_expr kind exprs + | Indexed (expr, (_ : Expr.Typed.t Index.t list)) | Promotion (expr, _, _) + -> + is_any_soa_supported_expr expr + | Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) + -> + true + | TernaryIf (_, texpr, fexpr) -> + is_any_soa_supported_expr texpr && is_any_soa_supported_expr fexpr + | EAnd (lhs, rhs) | EOr (lhs, rhs) -> + is_any_soa_supported_expr lhs && is_any_soa_supported_expr rhs -(** + (** Return false if the [Fun_kind.t] does not support [SoA] *) -and is_any_soa_supported_fun_expr (kind : 'a Fun_kind.t) - (exprs : Expr.Typed.t list) : bool = - match kind with - | CompilerInternal (Internal_fun.FnMakeArray | FnMakeRowVec) -> false - | UserDefined ((_ : string), (_ : bool Fun_kind.suffix)) -> false - | CompilerInternal (_ : 'a Internal_fun.t) -> true - | Fun_kind.StanLib (name, (_ : bool Fun_kind.suffix), _) -> ( - match name with - | "check_matching_dims" -> true - | _ -> - is_fun_soa_supported name exprs - && List.exists ~f:is_any_soa_supported_expr exprs ) + and is_any_soa_supported_fun_expr (kind : 'a Fun_kind.t) + (exprs : Expr.Typed.t list) : bool = + match kind with + | CompilerInternal (Internal_fun.FnMakeArray | FnMakeRowVec) -> false + | UserDefined ((_ : string), (_ : bool Fun_kind.suffix)) -> false + | CompilerInternal (_ : 'a Internal_fun.t) -> true + | Fun_kind.StanLib (name, (_ : bool Fun_kind.suffix), _) -> ( + match name with + | "check_matching_dims" -> true + | _ -> + is_fun_soa_supported name exprs + && List.exists ~f:is_any_soa_supported_expr exprs ) -(** + (** Return true if the rhs expression of an assignment contains only combinations of AutoDiffable Reals and Data Matrices *) -let rec is_any_ad_real_data_matrix_expr - Expr.Fixed.{pattern; meta= Expr.Typed.Meta.{adlevel; _}} : bool = - if UnsizedType.is_dataonlytype adlevel then false - else - match pattern with - | FunApp (kind, (exprs : Expr.Typed.t list)) -> - is_any_ad_real_data_matrix_expr_fun kind exprs - | Indexed (expr, _) | Promotion (expr, _, _) -> - is_any_ad_real_data_matrix_expr expr - | Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) -> - false - | TernaryIf (_, texpr, fexpr) -> - is_any_ad_real_data_matrix_expr texpr - || is_any_ad_real_data_matrix_expr fexpr - | EAnd (lhs, rhs) | EOr (lhs, rhs) -> - is_any_ad_real_data_matrix_expr lhs - && is_any_ad_real_data_matrix_expr rhs + let rec is_any_ad_real_data_matrix_expr + Expr.Fixed.{pattern; meta= Expr.Typed.Meta.{adlevel; _}} : bool = + if UnsizedType.is_dataonlytype adlevel then false + else + match pattern with + | FunApp (kind, (exprs : Expr.Typed.t list)) -> + is_any_ad_real_data_matrix_expr_fun kind exprs + | Indexed (expr, _) | Promotion (expr, _, _) -> + is_any_ad_real_data_matrix_expr expr + | Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) + -> + false + | TernaryIf (_, texpr, fexpr) -> + is_any_ad_real_data_matrix_expr texpr + || is_any_ad_real_data_matrix_expr fexpr + | EAnd (lhs, rhs) | EOr (lhs, rhs) -> + is_any_ad_real_data_matrix_expr lhs + && is_any_ad_real_data_matrix_expr rhs -(** + (** Return true if the expressions in a function call are all combinations of AutoDiffable Reals and Data Matrices *) -and is_any_ad_real_data_matrix_expr_fun (kind : 'a Fun_kind.t) - (exprs : Expr.Typed.t list) : bool = - match kind with - | Fun_kind.StanLib (name, (_ : bool Fun_kind.suffix), _) -> ( - match name with - | "check_matching_dims" -> false - | _ -> ( - let fun_args = List.map ~f:Expr.Typed.fun_arg exprs in - (*Right now we can't handle AD real and data matrix funcs - that return a matrix :-/*) - let is_args_autodiff_real_data_matrix = - (*If there are any autodiffable vars*) - List.exists - ~f:(fun (x, y) -> - match (x, y) with - | UnsizedType.AutoDiffable, UnsizedType.UReal -> true - | _ -> false ) - fun_args - (*And there are any data matrices*) - && List.exists - ~f:(fun (x, y) -> - match (x, UnsizedType.is_container y) with - | UnsizedType.DataOnly, true -> true - | _ -> false ) - fun_args - (*And there are no Autodiffable matrices*) - && List.exists - ~f:(fun (x, y) -> - match (x, UnsizedType.contains_eigen_type y) with - | UnsizedType.AutoDiffable, true -> false - | _ -> true ) - fun_args in - match is_args_autodiff_real_data_matrix with - | true -> true - | false -> List.exists ~f:is_any_ad_real_data_matrix_expr exprs ) ) - | CompilerInternal (Internal_fun.FnMakeArray | FnMakeRowVec) -> true - | CompilerInternal (_ : 'a Internal_fun.t) -> false - | UserDefined ((_ : string), (_ : bool Fun_kind.suffix)) -> false + and is_any_ad_real_data_matrix_expr_fun (kind : 'a Fun_kind.t) + (exprs : Expr.Typed.t list) : bool = + match kind with + | Fun_kind.StanLib (name, (_ : bool Fun_kind.suffix), _) -> ( + match name with + | "check_matching_dims" -> false + | _ -> ( + let fun_args = List.map ~f:Expr.Typed.fun_arg exprs in + (*Right now we can't handle AD real and data matrix funcs + that return a matrix :-/*) + let is_args_autodiff_real_data_matrix = + (*If there are any autodiffable vars*) + List.exists + ~f:(fun (x, y) -> + match (x, y) with + | UnsizedType.AutoDiffable, UnsizedType.UReal -> true + | _ -> false ) + fun_args + (*And there are any data matrices*) + && List.exists + ~f:(fun (x, y) -> + match (x, UnsizedType.is_container y) with + | UnsizedType.DataOnly, true -> true + | _ -> false ) + fun_args + (*And there are no Autodiffable matrices*) + && List.exists + ~f:(fun (x, y) -> + match (x, UnsizedType.contains_eigen_type y) with + | UnsizedType.AutoDiffable, true -> false + | _ -> true ) + fun_args in + match is_args_autodiff_real_data_matrix with + | true -> true + | false -> List.exists ~f:is_any_ad_real_data_matrix_expr exprs ) ) + | CompilerInternal (Internal_fun.FnMakeArray | FnMakeRowVec) -> true + | CompilerInternal (_ : 'a Internal_fun.t) -> false + | UserDefined ((_ : string), (_ : bool Fun_kind.suffix)) -> false -(** + (** Query to find the initial set of objects in statements that cannot be SoA. This is mostly recursive over expressions and statements, with the exception of functions and Assignments. @@ -349,96 +351,99 @@ and is_any_ad_real_data_matrix_expr_fun (kind : 'a Fun_kind.t) @param in_loop A boolean to specify the logic of indexing expressions. See [query_initial_demotable_expr] for an explanation of the logic. *) -let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t) - (Stmt.Fixed.{pattern; _} : Stmt.Located.t) : string Set.Poly.t = - let query_expr (accum : string Set.Poly.t) = - query_initial_demotable_expr in_loop ~acc:accum in - match pattern with - | Stmt.Fixed.Pattern.Assignment - ( ((name : string), (ut : UnsizedType.t), idx) - , (Expr.Fixed.{meta= Expr.Typed.Meta.{type_; adlevel; _}; _} as rhs) ) -> - let idx_list = - List.fold ~init:acc - ~f:(fun accum x -> - Index.folder accum - (fun acc -> query_initial_demotable_expr in_loop ~acc) - x ) - idx in - let idx_demotable = - (* RHS (2)*) - match is_uni_eigen_loop_indexing in_loop ut idx with - | true -> Set.Poly.add idx_list name - | false -> idx_list in - let rhs_demotable_names = query_expr acc rhs in - (* RHS (3)*) - let check_if_rhs_ad_real_data_matrix_expr = - match (UnsizedType.contains_eigen_type type_, adlevel) with - | true, UnsizedType.AutoDiffable -> - is_any_ad_real_data_matrix_expr rhs - || not (is_any_soa_supported_expr rhs) - | _ -> false in - (* RHS (1)*) - let is_all_rhs_aos = - let all_rhs_eigen_names = query_var_eigen_names rhs in - is_nonzero_subset ~subset:all_rhs_eigen_names ~set:rhs_demotable_names - in - let is_not_supported_func = - match rhs.pattern with - | FunApp (CompilerInternal _, _) -> false - | FunApp (UserDefined _, _) -> true - | _ -> false in - let is_eigen_stmt = UnsizedType.contains_eigen_type rhs.meta.type_ in - let assign_demotes = - if - is_eigen_stmt - && ( is_all_rhs_aos || check_if_rhs_ad_real_data_matrix_expr - || is_not_supported_func ) - then - let base_set = Set.Poly.union idx_demotable rhs_demotable_names in - Set.Poly.add - (Set.Poly.union base_set (query_var_eigen_names rhs)) - name - else Set.Poly.union idx_demotable rhs_demotable_names in - Set.Poly.union acc assign_demotes - | NRFunApp (kind, exprs) -> - query_initial_demotable_funs in_loop acc kind exprs - | IfElse (predicate, true_stmt, op_false_stmt) -> - let predicate_acc = query_expr acc predicate in - Set.Poly.union acc - (Set.Poly.union_list - [ predicate_acc - ; query_initial_demotable_stmt in_loop predicate_acc true_stmt - ; Option.value_map - ~f:(query_initial_demotable_stmt in_loop predicate_acc) - ~default:Set.Poly.empty op_false_stmt ] ) - | Return optional_expr -> - Option.value_map ~f:(query_expr acc) ~default:Set.Poly.empty optional_expr - | SList lst | Profile (_, lst) | Block lst -> - Set.Poly.union_list - (List.map ~f:(query_initial_demotable_stmt in_loop acc) lst) - | TargetPE expr -> query_expr acc expr - (*NOTE: loops generated by inlining are not actually loops*) - | For - { lower= Expr.Fixed.{pattern= Lit (Int, lb); _} - ; upper= Expr.Fixed.{pattern= Lit (Int, ub); _} - ; body - ; _ } - when lb = "1" && ub = "1" -> - query_initial_demotable_stmt false acc body - | For {lower; upper; body; _} -> - Set.Poly.union - (Set.Poly.union (query_expr acc lower) (query_expr acc upper)) - (query_initial_demotable_stmt true acc body) - | While (predicate, body) -> - Set.Poly.union_list - [ acc; query_expr acc predicate - ; query_initial_demotable_stmt true acc body ] - | Decl {decl_type= Type.Sized st; decl_id; _} - when SizedType.is_complex_type st -> - Set.Poly.add acc decl_id - | Skip | Break | Continue | Decl _ -> acc + let rec query_initial_demotable_stmt (in_loop : bool) + (acc : string Set.Poly.t) (Stmt.Fixed.{pattern; _} : Stmt.Located.t) : + string Set.Poly.t = + let query_expr (accum : string Set.Poly.t) = + query_initial_demotable_expr in_loop ~acc:accum in + match pattern with + | Stmt.Fixed.Pattern.Assignment + ( ((name : string), (ut : UnsizedType.t), idx) + , (Expr.Fixed.{meta= Expr.Typed.Meta.{type_; adlevel; _}; _} as rhs) ) + -> + let idx_list = + List.fold ~init:acc + ~f:(fun accum x -> + Index.folder accum + (fun acc -> query_initial_demotable_expr in_loop ~acc) + x ) + idx in + let idx_demotable = + (* RHS (2)*) + match is_uni_eigen_loop_indexing in_loop ut idx with + | true -> Set.Poly.add idx_list name + | false -> idx_list in + let rhs_demotable_names = query_expr acc rhs in + (* RHS (3)*) + let check_if_rhs_ad_real_data_matrix_expr = + match (UnsizedType.contains_eigen_type type_, adlevel) with + | true, UnsizedType.AutoDiffable -> + is_any_ad_real_data_matrix_expr rhs + || not (is_any_soa_supported_expr rhs) + | _ -> false in + (* RHS (1)*) + let is_all_rhs_aos = + let all_rhs_eigen_names = query_var_eigen_names rhs in + is_nonzero_subset ~subset:all_rhs_eigen_names ~set:rhs_demotable_names + in + let is_not_supported_func = + match rhs.pattern with + | FunApp (CompilerInternal _, _) -> false + | FunApp (UserDefined _, _) -> true + | _ -> false in + let is_eigen_stmt = UnsizedType.contains_eigen_type rhs.meta.type_ in + let assign_demotes = + if + is_eigen_stmt + && ( is_all_rhs_aos || check_if_rhs_ad_real_data_matrix_expr + || is_not_supported_func ) + then + let base_set = Set.Poly.union idx_demotable rhs_demotable_names in + Set.Poly.add + (Set.Poly.union base_set (query_var_eigen_names rhs)) + name + else Set.Poly.union idx_demotable rhs_demotable_names in + Set.Poly.union acc assign_demotes + | NRFunApp (kind, exprs) -> + query_initial_demotable_funs in_loop acc kind exprs + | IfElse (predicate, true_stmt, op_false_stmt) -> + let predicate_acc = query_expr acc predicate in + Set.Poly.union acc + (Set.Poly.union_list + [ predicate_acc + ; query_initial_demotable_stmt in_loop predicate_acc true_stmt + ; Option.value_map + ~f:(query_initial_demotable_stmt in_loop predicate_acc) + ~default:Set.Poly.empty op_false_stmt ] ) + | Return optional_expr -> + Option.value_map ~f:(query_expr acc) ~default:Set.Poly.empty + optional_expr + | SList lst | Profile (_, lst) | Block lst -> + Set.Poly.union_list + (List.map ~f:(query_initial_demotable_stmt in_loop acc) lst) + | TargetPE expr -> query_expr acc expr + (*NOTE: loops generated by inlining are not actually loops*) + | For + { lower= Expr.Fixed.{pattern= Lit (Int, lb); _} + ; upper= Expr.Fixed.{pattern= Lit (Int, ub); _} + ; body + ; _ } + when lb = "1" && ub = "1" -> + query_initial_demotable_stmt false acc body + | For {lower; upper; body; _} -> + Set.Poly.union + (Set.Poly.union (query_expr acc lower) (query_expr acc upper)) + (query_initial_demotable_stmt true acc body) + | While (predicate, body) -> + Set.Poly.union_list + [ acc; query_expr acc predicate + ; query_initial_demotable_stmt true acc body ] + | Decl {decl_type= Type.Sized st; decl_id; _} + when SizedType.is_complex_type st -> + Set.Poly.add acc decl_id + | Skip | Break | Continue | Decl _ -> acc -(** Look through a statement to see whether the objects used in it need to be + (** Look through a statement to see whether the objects used in it need to be modified from SoA to AoS. Returns the set of object names that need demoted in a statement, if any. This function looks at Assignment statements, and returns back the @@ -450,25 +455,27 @@ let rec query_initial_demotable_stmt (in_loop : bool) (acc : string Set.Poly.t) @param aos_exits A set of variables that can be demoted. @param pattern The Stmt pattern to query. *) -let query_demotable_stmt (aos_exits : string Set.Poly.t) - (pattern : (Expr.Typed.t, int) Stmt.Fixed.Pattern.t) : string Set.Poly.t = - match pattern with - | Stmt.Fixed.Pattern.Assignment - ( ( (assign_name : string) - , (_ : UnsizedType.t) - , (_ : Expr.Typed.t Index.t list) ) - , (rhs : Expr.Typed.t) ) -> ( - let all_rhs_eigen_names = query_var_eigen_names rhs in - if Set.Poly.mem aos_exits assign_name then - Set.Poly.add all_rhs_eigen_names assign_name - else - match is_nonzero_subset ~set:aos_exits ~subset:all_rhs_eigen_names with - | true -> Set.Poly.add all_rhs_eigen_names assign_name - | false -> Set.Poly.empty ) - (* All other statements do not need logic here*) - | _ -> Set.Poly.empty + let query_demotable_stmt (aos_exits : string Set.Poly.t) + (pattern : (Expr.Typed.t, int) Stmt.Fixed.Pattern.t) : string Set.Poly.t = + match pattern with + | Stmt.Fixed.Pattern.Assignment + ( ( (assign_name : string) + , (_ : UnsizedType.t) + , (_ : Expr.Typed.t Index.t list) ) + , (rhs : Expr.Typed.t) ) -> ( + let all_rhs_eigen_names = query_var_eigen_names rhs in + if Set.Poly.mem aos_exits assign_name then + Set.Poly.add all_rhs_eigen_names assign_name + else + match + is_nonzero_subset ~set:aos_exits ~subset:all_rhs_eigen_names + with + | true -> Set.Poly.add all_rhs_eigen_names assign_name + | false -> Set.Poly.empty ) + (* All other statements do not need logic here*) + | _ -> Set.Poly.empty -(** + (** Modify a function and it's subexpressions from SoA <-> AoS and vice versa. This performs demotion for sub expressions recursively. The top level expression and it's sub expressions are demoted to SoA if @@ -483,31 +490,36 @@ let query_demotable_stmt (aos_exits : string Set.Poly.t) @param kind A [Fun_kind.t] @param exprs A list of expressions going into the function. **) -let rec modify_kind ?force_demotion:(force = false) - (modifiable_set : string Set.Poly.t) (kind : 'a Fun_kind.t) - (exprs : Expr.Typed.t list) = - let expr_names = - Set.Poly.union_list (List.map ~f:query_var_eigen_names exprs) in - let is_all_in_list = - is_nonzero_subset ~set:modifiable_set ~subset:expr_names in - match kind with - | Fun_kind.StanLib (name, sfx, (_ : Common.Helpers.mem_pattern)) -> - if is_all_in_list || (not (is_fun_soa_supported name exprs)) || force then - (*Force demotion of all subexprs*) - let exprs' = - List.map ~f:(modify_expr ~force_demotion:true expr_names) exprs in - (Fun_kind.StanLib (name, sfx, Common.Helpers.AoS), exprs') - else - ( Fun_kind.StanLib (name, sfx, SoA) + let rec modify_kind ?force_demotion:(force = false) + (modifiable_set : string Set.Poly.t) (kind : 'a Fun_kind.t) + (exprs : Expr.Typed.t list) = + let expr_names = + Set.Poly.union_list (List.map ~f:query_var_eigen_names exprs) in + let is_all_in_list = + is_nonzero_subset ~set:modifiable_set ~subset:expr_names in + match kind with + | Fun_kind.StanLib (name, sfx, (_ : Common.Helpers.mem_pattern)) -> + if is_all_in_list || (not (is_fun_soa_supported name exprs)) || force + then + (*Force demotion of all subexprs*) + let exprs' = + List.map ~f:(modify_expr ~force_demotion:true expr_names) exprs + in + (Fun_kind.StanLib (name, sfx, Common.Helpers.AoS), exprs') + else + ( Fun_kind.StanLib (name, sfx, SoA) + , List.map ~f:(modify_expr ~force_demotion:force modifiable_set) exprs + ) + | UserDefined _ as udf -> + ( udf + , List.map ~f:(modify_expr ~force_demotion:force modifiable_set) exprs + ) + | (_ : 'a Fun_kind.t) -> + ( kind , List.map ~f:(modify_expr ~force_demotion:force modifiable_set) exprs ) - | UserDefined _ as udf -> - (udf, List.map ~f:(modify_expr ~force_demotion:force modifiable_set) exprs) - | (_ : 'a Fun_kind.t) -> - ( kind - , List.map ~f:(modify_expr ~force_demotion:force modifiable_set) exprs ) -(** + (** Modify an expression and it's subexpressions from SoA <-> AoS and vice versa. The only real paths in the below is on the functions and ternary expressions. @@ -521,42 +533,42 @@ let rec modify_kind ?force_demotion:(force = false) associated expressions we want to modify. @param pattern The expression to modify. *) -and modify_expr_pattern ?force_demotion:(force = false) - (modifiable_set : string Set.Poly.t) - (pattern : Expr.Typed.t Expr.Fixed.Pattern.t) = - let mod_expr ?force_demotion:(forced = false) = - modify_expr ~force_demotion:forced modifiable_set in - match pattern with - | Expr.Fixed.Pattern.FunApp (kind, (exprs : Expr.Typed.t list)) -> - let kind', expr' = - modify_kind ~force_demotion:force modifiable_set kind exprs in - Expr.Fixed.Pattern.FunApp (kind', expr') - | TernaryIf (predicate, texpr, fexpr) -> - let is_eigen_return = - UnsizedType.contains_eigen_type fexpr.meta.type_ - || UnsizedType.contains_eigen_type texpr.meta.type_ in - if is_eigen_return then - TernaryIf - ( mod_expr ~force_demotion:force predicate - , mod_expr ~force_demotion:true texpr - , mod_expr ~force_demotion:true fexpr ) - else - TernaryIf - ( mod_expr ~force_demotion:force predicate - , mod_expr ~force_demotion:force texpr - , mod_expr ~force_demotion:force fexpr ) - | Indexed (idx_expr, indexed) -> - Indexed - ( mod_expr idx_expr - , List.map ~f:(Index.map (mod_expr ~force_demotion:force)) indexed ) - | EAnd (lhs, rhs) -> EAnd (mod_expr lhs, mod_expr rhs) - | EOr (lhs, rhs) -> EOr (mod_expr lhs, mod_expr rhs) - | Promotion (expr, type_, ad_level) -> - Promotion (mod_expr expr, type_, ad_level) - | Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) -> - pattern + and modify_expr_pattern ?force_demotion:(force = false) + (modifiable_set : string Set.Poly.t) + (pattern : Expr.Typed.t Expr.Fixed.Pattern.t) = + let mod_expr ?force_demotion:(forced = false) = + modify_expr ~force_demotion:forced modifiable_set in + match pattern with + | Expr.Fixed.Pattern.FunApp (kind, (exprs : Expr.Typed.t list)) -> + let kind', expr' = + modify_kind ~force_demotion:force modifiable_set kind exprs in + Expr.Fixed.Pattern.FunApp (kind', expr') + | TernaryIf (predicate, texpr, fexpr) -> + let is_eigen_return = + UnsizedType.contains_eigen_type fexpr.meta.type_ + || UnsizedType.contains_eigen_type texpr.meta.type_ in + if is_eigen_return then + TernaryIf + ( mod_expr ~force_demotion:force predicate + , mod_expr ~force_demotion:true texpr + , mod_expr ~force_demotion:true fexpr ) + else + TernaryIf + ( mod_expr ~force_demotion:force predicate + , mod_expr ~force_demotion:force texpr + , mod_expr ~force_demotion:force fexpr ) + | Indexed (idx_expr, indexed) -> + Indexed + ( mod_expr idx_expr + , List.map ~f:(Index.map (mod_expr ~force_demotion:force)) indexed ) + | EAnd (lhs, rhs) -> EAnd (mod_expr lhs, mod_expr rhs) + | EOr (lhs, rhs) -> EOr (mod_expr lhs, mod_expr rhs) + | Promotion (expr, type_, ad_level) -> + Promotion (mod_expr expr, type_, ad_level) + | Var (_ : string) | Lit ((_ : Expr.Fixed.Pattern.litType), (_ : string)) -> + pattern -(** + (** Given a Set of strings containing the names of objects that can be modified from AoS <-> SoA and vice versa, modify them within the expression. @param mem_pattern The memory pattern to change expressions to. @@ -564,12 +576,13 @@ and modify_expr_pattern ?force_demotion:(force = false) associated expressions we want to modify. @param expr the expression to modify. *) -and modify_expr ?force_demotion:(force = false) - (modifiable_set : string Set.Poly.t) (Expr.Fixed.{pattern; _} as expr) = - { expr with - pattern= modify_expr_pattern ~force_demotion:force modifiable_set pattern } + and modify_expr ?force_demotion:(force = false) + (modifiable_set : string Set.Poly.t) (Expr.Fixed.{pattern; _} as expr) = + { expr with + pattern= modify_expr_pattern ~force_demotion:force modifiable_set pattern + } -(** + (** Modify statement patterns in the MIR from AoS <-> SoA and vice versa For [Decl] and [Assignment]'s reading in parameters, we demote to AoS if the [decl_id] (or assign name) is in the modifiable set and @@ -581,81 +594,82 @@ and modify_expr ?force_demotion:(force = false) @param pattern The statement pattern to modify @param modifiable_set The name of the variable we are searching for. *) -let rec modify_stmt_pattern - (pattern : (Expr.Typed.t, Stmt.Located.t) Stmt.Fixed.Pattern.t) - (modifiable_set : string Core_kernel.Set.Poly.t) = - let mod_expr force = modify_expr ~force_demotion:force modifiable_set in - let mod_stmt stmt = modify_stmt stmt modifiable_set in - match pattern with - | Stmt.Fixed.Pattern.Decl - ({decl_id; decl_type= Type.Sized sized_type; _} as decl) -> - if Set.Poly.mem modifiable_set decl_id then - Stmt.Fixed.Pattern.Decl - { decl with - decl_type= - Type.Sized (SizedType.modify_sizedtype_mem AoS sized_type) } - else - Decl - { decl with - decl_type= - Type.Sized (SizedType.modify_sizedtype_mem SoA sized_type) } - | NRFunApp (kind, (exprs : Expr.Typed.t list)) -> - let kind', exprs' = modify_kind modifiable_set kind exprs in - NRFunApp (kind', exprs') - | Assignment - ( (name, ut, lhs) - , ( {pattern= FunApp (CompilerInternal (FnReadParam read_param), args); _} - as assigner ) ) -> - if Set.Poly.mem modifiable_set name then - Assignment - ( (name, ut, List.map ~f:(Index.map (mod_expr false)) lhs) - , { assigner with - pattern= - FunApp - ( CompilerInternal - (FnReadParam {read_param with mem_pattern= AoS}) - , List.map ~f:(mod_expr true) args ) } ) - else - Assignment - ( (name, ut, List.map ~f:(Index.map (mod_expr false)) lhs) - , { assigner with - pattern= - FunApp - ( CompilerInternal - (FnReadParam {read_param with mem_pattern= SoA}) - , List.map ~f:(mod_expr false) args ) } ) - | Assignment (((name : string), (ut : UnsizedType.t), idx), rhs) -> - if Set.Poly.mem modifiable_set name then - (*If assignee is in bad set, force demotion of rhs functions*) - Assignment - ( (name, ut, List.map ~f:(Index.map (mod_expr false)) idx) - , mod_expr true rhs ) - else - Assignment - ( (name, ut, List.map ~f:(Index.map (mod_expr false)) idx) - , (mod_expr false) rhs ) - | IfElse (predicate, true_stmt, op_false_stmt) -> - IfElse - ( (mod_expr false) predicate - , mod_stmt true_stmt - , Option.map ~f:mod_stmt op_false_stmt ) - | Block stmts -> Block (List.map ~f:mod_stmt stmts) - | SList stmts -> SList (List.map ~f:mod_stmt stmts) - | For ({lower; upper; body; _} as loop) -> - Stmt.Fixed.Pattern.For - { loop with - lower= mod_expr false lower - ; upper= mod_expr false upper - ; body= mod_stmt body } - | TargetPE expr -> TargetPE ((mod_expr false) expr) - | Return optional_expr -> - Return (Option.map ~f:(mod_expr false) optional_expr) - | Profile ((p_name : string), stmt) -> - Profile (p_name, List.map ~f:mod_stmt stmt) - | While (predicate, body) -> While ((mod_expr false) predicate, mod_stmt body) - | Skip | Break | Continue | Decl _ -> pattern + let rec modify_stmt_pattern + (pattern : (Expr.Typed.t, Stmt.Located.t) Stmt.Fixed.Pattern.t) + (modifiable_set : string Core_kernel.Set.Poly.t) = + let mod_expr force = modify_expr ~force_demotion:force modifiable_set in + let mod_stmt stmt = modify_stmt stmt modifiable_set in + match pattern with + | Stmt.Fixed.Pattern.Decl + ({decl_id; decl_type= Type.Sized sized_type; _} as decl) -> + if Set.Poly.mem modifiable_set decl_id then + Stmt.Fixed.Pattern.Decl + { decl with + decl_type= + Type.Sized (SizedType.modify_sizedtype_mem AoS sized_type) } + else + Decl + { decl with + decl_type= + Type.Sized (SizedType.modify_sizedtype_mem SoA sized_type) } + | NRFunApp (kind, (exprs : Expr.Typed.t list)) -> + let kind', exprs' = modify_kind modifiable_set kind exprs in + NRFunApp (kind', exprs') + | Assignment + ( (name, ut, lhs) + , ( { pattern= FunApp (CompilerInternal (FnReadParam read_param), args) + ; _ } as assigner ) ) -> + if Set.Poly.mem modifiable_set name then + Assignment + ( (name, ut, List.map ~f:(Index.map (mod_expr false)) lhs) + , { assigner with + pattern= + FunApp + ( CompilerInternal + (FnReadParam {read_param with mem_pattern= AoS}) + , List.map ~f:(mod_expr true) args ) } ) + else + Assignment + ( (name, ut, List.map ~f:(Index.map (mod_expr false)) lhs) + , { assigner with + pattern= + FunApp + ( CompilerInternal + (FnReadParam {read_param with mem_pattern= SoA}) + , List.map ~f:(mod_expr false) args ) } ) + | Assignment (((name : string), (ut : UnsizedType.t), idx), rhs) -> + if Set.Poly.mem modifiable_set name then + (*If assignee is in bad set, force demotion of rhs functions*) + Assignment + ( (name, ut, List.map ~f:(Index.map (mod_expr false)) idx) + , mod_expr true rhs ) + else + Assignment + ( (name, ut, List.map ~f:(Index.map (mod_expr false)) idx) + , (mod_expr false) rhs ) + | IfElse (predicate, true_stmt, op_false_stmt) -> + IfElse + ( (mod_expr false) predicate + , mod_stmt true_stmt + , Option.map ~f:mod_stmt op_false_stmt ) + | Block stmts -> Block (List.map ~f:mod_stmt stmts) + | SList stmts -> SList (List.map ~f:mod_stmt stmts) + | For ({lower; upper; body; _} as loop) -> + Stmt.Fixed.Pattern.For + { loop with + lower= mod_expr false lower + ; upper= mod_expr false upper + ; body= mod_stmt body } + | TargetPE expr -> TargetPE ((mod_expr false) expr) + | Return optional_expr -> + Return (Option.map ~f:(mod_expr false) optional_expr) + | Profile ((p_name : string), stmt) -> + Profile (p_name, List.map ~f:mod_stmt stmt) + | While (predicate, body) -> + While ((mod_expr false) predicate, mod_stmt body) + | Skip | Break | Continue | Decl _ -> pattern -(** + (** Modify statement patterns in the MIR from AoS <-> SoA and vice versa @param mem_pattern A mem_pattern to modify expressions to. For the given memory pattern, this modifies @@ -663,6 +677,7 @@ let rec modify_stmt_pattern @param stmt The statement to modify. @param modifiable_set The name of the variable we are searching for. *) -and modify_stmt (Stmt.Fixed.{pattern; _} as stmt) - (modifiable_set : string Set.Poly.t) = - {stmt with pattern= modify_stmt_pattern pattern modifiable_set} + and modify_stmt (Stmt.Fixed.{pattern; _} as stmt) + (modifiable_set : string Set.Poly.t) = + {stmt with pattern= modify_stmt_pattern pattern modifiable_set} +end diff --git a/src/analysis_and_optimization/Optimize.ml b/src/analysis_and_optimization/Optimize.ml index eaa793e617..aa05c06f17 100644 --- a/src/analysis_and_optimization/Optimize.ml +++ b/src/analysis_and_optimization/Optimize.ml @@ -1174,26 +1174,22 @@ let optimize_ad_levels (mir : Program.Typed.t) = * @param mir: The program's whole MIR. *) let optimize_soa (mir : Program.Typed.t) = + let module Mem = Mem_pattern.Make (Frontend.Std_library_utils.NullLibrary) + (*TODO*) in let gen_aos_variables (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) (l : int) (aos_variables : string Set.Poly.t) = let mir_node mir_idx = Map.find_exn flowgraph_to_mir mir_idx in match (mir_node l).pattern with - | stmt -> Mem_pattern.query_demotable_stmt aos_variables stmt in + | stmt -> Mem.query_demotable_stmt aos_variables stmt in let initial_variables = List.fold ~init:Set.Poly.empty - ~f:(Mem_pattern.query_initial_demotable_stmt false) + ~f:(Mem.query_initial_demotable_stmt false) mir.log_prob in - (* - let print_set s = - Set.Poly.iter ~f:print_endline s in - let () = print_set initial_variables in - *) let mod_exprs aos_exits mod_expr = - Mir_utils.map_rec_expr (Mem_pattern.modify_expr_pattern aos_exits) mod_expr - in + Mir_utils.map_rec_expr (Mem.modify_expr_pattern aos_exits) mod_expr in let modify_stmt_patt stmt_pattern variable_set = - Mem_pattern.modify_stmt_pattern stmt_pattern variable_set in + Mem.modify_stmt_pattern stmt_pattern variable_set in let transform stmt = optimize_minimal_variables ~gen_variables:gen_aos_variables ~update_expr:mod_exprs ~update_stmt:modify_stmt_patt ~initial_variables diff --git a/src/analysis_and_optimization/Partial_evaluator.ml b/src/analysis_and_optimization/Partial_evaluator.ml index 0b898eaa03..50929f28a4 100644 --- a/src/analysis_and_optimization/Partial_evaluator.ml +++ b/src/analysis_and_optimization/Partial_evaluator.ml @@ -100,18 +100,19 @@ let rec eval_expr ?(preserve_stability = false) (e : Expr.Typed.t) = | UserDefined _ | CompilerInternal _ -> FunApp (kind, l) | StanLib (f, suffix, mem_type) -> let get_fun_or_op_rt_opt name l' = + let module TC = Frontend.Typechecker.Make ((* TODO *) + Frontend.Std_library_utils.NullLibrary) in let argument_types = List.map ~f:(fun x -> Expr.Typed.(adlevel_of x, type_of x)) l' in Operator.of_string_opt name |> Option.value_map ~f:(fun op -> - Frontend.Typechecker.operator_stan_math_return_type op - argument_types + TC.operator_return_type op argument_types |> Option.map ~f:fst ) ~default: - (Frontend.Typechecker.stan_math_return_type name - argument_types ) in + (TC.library_function_return_type name argument_types) + in let try_partially_evaluate_stanlib e = Expr.Fixed.Pattern.( match e with diff --git a/src/frontend/Canonicalize.ml b/src/frontend/Canonicalize.ml index d4bec27795..c26b900015 100644 --- a/src/frontend/Canonicalize.ml +++ b/src/frontend/Canonicalize.ml @@ -1,6 +1,5 @@ open Core_kernel open Ast -open Deprecation_analysis type canonicalizer_settings = {deprecations: bool; parentheses: bool; braces: bool; inline_includes: bool} @@ -14,232 +13,249 @@ let none = ; inline_includes= false ; braces= false } -let rec repair_syntax_stmt user_dists {stmt; smeta} = - match stmt with - | Tilde {arg; distribution= {name; id_loc}; args; truncation} -> - { stmt= - Tilde - { arg - ; distribution= {name= without_suffix user_dists name; id_loc} - ; args - ; truncation } - ; smeta } - | _ -> - { stmt= - map_statement ident (repair_syntax_stmt user_dists) ident ident stmt - ; smeta } - -let rec replace_deprecated_expr - (deprecated_userdefined : Middle.UnsizedType.t Core_kernel.String.Map.t) - {expr; emeta} = - let expr = - match expr with - | GetLP -> GetTarget - | FunApp (StanLib FnPlain, {name= "abs"; id_loc}, [e]) - when Middle.UnsizedType.is_real_type e.emeta.type_ -> - FunApp - ( StanLib FnPlain - , {name= "fabs"; id_loc} - , [replace_deprecated_expr deprecated_userdefined e] ) - | FunApp (StanLib FnPlain, {name= "if_else"; _}, [c; t; e]) -> - Paren - (replace_deprecated_expr deprecated_userdefined - {expr= TernaryIf ({expr= Paren c; emeta= c.emeta}, t, e); emeta} ) - | FunApp (StanLib suffix, {name; id_loc}, e) -> - if is_deprecated_distribution name then - CondDistApp - ( StanLib suffix - , {name= rename_deprecated deprecated_distributions name; id_loc} - , List.map ~f:(replace_deprecated_expr deprecated_userdefined) e ) - else if String.is_suffix name ~suffix:"_cdf" then - CondDistApp - ( StanLib suffix - , {name; id_loc} - , List.map ~f:(replace_deprecated_expr deprecated_userdefined) e ) - else +module type Canonicalizer = sig + val repair_syntax : + untyped_program -> canonicalizer_settings -> untyped_program + + val canonicalize_program : + typed_program -> canonicalizer_settings -> typed_program +end + +module Make (Deprecation : Deprecation_analysis.Deprecation_analizer) = struct + let rec repair_syntax_stmt user_dists {stmt; smeta} = + match stmt with + | Tilde {arg; distribution= {name; id_loc}; args; truncation} -> + { stmt= + Tilde + { arg + ; distribution= + {name= Deprecation.without_suffix user_dists name; id_loc} + ; args + ; truncation } + ; smeta } + | _ -> + { stmt= + map_statement ident (repair_syntax_stmt user_dists) ident ident stmt + ; smeta } + + let rec replace_deprecated_expr + (deprecated_userdefined : Middle.UnsizedType.t Core_kernel.String.Map.t) + {expr; emeta} = + let expr = + match expr with + | GetLP -> GetTarget + | FunApp (StanLib FnPlain, {name= "abs"; id_loc}, [e]) + when Middle.UnsizedType.is_real_type e.emeta.type_ -> FunApp - ( StanLib suffix - , {name= rename_deprecated deprecated_functions name; id_loc} - , List.map ~f:(replace_deprecated_expr deprecated_userdefined) e ) - | FunApp (UserDefined suffix, {name; id_loc}, e) -> ( - match String.Map.find deprecated_userdefined name with - | Some type_ -> - CondDistApp - ( UserDefined suffix - , {name= update_suffix name type_; id_loc} - , List.map ~f:(replace_deprecated_expr deprecated_userdefined) e ) - | None -> - if String.is_suffix name ~suffix:"_cdf" then + ( StanLib FnPlain + , {name= "fabs"; id_loc} + , [replace_deprecated_expr deprecated_userdefined e] ) + | FunApp (StanLib FnPlain, {name= "if_else"; _}, [c; t; e]) -> + Paren + (replace_deprecated_expr deprecated_userdefined + {expr= TernaryIf ({expr= Paren c; emeta= c.emeta}, t, e); emeta} ) + | FunApp (StanLib suffix, {name; id_loc}, e) -> + if Deprecation.is_deprecated_distribution name then CondDistApp - ( UserDefined suffix + ( StanLib suffix + , {name= Deprecation.rename_deprecated_distribution name; id_loc} + , List.map ~f:(replace_deprecated_expr deprecated_userdefined) e + ) + else if String.is_suffix name ~suffix:"_cdf" then + CondDistApp + ( StanLib suffix , {name; id_loc} , List.map ~f:(replace_deprecated_expr deprecated_userdefined) e ) else FunApp + ( StanLib suffix + , {name= Deprecation.rename_deprecated_function name; id_loc} + , List.map ~f:(replace_deprecated_expr deprecated_userdefined) e + ) + | FunApp (UserDefined suffix, {name; id_loc}, e) -> ( + match String.Map.find deprecated_userdefined name with + | Some type_ -> + CondDistApp ( UserDefined suffix - , {name; id_loc} + , {name= Deprecation.update_suffix name type_; id_loc} , List.map ~f:(replace_deprecated_expr deprecated_userdefined) e - ) ) - | _ -> - map_expression - (replace_deprecated_expr deprecated_userdefined) - ident expr in - {expr; emeta} - -let replace_deprecated_lval deprecated_userdefined {lval; lmeta} = - let is_multiindex = function - | Single {emeta= {type_= Middle.UnsizedType.UInt; _}; _} -> false - | _ -> true in - let rec flatten_multi = function - | LVariable id -> (LVariable id, None) - | LIndexed ({lval; lmeta}, idcs) -> ( - let outer = - List.map idcs - ~f:(map_index (replace_deprecated_expr deprecated_userdefined)) - in - let unwrap = Option.value_map ~default:[] ~f:fst in - match flatten_multi lval with - | lval, inner when List.exists ~f:is_multiindex outer -> - (lval, Some (unwrap inner @ outer, lmeta)) - | lval, None -> (LIndexed ({lval; lmeta}, outer), None) - | lval, Some (inner, _) -> (lval, Some (inner @ outer, lmeta)) ) in - let lval = - match flatten_multi lval with - | lval, None -> lval - | lval, Some (idcs, lmeta) -> LIndexed ({lval; lmeta}, idcs) in - {lval; lmeta} - -let rec replace_deprecated_stmt - (deprecated_userdefined : Middle.UnsizedType.t Core_kernel.String.Map.t) - ({stmt; smeta} : typed_statement) = - let stmt = - match stmt with - | IncrementLogProb e -> - TargetPE (replace_deprecated_expr deprecated_userdefined e) - | Assignment {assign_lhs= l; assign_op= ArrowAssign; assign_rhs= e} -> - Assignment - { assign_lhs= replace_deprecated_lval deprecated_userdefined l - ; assign_op= Assign - ; assign_rhs= (replace_deprecated_expr deprecated_userdefined) e } - | FunDef {returntype; funname= {name; id_loc}; arguments; body} -> - let newname = - match String.Map.find deprecated_userdefined name with - | Some type_ -> update_suffix name type_ - | None -> name in - FunDef - { returntype - ; funname= {name= newname; id_loc} - ; arguments - ; body= replace_deprecated_stmt deprecated_userdefined body } - | _ -> - map_statement - (replace_deprecated_expr deprecated_userdefined) - (replace_deprecated_stmt deprecated_userdefined) - (replace_deprecated_lval deprecated_userdefined) - ident stmt in - {stmt; smeta} - -let rec no_parens {expr; emeta} = - match expr with - | Paren e -> no_parens e - | Variable _ | IntNumeral _ | RealNumeral _ | ImagNumeral _ | GetLP - |GetTarget -> - {expr; emeta} - | TernaryIf _ | BinOp _ | PrefixOp _ | PostfixOp _ -> - {expr= map_expression keep_parens ident expr; emeta} - | Indexed (e, l) -> - { expr= - Indexed - ( keep_parens e - , List.map - ~f:(function - | Single e -> Single (no_parens e) - | i -> map_index keep_parens i ) - l ) - ; emeta } - | ArrayExpr _ | RowVectorExpr _ | FunApp _ | CondDistApp _ | Promotion _ -> - {expr= map_expression no_parens ident expr; emeta} - -and keep_parens {expr; emeta} = - match expr with - | Promotion (e, ut, ad) -> {expr= Promotion (keep_parens e, ut, ad); emeta} - | Paren ({expr= Paren _; _} as e) -> keep_parens e - | Paren ({expr= BinOp _; _} as e) - |Paren ({expr= PrefixOp _; _} as e) - |Paren ({expr= PostfixOp _; _} as e) - |Paren ({expr= TernaryIf _; _} as e) -> - {expr= Paren (no_parens e); emeta} - | _ -> no_parens {expr; emeta} - -let parens_lval = map_lval_with no_parens ident - -let rec parens_stmt ({stmt; smeta} : typed_statement) : typed_statement = - let stmt = - match stmt with - | VarDecl - { decl_type= d - ; transformation= t - ; identifier - ; initial_value= init - ; is_global } -> - VarDecl - { decl_type= Middle.Type.map no_parens d - ; transformation= Middle.Transformation.map keep_parens t + ) + | None -> + if String.is_suffix name ~suffix:"_cdf" then + CondDistApp + ( UserDefined suffix + , {name; id_loc} + , List.map ~f:(replace_deprecated_expr deprecated_userdefined) e + ) + else + FunApp + ( UserDefined suffix + , {name; id_loc} + , List.map ~f:(replace_deprecated_expr deprecated_userdefined) e + ) ) + | _ -> + map_expression + (replace_deprecated_expr deprecated_userdefined) + ident expr in + {expr; emeta} + + let replace_deprecated_lval deprecated_userdefined {lval; lmeta} = + let is_multiindex = function + | Single {emeta= {type_= Middle.UnsizedType.UInt; _}; _} -> false + | _ -> true in + let rec flatten_multi = function + | LVariable id -> (LVariable id, None) + | LIndexed ({lval; lmeta}, idcs) -> ( + let outer = + List.map idcs + ~f:(map_index (replace_deprecated_expr deprecated_userdefined)) + in + let unwrap = Option.value_map ~default:[] ~f:fst in + match flatten_multi lval with + | lval, inner when List.exists ~f:is_multiindex outer -> + (lval, Some (unwrap inner @ outer, lmeta)) + | lval, None -> (LIndexed ({lval; lmeta}, outer), None) + | lval, Some (inner, _) -> (lval, Some (inner @ outer, lmeta)) ) in + let lval = + match flatten_multi lval with + | lval, None -> lval + | lval, Some (idcs, lmeta) -> LIndexed ({lval; lmeta}, idcs) in + {lval; lmeta} + + let rec replace_deprecated_stmt + (deprecated_userdefined : Middle.UnsizedType.t Core_kernel.String.Map.t) + ({stmt; smeta} : typed_statement) = + let stmt = + match stmt with + | IncrementLogProb e -> + TargetPE (replace_deprecated_expr deprecated_userdefined e) + | Assignment {assign_lhs= l; assign_op= ArrowAssign; assign_rhs= e} -> + Assignment + { assign_lhs= replace_deprecated_lval deprecated_userdefined l + ; assign_op= Assign + ; assign_rhs= (replace_deprecated_expr deprecated_userdefined) e } + | FunDef {returntype; funname= {name; id_loc}; arguments; body} -> + let newname = + match String.Map.find deprecated_userdefined name with + | Some type_ -> Deprecation.update_suffix name type_ + | None -> name in + FunDef + { returntype + ; funname= {name= newname; id_loc} + ; arguments + ; body= replace_deprecated_stmt deprecated_userdefined body } + | _ -> + map_statement + (replace_deprecated_expr deprecated_userdefined) + (replace_deprecated_stmt deprecated_userdefined) + (replace_deprecated_lval deprecated_userdefined) + ident stmt in + {stmt; smeta} + + let rec no_parens {expr; emeta} = + match expr with + | Paren e -> no_parens e + | Variable _ | IntNumeral _ | RealNumeral _ | ImagNumeral _ | GetLP + |GetTarget -> + {expr; emeta} + | TernaryIf _ | BinOp _ | PrefixOp _ | PostfixOp _ -> + {expr= map_expression keep_parens ident expr; emeta} + | Indexed (e, l) -> + { expr= + Indexed + ( keep_parens e + , List.map + ~f:(function + | Single e -> Single (no_parens e) + | i -> map_index keep_parens i ) + l ) + ; emeta } + | ArrayExpr _ | RowVectorExpr _ | FunApp _ | CondDistApp _ | Promotion _ -> + {expr= map_expression no_parens ident expr; emeta} + + and keep_parens {expr; emeta} = + match expr with + | Promotion (e, ut, ad) -> {expr= Promotion (keep_parens e, ut, ad); emeta} + | Paren ({expr= Paren _; _} as e) -> keep_parens e + | Paren ({expr= BinOp _; _} as e) + |Paren ({expr= PrefixOp _; _} as e) + |Paren ({expr= PostfixOp _; _} as e) + |Paren ({expr= TernaryIf _; _} as e) -> + {expr= Paren (no_parens e); emeta} + | _ -> no_parens {expr; emeta} + + let parens_lval = map_lval_with no_parens ident + + let rec parens_stmt ({stmt; smeta} : typed_statement) : typed_statement = + let stmt = + match stmt with + | VarDecl + { decl_type= d + ; transformation= t ; identifier - ; initial_value= Option.map ~f:no_parens init - ; is_global } - | For {loop_variable; lower_bound; upper_bound; loop_body} -> - For - { loop_variable - ; lower_bound= keep_parens lower_bound - ; upper_bound= keep_parens upper_bound - ; loop_body= parens_stmt loop_body } - | _ -> map_statement no_parens parens_stmt parens_lval ident stmt in - {stmt; smeta} - -let rec blocks_stmt ({stmt; smeta} : typed_statement) : typed_statement = - let stmt_to_block ({stmt; smeta} : typed_statement) : typed_statement = - match stmt with - | Block _ -> blocks_stmt {stmt; smeta} - | _ -> - blocks_stmt - @@ mk_typed_statement - ~stmt:(Block [{stmt; smeta}]) - ~return_type:smeta.return_type ~loc:smeta.loc in - let stmt = - match stmt with - | While (e, s) -> While (e, stmt_to_block s) - | IfThenElse (e, s1, Some ({stmt= IfThenElse _; _} as s2)) - |IfThenElse (e, s1, Some {stmt= Block [({stmt= IfThenElse _; _} as s2)]; _}) - -> - (* Flatten if ... else if ... constructs *) - IfThenElse (e, stmt_to_block s1, Some (blocks_stmt s2)) - | IfThenElse (e, s1, s2) -> - IfThenElse (e, stmt_to_block s1, Option.map ~f:stmt_to_block s2) - | For ({loop_body; _} as f) -> - For {f with loop_body= stmt_to_block loop_body} - | _ -> map_statement ident blocks_stmt ident ident stmt in - {stmt; smeta} - -let repair_syntax program settings = - if settings.deprecations then - program - |> map_program - (repair_syntax_stmt (userdef_distributions program.functionblock)) - else program + ; initial_value= init + ; is_global } -> + VarDecl + { decl_type= Middle.Type.map no_parens d + ; transformation= Middle.Transformation.map keep_parens t + ; identifier + ; initial_value= Option.map ~f:no_parens init + ; is_global } + | For {loop_variable; lower_bound; upper_bound; loop_body} -> + For + { loop_variable + ; lower_bound= keep_parens lower_bound + ; upper_bound= keep_parens upper_bound + ; loop_body= parens_stmt loop_body } + | _ -> map_statement no_parens parens_stmt parens_lval ident stmt in + {stmt; smeta} + + let rec blocks_stmt ({stmt; smeta} : typed_statement) : typed_statement = + let stmt_to_block ({stmt; smeta} : typed_statement) : typed_statement = + match stmt with + | Block _ -> blocks_stmt {stmt; smeta} + | _ -> + blocks_stmt + @@ mk_typed_statement + ~stmt:(Block [{stmt; smeta}]) + ~return_type:smeta.return_type ~loc:smeta.loc in + let stmt = + match stmt with + | While (e, s) -> While (e, stmt_to_block s) + | IfThenElse (e, s1, Some ({stmt= IfThenElse _; _} as s2)) + |IfThenElse + (e, s1, Some {stmt= Block [({stmt= IfThenElse _; _} as s2)]; _}) -> + (* Flatten if ... else if ... constructs *) + IfThenElse (e, stmt_to_block s1, Some (blocks_stmt s2)) + | IfThenElse (e, s1, s2) -> + IfThenElse (e, stmt_to_block s1, Option.map ~f:stmt_to_block s2) + | For ({loop_body; _} as f) -> + For {f with loop_body= stmt_to_block loop_body} + | _ -> map_statement ident blocks_stmt ident ident stmt in + {stmt; smeta} -let canonicalize_program program settings : typed_program = - let program = + let repair_syntax program settings = if settings.deprecations then program |> map_program - (replace_deprecated_stmt (collect_userdef_distributions program)) - else program in - let program = - if settings.parentheses then program |> map_program parens_stmt else program - in - let program = - if settings.braces then program |> map_program blocks_stmt else program - in - program + (repair_syntax_stmt + (Deprecation.userdef_distributions program.functionblock) ) + else program + + let canonicalize_program program settings : typed_program = + let program = + if settings.deprecations then + program + |> map_program + (replace_deprecated_stmt + (Deprecation.collect_userdef_distributions program) ) + else program in + let program = + if settings.parentheses then program |> map_program parens_stmt + else program in + let program = + if settings.braces then program |> map_program blocks_stmt else program + in + program +end diff --git a/src/frontend/Canonicalize.mli b/src/frontend/Canonicalize.mli index 362567d649..b2282c0e91 100644 --- a/src/frontend/Canonicalize.mli +++ b/src/frontend/Canonicalize.mli @@ -13,11 +13,17 @@ type canonicalizer_settings = val all : canonicalizer_settings val none : canonicalizer_settings -val repair_syntax : untyped_program -> canonicalizer_settings -> untyped_program -(** When deprecation canonicalization is enabled, this runs before typechecking +module type Canonicalizer = sig + val repair_syntax : + untyped_program -> canonicalizer_settings -> untyped_program + (** When deprecation canonicalization is enabled, this runs before typechecking and removes suffixes from ~ statements, which are otherwise forbidden by the typechecker *) -val canonicalize_program : - typed_program -> canonicalizer_settings -> typed_program -(** "Canonicalize" the program by removing deprecations, adding or removing parenthesis + val canonicalize_program : + typed_program -> canonicalizer_settings -> typed_program + (** "Canonicalize" the program by removing deprecations, adding or removing parenthesis and braces, etc. *) +end + +module Make (Deprecation : Deprecation_analysis.Deprecation_analizer) : + Canonicalizer diff --git a/src/frontend/Deprecation_analysis.ml b/src/frontend/Deprecation_analysis.ml index a1f5db637f..5f4d5cc2a8 100644 --- a/src/frontend/Deprecation_analysis.ml +++ b/src/frontend/Deprecation_analysis.ml @@ -2,184 +2,207 @@ open Core_kernel open Ast open Middle -let deprecated_functions = - String.Map.of_alist_exn - [ ("multiply_log", ("lmultiply", "2.32.0")) - ; ("binomial_coefficient_log", ("lchoose", "2.32.0")) - ; ("cov_exp_quad", ("gp_exp_quad_cov", "2.32.0")) ] - -let deprecated_odes = - String.Map.of_alist_exn - [ ("integrate_ode", ("ode_rk45", "3.0")) - ; ("integrate_ode_rk45", ("ode_rk45", "3.0")) - ; ("integrate_ode_bdf", ("ode_bdf", "3.0")) - ; ("integrate_ode_adams", ("ode_adams", "3.0")) ] - -let deprecated_distributions = - String.Map.of_alist_exn - (List.map - ~f:(fun (x, y) -> (x, (y, "2.32.0"))) - (List.concat_map Middle.Stan_math_signatures.distributions - ~f:(fun (fnkinds, name, _, _) -> - List.filter_map fnkinds ~f:(function - | Lpdf -> Some (name ^ "_log", name ^ "_lpdf") - | Lpmf -> Some (name ^ "_log", name ^ "_lpmf") - | Cdf -> Some (name ^ "_cdf_log", name ^ "_lcdf") - | Ccdf -> Some (name ^ "_ccdf_log", name ^ "_lccdf") - | Rng | UnaryVectorized -> None ) ) ) ) - -let stan_lib_deprecations = - Map.merge_skewed deprecated_distributions deprecated_functions - ~combine:(fun ~key x y -> - Common.FatalError.fatal_error_msg - [%message - "Common key in deprecation map" - (key : string) - (x : string * string) - (y : string * string)] ) - -let is_deprecated_distribution name = - Option.is_some (Map.find deprecated_distributions name) - -let rename_deprecated map name = - Map.find map name |> Option.map ~f:fst |> Option.value ~default:name - -let distribution_suffix name = - let open String in - is_suffix ~suffix:"_lpdf" name - || is_suffix ~suffix:"_lpmf" name - || is_suffix ~suffix:"_lcdf" name - || is_suffix ~suffix:"_lccdf" name - -let userdef_distributions stmts = - let open String in - List.filter_map - ~f:(function - | {stmt= FunDef {funname= {name; _}; _}; _} -> - if - is_suffix ~suffix:"_log_lpdf" name - || is_suffix ~suffix:"_log_lpmf" name - then Some (drop_suffix name 5) - else if is_suffix ~suffix:"_log_log" name then - Some (drop_suffix name 4) - else None - | _ -> None ) - (Ast.get_stmts stmts) - -let without_suffix user_dists name = - let open String in - if is_suffix ~suffix:"_lpdf" name || is_suffix ~suffix:"_lpmf" name then - drop_suffix name 5 - else if - is_suffix ~suffix:"_log" name - && not - ( is_deprecated_distribution (name ^ "_log") - || List.exists ~f:(( = ) name) user_dists ) - then drop_suffix name 4 - else name - -let update_suffix name type_ = - let open String in - if is_suffix ~suffix:"_cdf_log" name then drop_suffix name 8 ^ "_lcdf" - else if is_suffix ~suffix:"_ccdf_log" name then drop_suffix name 9 ^ "_lccdf" - else if Middle.UnsizedType.is_int_type type_ then drop_suffix name 4 ^ "_lpmf" - else drop_suffix name 4 ^ "_lpdf" - -let find_udf_log_suffix = function - | { stmt= - FunDef - { funname= {name; _} - ; arguments= (_, ((UReal | UInt) as type_), _) :: _ - ; _ } - ; smeta= _ } - when String.is_suffix ~suffix:"_log" name -> - Some (name, type_) - | _ -> None - -let rec collect_deprecated_expr (acc : (Location_span.t * string) list) - ({expr; emeta} : (typed_expr_meta, fun_kind) expr_with) : - (Location_span.t * string) list = - match expr with - | FunApp (StanLib FnPlain, {name= "abs"; _}, [e]) - when Middle.UnsizedType.is_real_type e.emeta.type_ -> - collect_deprecated_expr - ( acc +module type Deprecation_analizer = sig + val find_udf_log_suffix : + typed_statement -> (string * Middle.UnsizedType.t) option + + val update_suffix : string -> Middle.UnsizedType.t -> string + + val collect_userdef_distributions : + typed_program -> Middle.UnsizedType.t String.Map.t + + val distribution_suffix : string -> bool + val without_suffix : string list -> string -> string + val is_deprecated_distribution : string -> bool + val rename_deprecated_distribution : string -> string + val rename_deprecated_function : string -> string + val userdef_distributions : untyped_statement block option -> string list + val collect_warnings : typed_program -> Warnings.t list +end + +module Make (StdLib : Std_library_utils.Library) : Deprecation_analizer = struct + (* String.Map.of_alist_exn + (List.map + ~f:(fun (x, y) -> (x, (y, "2.32.0"))) + (List.concat_map StdLib.distributions + ~f:(fun (fnkinds, name, _, _) -> + List.filter_map fnkinds ~f:(function + | Lpdf -> Some (name ^ "_log", name ^ "_lpdf") + | Lpmf -> Some (name ^ "_log", name ^ "_lpmf") + | Cdf -> Some (name ^ "_cdf_log", name ^ "_lcdf") + | Ccdf -> Some (name ^ "_ccdf_log", name ^ "_lccdf") + | Rng | UnaryVectorized -> None ) ) ) ) *) + (* String.Map.of_alist_exn + [ ("multiply_log", ("lmultiply", "2.32.0")) + ; ("binomial_coefficient_log", ("lchoose", "2.32.0")) + ; ("cov_exp_quad", ("gp_exp_quad_cov", "2.32.0")) ] + + + + This can be automatically changed using the \ + canonicalize flag for stanc + *) + + (* String.Map.of_alist_exn + [ ("integrate_ode", ("ode_rk45", "3.0")) + ; ("integrate_ode_rk45", ("ode_rk45", "3.0")) + ; ("integrate_ode_bdf", ("ode_bdf", "3.0")) + ; ("integrate_ode_adams", ("ode_adams", "3.0")) ] + + + + The new interface is slightly different, see: + https://mc-stan.org/users/documentation/case-studies/convert_odes.html + *) + + let stan_lib_deprecations = + Map.merge_skewed StdLib.deprecated_distributions StdLib.deprecated_functions + ~combine:(fun ~key x y -> + Common.FatalError.fatal_error_msg + [%message + "Common key in deprecation map" + (key : string) + (x : Std_library_utils.deprecation_info) + (y : Std_library_utils.deprecation_info)] ) + + let is_deprecated_distribution name = + Map.mem StdLib.deprecated_distributions name + + let rename_deprecated map name = + Map.find map name + |> Option.map ~f:(fun Std_library_utils.{replacement; _} -> replacement) + |> Option.value ~default:name + + let rename_deprecated_distribution = + rename_deprecated StdLib.deprecated_distributions + + let rename_deprecated_function = rename_deprecated StdLib.deprecated_functions + + let distribution_suffix name = + let open String in + is_suffix ~suffix:"_lpdf" name + || is_suffix ~suffix:"_lpmf" name + || is_suffix ~suffix:"_lcdf" name + || is_suffix ~suffix:"_lccdf" name + + let userdef_distributions stmts = + let open String in + List.filter_map + ~f:(function + | {stmt= FunDef {funname= {name; _}; _}; _} -> + if + is_suffix ~suffix:"_log_lpdf" name + || is_suffix ~suffix:"_log_lpmf" name + then Some (drop_suffix name 5) + else if is_suffix ~suffix:"_log_log" name then + Some (drop_suffix name 4) + else None + | _ -> None ) + (Ast.get_stmts stmts) + + let without_suffix user_dists name = + let open String in + if is_suffix ~suffix:"_lpdf" name || is_suffix ~suffix:"_lpmf" name then + drop_suffix name 5 + else if + is_suffix ~suffix:"_log" name + && not + ( is_deprecated_distribution (name ^ "_log") + || List.exists ~f:(( = ) name) user_dists ) + then drop_suffix name 4 + else name + + let update_suffix name type_ = + let open String in + if is_suffix ~suffix:"_cdf_log" name then drop_suffix name 8 ^ "_lcdf" + else if is_suffix ~suffix:"_ccdf_log" name then + drop_suffix name 9 ^ "_lccdf" + else if Middle.UnsizedType.is_int_type type_ then + drop_suffix name 4 ^ "_lpmf" + else drop_suffix name 4 ^ "_lpdf" + + let find_udf_log_suffix = function + | { stmt= + FunDef + { funname= {name; _} + ; arguments= (_, ((UReal | UInt) as type_), _) :: _ + ; _ } + ; smeta= _ } + when String.is_suffix ~suffix:"_log" name -> + Some (name, type_) + | _ -> None + + let rec collect_deprecated_expr (acc : (Location_span.t * string) list) + ({expr; emeta} : (typed_expr_meta, fun_kind) expr_with) : + (Location_span.t * string) list = + match expr with + | FunApp (StanLib FnPlain, {name= "abs"; _}, [e]) + when Middle.UnsizedType.is_real_type e.emeta.type_ -> + collect_deprecated_expr + ( acc + @ [ ( emeta.loc + , "Use of the `abs` function with real-valued arguments is \ + deprecated; use function `fabs` instead." ) ] ) + e + | FunApp (StanLib FnPlain, {name= "if_else"; _}, l) -> + acc @ [ ( emeta.loc - , "Use of the `abs` function with real-valued arguments is \ - deprecated; use function `fabs` instead." ) ] ) - e - | FunApp (StanLib FnPlain, {name= "if_else"; _}, l) -> - acc - @ [ ( emeta.loc - , "The function `if_else` is deprecated and will be removed in Stan \ - 2.32.0. Use the conditional operator (x ? y : z) instead; this \ - can be automatically changed using the canonicalize flag for \ - stanc" ) ] - @ List.concat_map l ~f:(fun e -> collect_deprecated_expr [] e) - | FunApp ((StanLib _ | UserDefined _), {name; _}, l) -> - let w = - match Map.find stan_lib_deprecations name with - | Some (rename, version) -> - [ ( emeta.loc - , name ^ " is deprecated and will be removed in Stan " ^ version - ^ ". Use " ^ rename - ^ " instead. This can be automatically changed using the \ - canonicalize flag for stanc" ) ] - | _ when String.is_suffix name ~suffix:"_cdf" -> - [ ( emeta.loc - , "Use of " ^ name - ^ " without a vertical bar (|) between the first two arguments \ - of a CDF is deprecated and will be removed in Stan 2.32.0. \ - This can be automatically changed using the canonicalize \ - flag for stanc" ) ] - | _ -> ( - match Map.find deprecated_odes name with - | Some (rename, version) -> + , "The function `if_else` is deprecated and will be removed in \ + Stan 2.32.0. Use the conditional operator (x ? y : z) instead; \ + this can be automatically changed using the canonicalize flag \ + for stanc" ) ] + @ List.concat_map l ~f:(fun e -> collect_deprecated_expr [] e) + | FunApp ((StanLib _ | UserDefined _), {name; _}, l) -> + let w = + match Map.find stan_lib_deprecations name with + | Some {replacement; version; extra_message} -> [ ( emeta.loc , name ^ " is deprecated and will be removed in Stan " ^ version - ^ ". Use " ^ rename - ^ " instead. \n\ - The new interface is slightly different, see: \ - https://mc-stan.org/users/documentation/case-studies/convert_odes.html" - ) ] - | _ -> [] ) in - acc @ w @ List.concat_map l ~f:(fun e -> collect_deprecated_expr [] e) - | _ -> fold_expression collect_deprecated_expr (fun l _ -> l) acc expr - -let collect_deprecated_lval acc l = - fold_lval_with collect_deprecated_expr (fun x _ -> x) acc l - -let rec collect_deprecated_stmt (acc : (Location_span.t * string) list) {stmt; _} - : (Location_span.t * string) list = - match stmt with - | FunDef - { body - ; funname= {name; id_loc} - ; arguments= (_, ((UReal | UInt) as type_), _) :: _ - ; _ } - when String.is_suffix ~suffix:"_log" name -> - let acc = - acc - @ [ ( id_loc - , "Use of the _log suffix in user defined probability functions is \ - deprecated and will be removed in Stan 2.32.0, use name '" - ^ update_suffix name type_ - ^ "' instead if you intend on using this function in ~ \ - statements or calling unnormalized probability functions \ - inside of it." ) ] in - collect_deprecated_stmt acc body - | FunDef {body; _} -> collect_deprecated_stmt acc body - | _ -> - fold_statement collect_deprecated_expr collect_deprecated_stmt - collect_deprecated_lval - (fun l _ -> l) - acc stmt - -let collect_userdef_distributions program = - program.functionblock |> Ast.get_stmts - |> List.filter_map ~f:find_udf_log_suffix - |> List.dedup_and_sort ~compare:(fun (x, _) (y, _) -> String.compare x y) - |> String.Map.of_alist_exn - -let collect_warnings (program : typed_program) = - fold_program collect_deprecated_stmt [] program + ^ ". Use " ^ replacement ^ " instead. " ^ extra_message ) ] + | _ when String.is_suffix name ~suffix:"_cdf" -> + [ ( emeta.loc + , "Use of " ^ name + ^ " without a vertical bar (|) between the first two \ + arguments of a CDF is deprecated and will be removed in \ + Stan 2.32.0. This can be automatically changed using the \ + canonicalize flag for stanc" ) ] + | _ -> [] in + acc @ w @ List.concat_map l ~f:(fun e -> collect_deprecated_expr [] e) + | _ -> fold_expression collect_deprecated_expr (fun l _ -> l) acc expr + + let collect_deprecated_lval acc l = + fold_lval_with collect_deprecated_expr (fun x _ -> x) acc l + + let rec collect_deprecated_stmt (acc : (Location_span.t * string) list) + {stmt; _} : (Location_span.t * string) list = + match stmt with + | FunDef + { body + ; funname= {name; id_loc} + ; arguments= (_, ((UReal | UInt) as type_), _) :: _ + ; _ } + when String.is_suffix ~suffix:"_log" name -> + let acc = + acc + @ [ ( id_loc + , "Use of the _log suffix in user defined probability functions \ + is deprecated and will be removed in Stan 2.32.0, use name '" + ^ update_suffix name type_ + ^ "' instead if you intend on using this function in ~ \ + statements or calling unnormalized probability functions \ + inside of it." ) ] in + collect_deprecated_stmt acc body + | FunDef {body; _} -> collect_deprecated_stmt acc body + | _ -> + fold_statement collect_deprecated_expr collect_deprecated_stmt + collect_deprecated_lval + (fun l _ -> l) + acc stmt + + let collect_userdef_distributions program = + program.functionblock |> Ast.get_stmts + |> List.filter_map ~f:find_udf_log_suffix + |> List.dedup_and_sort ~compare:(fun (x, _) (y, _) -> String.compare x y) + |> String.Map.of_alist_exn + + let collect_warnings (program : typed_program) = + fold_program collect_deprecated_stmt [] program +end diff --git a/src/frontend/Deprecation_analysis.mli b/src/frontend/Deprecation_analysis.mli index 205e400afd..90261e8d6f 100644 --- a/src/frontend/Deprecation_analysis.mli +++ b/src/frontend/Deprecation_analysis.mli @@ -5,19 +5,22 @@ open Core_kernel open Ast -val find_udf_log_suffix : - typed_statement -> (string * Middle.UnsizedType.t) option +module type Deprecation_analizer = sig + val find_udf_log_suffix : + typed_statement -> (string * Middle.UnsizedType.t) option -val update_suffix : string -> Middle.UnsizedType.t -> string + val update_suffix : string -> Middle.UnsizedType.t -> string -val collect_userdef_distributions : - typed_program -> Middle.UnsizedType.t String.Map.t + val collect_userdef_distributions : + typed_program -> Middle.UnsizedType.t String.Map.t -val distribution_suffix : string -> bool -val without_suffix : string list -> string -> string -val is_deprecated_distribution : string -> bool -val deprecated_distributions : (string * string) String.Map.t -val deprecated_functions : (string * string) String.Map.t -val rename_deprecated : (string * string) String.Map.t -> string -> string -val userdef_distributions : untyped_statement block option -> string list -val collect_warnings : typed_program -> Warnings.t list + val distribution_suffix : string -> bool + val without_suffix : string list -> string -> string + val is_deprecated_distribution : string -> bool + val rename_deprecated_distribution : string -> string + val rename_deprecated_function : string -> string + val userdef_distributions : untyped_statement block option -> string list + val collect_warnings : typed_program -> Warnings.t list +end + +module Make (StdLib : Std_library_utils.Library) : Deprecation_analizer diff --git a/src/frontend/Environment.ml b/src/frontend/Environment.ml index edcdc89c83..0d41e0d460 100644 --- a/src/frontend/Environment.ml +++ b/src/frontend/Environment.ml @@ -27,9 +27,9 @@ type info = type t = info list String.Map.t -let stan_math_environment = +let make_from_library signatures : t = let functions = - Hashtbl.to_alist Stan_math_signatures.stan_math_signatures + Hashtbl.to_alist signatures |> List.map ~f:(fun (key, values) -> ( key , List.map values ~f:(fun (rt, args, mem) -> diff --git a/src/frontend/Environment.mli b/src/frontend/Environment.mli index a0436c3610..0464b2c1de 100644 --- a/src/frontend/Environment.mli +++ b/src/frontend/Environment.mli @@ -29,8 +29,16 @@ type info = type t -val stan_math_environment : t -(** A type environment which contains the Stan math library functions +val make_from_library : + ( string + , ( UnsizedType.returntype + * (UnsizedType.autodifftype * UnsizedType.t) list + * Common.Helpers.mem_pattern ) + list ) + Core_kernel.Hashtbl.t + -> t +(** Make a type environment from a hashtable of functions like those from + [Std_library_utils] *) val find : t -> string -> info list diff --git a/src/frontend/Semantic_error.ml b/src/frontend/Semantic_error.ml index 79594d24fb..38090618fe 100644 --- a/src/frontend/Semantic_error.ml +++ b/src/frontend/Semantic_error.ml @@ -16,17 +16,7 @@ module TypeError = struct | ArrayVectorRowVectorMatrixExpected of UnsizedType.t | IllTypedAssignment of Operator.t * UnsizedType.t * UnsizedType.t | IllTypedTernaryIf of UnsizedType.t * UnsizedType.t * UnsizedType.t - | IllTypedReduceSum of - string - * UnsizedType.t list - * (UnsizedType.autodifftype * UnsizedType.t) list - * SignatureMismatch.function_mismatch - | IllTypedReduceSumGeneric of - string - * UnsizedType.t list - * (UnsizedType.autodifftype * UnsizedType.t) list - * SignatureMismatch.function_mismatch - | IllTypedVariadicDE of + | IllTypedVariadicFn of string * UnsizedType.t list * (UnsizedType.autodifftype * UnsizedType.t) list @@ -50,9 +40,15 @@ module TypeError = struct string * UnsizedType.t list * (SignatureMismatch.signature_error list * bool) - | IllTypedBinaryOperator of Operator.t * UnsizedType.t * UnsizedType.t - | IllTypedPrefixOperator of Operator.t * UnsizedType.t - | IllTypedPostfixOperator of Operator.t * UnsizedType.t + | IllTypedBinaryOperator of + Operator.t + * UnsizedType.t + * UnsizedType.t + * Std_library_utils.signature list + | IllTypedPrefixOperator of + Operator.t * UnsizedType.t * Std_library_utils.signature list + | IllTypedPostfixOperator of + Operator.t * UnsizedType.t * Std_library_utils.signature list | NotIndexable of UnsizedType.t * int let pp ppf = function @@ -125,13 +121,7 @@ module TypeError = struct Fmt.pf ppf "Condition in ternary expression must be primitive int; found type=%a" UnsizedType.pp ut1 - | IllTypedReduceSum (name, arg_tys, expected_args, error) -> - SignatureMismatch.pp_signature_mismatch ppf - (name, arg_tys, ([((ReturnType UReal, expected_args), error)], false)) - | IllTypedReduceSumGeneric (name, arg_tys, expected_args, error) -> - SignatureMismatch.pp_signature_mismatch ppf - (name, arg_tys, ([((ReturnType UReal, expected_args), error)], false)) - | IllTypedVariadicDE (name, arg_tys, args, error, return_type) -> + | IllTypedVariadicFn (name, arg_tys, args, error, return_type) -> SignatureMismatch.pp_signature_mismatch ppf ( name , arg_tys @@ -224,32 +214,25 @@ module TypeError = struct prefix suffix prefix prefix newsuffix | IllTypedFunctionApp (name, arg_tys, errors) -> SignatureMismatch.pp_signature_mismatch ppf (name, arg_tys, errors) - | IllTypedBinaryOperator (op, lt, rt) -> + | IllTypedBinaryOperator (op, lt, rt, sigs) -> Fmt.pf ppf "Ill-typed arguments supplied to infix operator %a. Available \ - signatures: %s@[Instead supplied arguments of incompatible type: \ - %a, %a.@]" - Operator.pp op - ( Stan_math_signatures.pretty_print_math_lib_operator_sigs op - |> String.concat ~sep:"\n" ) - UnsizedType.pp lt UnsizedType.pp rt - | IllTypedPrefixOperator (op, ut) -> + signatures: @[%a@]@[Instead supplied arguments of \ + incompatible type: %a, %a.@]" + Operator.pp op Std_library_utils.pp_math_sigs sigs UnsizedType.pp lt + UnsizedType.pp rt + | IllTypedPrefixOperator (op, ut, sigs) -> Fmt.pf ppf "Ill-typed arguments supplied to prefix operator %a. Available \ - signatures: %s@[Instead supplied argument of incompatible type: \ - %a.@]" - Operator.pp op - ( Stan_math_signatures.pretty_print_math_lib_operator_sigs op - |> String.concat ~sep:"\n" ) - UnsizedType.pp ut - | IllTypedPostfixOperator (op, ut) -> + signatures: @[%a@]@[Instead supplied argument of incompatible \ + type: %a.@]" + Operator.pp op Std_library_utils.pp_math_sigs sigs UnsizedType.pp ut + | IllTypedPostfixOperator (op, ut, sigs) -> Fmt.pf ppf "Ill-typed arguments supplied to postfix operator %a. Available \ - signatures: %s\n\ - Instead supplied argument of incompatible type: %a." Operator.pp op - ( Stan_math_signatures.pretty_print_math_lib_operator_sigs op - |> String.concat ~sep:"\n" ) - UnsizedType.pp ut + signatures: @[%a@]@[Instead supplied argument of incompatible \ + type: %a.@]" + Operator.pp op Std_library_utils.pp_math_sigs sigs UnsizedType.pp ut end module IdentifierError = struct @@ -540,34 +523,9 @@ let illtyped_ternary_if loc predt lt rt = let returning_fn_expected_nonreturning_found loc name = TypeError (loc, TypeError.ReturningFnExpectedNonReturningFound name) -let illtyped_reduce_sum loc name arg_tys args error = - TypeError (loc, TypeError.IllTypedReduceSum (name, arg_tys, args, error)) - -let illtyped_reduce_sum_generic loc name arg_tys expected_args error = +let illtyped_variadic_fn loc name arg_tys args error return_type = TypeError - ( loc - , TypeError.IllTypedReduceSumGeneric (name, arg_tys, expected_args, error) - ) - -let illtyped_variadic_ode loc name arg_tys args error = - TypeError - ( loc - , TypeError.IllTypedVariadicDE - ( name - , arg_tys - , args - , error - , Stan_math_signatures.variadic_ode_fun_return_type ) ) - -let illtyped_variadic_dae loc name arg_tys args error = - TypeError - ( loc - , TypeError.IllTypedVariadicDE - ( name - , arg_tys - , args - , error - , Stan_math_signatures.variadic_dae_fun_return_type ) ) + (loc, TypeError.IllTypedVariadicFn (name, arg_tys, args, error, return_type)) let ambiguous_function_promotion loc name arg_tys signatures = TypeError @@ -601,14 +559,14 @@ let nonreturning_fn_expected_undeclaredident_found loc name sug = let illtyped_fn_app loc name errors arg_tys = TypeError (loc, TypeError.IllTypedFunctionApp (name, arg_tys, errors)) -let illtyped_binary_op loc op lt rt = - TypeError (loc, TypeError.IllTypedBinaryOperator (op, lt, rt)) +let illtyped_binary_op loc op lt rt sigs = + TypeError (loc, TypeError.IllTypedBinaryOperator (op, lt, rt, sigs)) -let illtyped_prefix_op loc op ut = - TypeError (loc, TypeError.IllTypedPrefixOperator (op, ut)) +let illtyped_prefix_op loc op ut sigs = + TypeError (loc, TypeError.IllTypedPrefixOperator (op, ut, sigs)) -let illtyped_postfix_op loc op ut = - TypeError (loc, TypeError.IllTypedPostfixOperator (op, ut)) +let illtyped_postfix_op loc op ut sigs = + TypeError (loc, TypeError.IllTypedPostfixOperator (op, ut, sigs)) let not_indexable loc ut nidcs = TypeError (loc, TypeError.NotIndexable (ut, nidcs)) diff --git a/src/frontend/Semantic_error.mli b/src/frontend/Semantic_error.mli index 525a320694..874cbacc41 100644 --- a/src/frontend/Semantic_error.mli +++ b/src/frontend/Semantic_error.mli @@ -42,28 +42,13 @@ val returning_fn_expected_undeclared_dist_suffix_found : val returning_fn_expected_wrong_dist_suffix_found : Location_span.t -> string * string -> t -val illtyped_reduce_sum : - Location_span.t - -> string - -> UnsizedType.t list - -> (UnsizedType.autodifftype * UnsizedType.t) list - -> SignatureMismatch.function_mismatch - -> t - -val illtyped_reduce_sum_generic : - Location_span.t - -> string - -> UnsizedType.t list - -> (UnsizedType.autodifftype * UnsizedType.t) list - -> SignatureMismatch.function_mismatch - -> t - -val illtyped_variadic_ode : +val illtyped_variadic_fn : Location_span.t -> string -> UnsizedType.t list -> (UnsizedType.autodifftype * UnsizedType.t) list -> SignatureMismatch.function_mismatch + -> UnsizedType.t -> t val ambiguous_function_promotion : @@ -74,14 +59,6 @@ val ambiguous_function_promotion : list -> t -val illtyped_variadic_dae : - Location_span.t - -> string - -> UnsizedType.t list - -> (UnsizedType.autodifftype * UnsizedType.t) list - -> SignatureMismatch.function_mismatch - -> t - val nonreturning_fn_expected_returning_found : Location_span.t -> string -> t val nonreturning_fn_expected_nonfn_found : Location_span.t -> string -> t @@ -96,10 +73,27 @@ val illtyped_fn_app : -> t val illtyped_binary_op : - Location_span.t -> Operator.t -> UnsizedType.t -> UnsizedType.t -> t + Location_span.t + -> Operator.t + -> UnsizedType.t + -> UnsizedType.t + -> Std_library_utils.signature list + -> t + +val illtyped_prefix_op : + Location_span.t + -> Operator.t + -> UnsizedType.t + -> Std_library_utils.signature list + -> t + +val illtyped_postfix_op : + Location_span.t + -> Operator.t + -> UnsizedType.t + -> Std_library_utils.signature list + -> t -val illtyped_prefix_op : Location_span.t -> Operator.t -> UnsizedType.t -> t -val illtyped_postfix_op : Location_span.t -> Operator.t -> UnsizedType.t -> t val not_indexable : Location_span.t -> UnsizedType.t -> int -> t val ident_is_keyword : Location_span.t -> string -> t val ident_is_model_name : Location_span.t -> string -> t diff --git a/src/frontend/SignatureMismatch.ml b/src/frontend/SignatureMismatch.ml index f204920ac3..e1b678b046 100644 --- a/src/frontend/SignatureMismatch.ml +++ b/src/frontend/SignatureMismatch.ml @@ -245,9 +245,6 @@ let matching_function env name args = UnsizedType.compare_returntype ret1 ret2 ) in find_compatible_rt function_types args -let matching_stanlib_function = - matching_function Environment.stan_math_environment - let check_variadic_args allow_lpdf mandatory_arg_tys mandatory_fun_arg_tys fun_return args = let minimal_func_type = diff --git a/src/frontend/SignatureMismatch.mli b/src/frontend/SignatureMismatch.mli index 66a951903c..4871bf49f5 100644 --- a/src/frontend/SignatureMismatch.mli +++ b/src/frontend/SignatureMismatch.mli @@ -55,12 +55,6 @@ val matching_function : Requires a unique minimum option under type promotion *) -val matching_stanlib_function : - string -> (UnsizedType.autodifftype * UnsizedType.t) list -> match_result -(** Same as [matching_function] but requires specifically that the function - be from StanMath (uses [Environment.stan_math_environment]) -*) - val check_variadic_args : bool -> (UnsizedType.autodifftype * UnsizedType.t) list diff --git a/src/frontend/Std_library_utils.ml b/src/frontend/Std_library_utils.ml index 4b0628e9d8..6022d1540e 100644 --- a/src/frontend/Std_library_utils.ml +++ b/src/frontend/Std_library_utils.ml @@ -1,6 +1,7 @@ (** General functions and signatures for a Standard Library *) open Middle +open Core_kernel (* Types for the module representing the standard library *) type fun_arg = UnsizedType.autodifftype * UnsizedType.t @@ -23,16 +24,44 @@ let pp_math_sig ppf (rt, args, mem_pattern) = let pp_math_sigs ppf sigs = (Fmt.list ~sep:Fmt.cut pp_math_sig) ppf sigs let pretty_print_math_sigs = Fmt.str "@[@,%a@]" pp_math_sigs +type deprecation_info = + {replacement: string; version: string; extra_message: string} +[@@deriving sexp] + module type Library = sig - val stan_math_signatures : (string, signature list) Hashtbl.t + (** This module is used as a parameter for many functors which + rely on information about a backend-specific Stan library. *) + + val function_signatures : (string, signature list) Hashtbl.t (** Mapping from names to signature(s) of functions *) val distribution_families : string list - val is_stan_math_function_name : string -> bool + val is_stdlib_function_name : string -> bool (** Equivalent to [Hashtbl.mem stan_math_signatures s]*) + val get_signatures : string -> signature list + val get_operator_signatures : Operator.t -> signature list val is_not_overloadable : string -> bool val is_variadic_function_name : string -> bool val operator_to_function_names : Operator.t -> string list + val string_operator_to_function_name : string -> string + val deprecated_distributions : deprecation_info String.Map.t + val deprecated_functions : deprecation_info String.Map.t +end + +module NullLibrary : Library = struct + let function_signatures : (string, signature list) Hashtbl.t = + String.Table.create () + + let distribution_families : string list = [] + let is_stdlib_function_name _ = false + let get_signatures _ = [] + let get_operator_signatures _ = [] + let is_not_overloadable _ = false + let is_variadic_function_name _ = false + let operator_to_function_names _ = [] + let string_operator_to_function_name s = s + let deprecated_distributions = String.Map.empty + let deprecated_functions = String.Map.empty end diff --git a/src/frontend/Typechecker.ml b/src/frontend/Typechecker.ml index ccea526f94..b279c90fc5 100644 --- a/src/frontend/Typechecker.ml +++ b/src/frontend/Typechecker.ml @@ -85,32 +85,26 @@ let reserved_keywords = module type Typechecker = sig val check_program_exn : untyped_program -> typed_program * Warnings.t list - (** - Type check a full Stan program. - Can raise [Errors.SemanticError] - *) val check_program : untyped_program -> (typed_program * Warnings.t list, Semantic_error.t) result - (** - The safe version of [check_program_exn]. This catches - all [Errors.SemanticError] exceptions and converts them - into a [Result.t] - *) - val operator_stan_math_return_type : + val operator_return_type : Middle.Operator.t -> (Middle.UnsizedType.autodifftype * Middle.UnsizedType.t) list -> (Middle.UnsizedType.returntype * Promotion.t list) option - val stan_math_return_type : + val library_function_return_type : string -> (Middle.UnsizedType.autodifftype * Middle.UnsizedType.t) list -> Middle.UnsizedType.returntype option end -module Typecheck (StdLibrary : Std_library_utils.Library) : Typechecker = struct +module Make (StdLibrary : Std_library_utils.Library) : Typechecker = struct + let std_library_tenv : Env.t = + Env.make_from_library StdLibrary.function_signatures + let verify_identifier id : unit = if id.name = !model_name then Semantic_error.ident_is_model_name id.id_loc id.name |> error @@ -217,14 +211,14 @@ module Typecheck (StdLibrary : Std_library_utils.Library) : Typechecker = struct | SignatureMismatch.UniqueMatch (rt, _, _) -> Some rt | _ -> None - let stan_math_return_type name arg_tys = + let library_function_return_type name arg_tys = match name with | x when StdLibrary.is_variadic_function_name x -> Some (failwith "TODO") | _ -> - SignatureMismatch.matching_stanlib_function name arg_tys + SignatureMismatch.matching_function std_library_tenv name arg_tys |> match_to_rt_option - let operator_stan_math_return_type op arg_tys = + let operator_return_type op arg_tys = match (op, arg_tys) with | Operator.IntDivide, [(_, UnsizedType.UInt); (_, UInt)] -> Some @@ -233,19 +227,19 @@ module Typecheck (StdLibrary : Std_library_utils.Library) : Typechecker = struct | _ -> StdLibrary.operator_to_function_names op |> List.filter_map ~f:(fun name -> - SignatureMismatch.matching_stanlib_function name arg_tys + SignatureMismatch.matching_function std_library_tenv name arg_tys |> function | SignatureMismatch.UniqueMatch (rt, _, p) -> Some (rt, p) | _ -> None ) |> List.hd - let assignmentoperator_stan_math_return_type assop arg_tys = + let assignmentoperator_return_type assop arg_tys = ( match assop with | Operator.Divide -> - SignatureMismatch.matching_stanlib_function "divide" arg_tys + SignatureMismatch.matching_function std_library_tenv "divide" arg_tys |> match_to_rt_option | Plus | Minus | Times | EltTimes | EltDivide -> - operator_stan_math_return_type assop arg_tys |> Option.map ~f:fst + operator_return_type assop arg_tys |> Option.map ~f:fst | _ -> None ) |> Option.bind ~f:(function | ReturnType rtype @@ -257,7 +251,7 @@ module Typecheck (StdLibrary : Std_library_utils.Library) : Typechecker = struct | _ -> None ) let check_binop loc op le re = - let rt = [le; re] |> get_arg_types |> operator_stan_math_return_type op in + let rt = [le; re] |> get_arg_types |> operator_return_type op in match rt with | Some (ReturnType type_, [p1; p2]) -> mk_typed_expression @@ -266,27 +260,34 @@ module Typecheck (StdLibrary : Std_library_utils.Library) : Typechecker = struct ~type_ ~loc | _ -> Semantic_error.illtyped_binary_op loc op le.emeta.type_ re.emeta.type_ + (StdLibrary.get_operator_signatures op) |> error let check_prefixop loc op te = - let rt = operator_stan_math_return_type op [arg_type te] in + let rt = operator_return_type op [arg_type te] in match rt with | Some (ReturnType type_, _) -> mk_typed_expression ~expr:(PrefixOp (op, te)) ~ad_level:(expr_ad_lub [te]) ~type_ ~loc - | _ -> Semantic_error.illtyped_prefix_op loc op te.emeta.type_ |> error + | _ -> + Semantic_error.illtyped_prefix_op loc op te.emeta.type_ + (StdLibrary.get_operator_signatures op) + |> error let check_postfixop loc op te = - let rt = operator_stan_math_return_type op [arg_type te] in + let rt = operator_return_type op [arg_type te] in match rt with | Some (ReturnType type_, _) -> mk_typed_expression ~expr:(PostfixOp (te, op)) ~ad_level:(expr_ad_lub [te]) ~type_ ~loc - | _ -> Semantic_error.illtyped_postfix_op loc op te.emeta.type_ |> error + | _ -> + Semantic_error.illtyped_postfix_op loc op te.emeta.type_ + (StdLibrary.get_operator_signatures op) + |> error let check_id cf loc tenv id = match Env.find tenv (Utils.stdlib_distribution_name id.name) with @@ -493,7 +494,7 @@ module Typecheck (StdLibrary : Std_library_utils.Library) : Typechecker = struct | {kind= `Variable _; _} :: _ (* variables can sometimes shadow stanlib functions, so we have to check this *) when not - (StdLibrary.is_stan_math_function_name + (StdLibrary.is_stdlib_function_name (Utils.normalized_name id.name) ) -> Semantic_error.returning_fn_expected_nonfn_found loc id.name |> error | [] -> @@ -941,7 +942,7 @@ module Typecheck (StdLibrary : Std_library_utils.Library) : Typechecker = struct match Env.find tenv id.name with | {kind= `Variable _; _} :: _ (* variables can shadow stanlib functions, so we have to check this *) - when not (StdLibrary.is_stan_math_function_name id.name) -> + when not (StdLibrary.is_stdlib_function_name id.name) -> Semantic_error.nonreturning_fn_expected_nonfn_found loc id.name |> error | [] -> Semantic_error.nonreturning_fn_expected_undeclaredident_found loc @@ -1015,7 +1016,7 @@ module Typecheck (StdLibrary : Std_library_utils.Library) : Typechecker = struct | Error _ -> err Operator.Equals |> error ) | OperatorAssign op -> ( let args = List.map ~f:arg_type [Ast.expr_of_lvalue lhs; rhs] in - let return_type = assignmentoperator_stan_math_return_type op args in + let return_type = assignmentoperator_return_type op args in match return_type with Some Void -> rhs | _ -> err op |> error ) let check_lvalue cf tenv = function @@ -1797,7 +1798,7 @@ module Typecheck (StdLibrary : Std_library_utils.Library) : Typechecker = struct ; comments } as ast ) = warnings := [] ; (* create a new type environment which has only stan-math functions *) - let tenv = Env.stan_math_environment in + let tenv = std_library_tenv in let tenv, typed_fb = check_toplevel_block Functions tenv fb in verify_functions_have_defn tenv typed_fb ; let tenv, typed_db = check_toplevel_block Data tenv db in diff --git a/src/frontend/Typechecker.mli b/src/frontend/Typechecker.mli index 21b75b5497..397b30b60e 100644 --- a/src/frontend/Typechecker.mli +++ b/src/frontend/Typechecker.mli @@ -15,7 +15,6 @@ open Ast - val model_name : string ref (** A reference to hold the model name. Relevant for checking variable clashes and used in code generation. *) @@ -39,16 +38,15 @@ module type Typechecker = sig into a [Result.t] *) - val operator_stan_math_return_type : + val operator_return_type : Middle.Operator.t -> (Middle.UnsizedType.autodifftype * Middle.UnsizedType.t) list -> (Middle.UnsizedType.returntype * Promotion.t list) option - val stan_math_return_type : + val library_function_return_type : string -> (Middle.UnsizedType.autodifftype * Middle.UnsizedType.t) list -> Middle.UnsizedType.returntype option - end -module Typecheck (StdLibrary : Std_library_utils.Library): Typechecker +module Make (StdLibrary : Std_library_utils.Library) : Typechecker diff --git a/src/frontend/dune b/src/frontend/dune index 0359a63c1d..f5e0020097 100644 --- a/src/frontend/dune +++ b/src/frontend/dune @@ -1,7 +1,15 @@ (library (name frontend) (public_name stanc.frontend) - (libraries core_kernel re menhirLib fmt middle common yojson) + (libraries + core_kernel + re + menhirLib + fmt + middle + common + yojson + stan_math_backend) (inline_tests) (preprocess (pps ppx_jane ppx_deriving.fold ppx_deriving.map))) diff --git a/src/middle/dune b/src/middle/dune index 65e20b7aba..220ee14c6e 100644 --- a/src/middle/dune +++ b/src/middle/dune @@ -4,9 +4,4 @@ (libraries core_kernel str fmt common re) (inline_tests) (preprocess - (pps - ppx_jane - ppx_deriving.map - ppx_deriving.fold - ppx_deriving.create - ppx_deriving.show))) + (pps ppx_jane ppx_deriving.map ppx_deriving.fold ppx_deriving.create))) diff --git a/src/middle/Stan_math_signatures.ml b/src/stan_math_backend/Stan_math_signatures.ml similarity index 99% rename from src/middle/Stan_math_signatures.ml rename to src/stan_math_backend/Stan_math_signatures.ml index 66435f7b84..0bf5de07a2 100644 --- a/src/middle/Stan_math_signatures.ml +++ b/src/stan_math_backend/Stan_math_signatures.ml @@ -1,7 +1,8 @@ (** The signatures of the Stan Math library, which are used for type checking *) -open Core_kernel +open Core_kernel open Core_kernel.Poly +open Middle (** The "dimensionality" (bad name?) is supposed to help us represent the vectorized nature of many Stan functions. It allows us to represent when @@ -377,16 +378,10 @@ let is_stan_math_function_name name = let dist_name_suffix udf_names name = let is_udf_name s = List.exists ~f:(fun (n, _) -> n = s) udf_names in - match - Utils.distribution_suffices - |> List.filter ~f:(fun sfx -> - is_stan_math_function_name (name ^ sfx) || is_udf_name (name ^ sfx) ) - |> List.hd - with - | Some hd -> hd - | None -> - Common.FatalError.fatal_error_msg - [%message "Couldn't find distribution " name] + Utils.distribution_suffices + |> List.filter ~f:(fun sfx -> + is_stan_math_function_name (name ^ sfx) || is_udf_name (name ^ sfx) ) + |> List.hd_exn let operator_to_stan_math_fns op = match op with @@ -414,12 +409,6 @@ let operator_to_stan_math_fns op = | PNot -> ["logical_negation"] | Transpose -> ["transpose"] -let int_divide_type = - UnsizedType. - ( ReturnType UInt - , [(AutoDiffable, UInt); (AutoDiffable, UInt)] - , Common.Helpers.AoS ) - let get_sigs name = let name = Utils.stdlib_distribution_name name in Hashtbl.find_multi stan_math_signatures name |> List.sort ~compare @@ -489,13 +478,20 @@ let pretty_print_all_math_distributions ppf () = (List.map ~f:(Fn.compose String.lowercase show_fkind) kinds) in pf ppf "@[%a@]" (list ~sep:cut pp_dist) distributions -let pretty_print_math_lib_operator_sigs op = - if op = Operator.IntDivide then - [Fmt.str "@[@,%a@]" Std_library_utils.pp_math_sig int_divide_type] - else - operator_to_stan_math_fns op - |> List.map - ~f:(Fn.compose Std_library_utils.pretty_print_math_sigs get_sigs) +(* let int_divide_type = + UnsizedType. + ( ReturnType UInt + , [(AutoDiffable, UInt); (AutoDiffable, UInt)] + , Common.Helpers.AoS ) *) + +(* TODO turn into a get_sigs version + let pretty_print_math_lib_operator_sigs op = + if op = Operator.IntDivide then + [Fmt.str "@[@,%a@]" Std_library_utils.pp_math_sig int_divide_type] + else + operator_to_stan_math_fns op + |> List.map + ~f:(Fn.compose Std_library_utils.pretty_print_math_sigs get_sigs) *) (* -- Some helper definitions to populate stan_math_signatures -- *) let bare_types = diff --git a/src/middle/Stan_math_signatures.mli b/src/stan_math_backend/Stan_math_signatures.mli similarity index 97% rename from src/middle/Stan_math_signatures.mli rename to src/stan_math_backend/Stan_math_signatures.mli index 8d3ffb1bc7..181333aca6 100644 --- a/src/middle/Stan_math_signatures.mli +++ b/src/stan_math_backend/Stan_math_signatures.mli @@ -4,6 +4,7 @@ *) open Core_kernel +open Middle (** Function arguments are represented by their type an autodiff type. This is [AutoDiffable] for everything except arguments @@ -39,7 +40,6 @@ val dist_name_suffix : (string * 'a) list -> string -> string val operator_to_stan_math_fns : Operator.t -> string list val string_operator_to_stan_math_fns : string -> string -val pretty_print_math_lib_operator_sigs : Operator.t -> string list val make_assignmentoperator_stan_math_signatures : Operator.t -> signature list (** Special functions for the variadic signatures exposed *) diff --git a/src/stan_math_backend/dune b/src/stan_math_backend/dune index d8dd800d2a..07581843c4 100644 --- a/src/stan_math_backend/dune +++ b/src/stan_math_backend/dune @@ -12,4 +12,4 @@ statement_gen) (inline_tests) (preprocess - (pps ppx_jane ppx_deriving.map ppx_deriving.fold))) + (pps ppx_jane ppx_deriving.map ppx_deriving.fold ppx_deriving.show))) From a913d4f457e3e3b51eac5ac6d0936a56674f8a8f Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Wed, 27 Apr 2022 14:50:11 -0400 Subject: [PATCH 03/14] Finish modularizing --- .../Partial_evaluator.ml | 2 +- src/frontend/Ast_to_Mir.ml | 1637 +++++++++-------- src/frontend/Ast_to_Mir.mli | 7 + src/frontend/Info.ml | 108 +- src/frontend/Info.mli | 8 +- src/frontend/Semantic_error.ml | 16 +- src/frontend/Semantic_error.mli | 7 +- src/frontend/SignatureMismatch.ml | 4 +- src/frontend/SignatureMismatch.mli | 6 +- src/frontend/Std_library_utils.ml | 22 +- .../{Typechecker.ml => Typechecking.ml} | 13 +- .../{Typechecker.mli => Typechecking.mli} | 0 src/stanc/stanc.ml | 32 +- 13 files changed, 972 insertions(+), 890 deletions(-) rename src/frontend/{Typechecker.ml => Typechecking.ml} (99%) rename src/frontend/{Typechecker.mli => Typechecking.mli} (100%) diff --git a/src/analysis_and_optimization/Partial_evaluator.ml b/src/analysis_and_optimization/Partial_evaluator.ml index 50929f28a4..e63e77ee98 100644 --- a/src/analysis_and_optimization/Partial_evaluator.ml +++ b/src/analysis_and_optimization/Partial_evaluator.ml @@ -100,7 +100,7 @@ let rec eval_expr ?(preserve_stability = false) (e : Expr.Typed.t) = | UserDefined _ | CompilerInternal _ -> FunApp (kind, l) | StanLib (f, suffix, mem_type) -> let get_fun_or_op_rt_opt name l' = - let module TC = Frontend.Typechecker.Make ((* TODO *) + let module TC = Frontend.Typechecking.Make ((* TODO *) Frontend.Std_library_utils.NullLibrary) in let argument_types = List.map ~f:(fun x -> Expr.Typed.(adlevel_of x, type_of x)) l' diff --git a/src/frontend/Ast_to_Mir.ml b/src/frontend/Ast_to_Mir.ml index d8db6c1e3d..26da24d737 100644 --- a/src/frontend/Ast_to_Mir.ml +++ b/src/frontend/Ast_to_Mir.ml @@ -2,816 +2,849 @@ open Core_kernel open Core_kernel.Poly open Middle -let trans_fn_kind kind name = - let fname = Utils.stdlib_distribution_name name in - match kind with - | Ast.StanLib suffix -> Fun_kind.StanLib (fname, suffix, AoS) - | UserDefined suffix -> UserDefined (fname, suffix) - -let without_underscores = String.filter ~f:(( <> ) '_') - -let drop_leading_zeros s = - match String.lfindi ~f:(fun _ c -> c <> '0') s with - | Some p when p > 0 -> ( - match s.[p] with - | 'e' | '.' -> String.drop_prefix s (p - 1) - | _ -> String.drop_prefix s p ) - | Some _ -> s - | None -> "0" - -let format_number s = s |> without_underscores |> drop_leading_zeros - -let%expect_test "format_number0" = - format_number "0_000." |> print_endline ; - [%expect "0."] - -let%expect_test "format_number1" = - format_number ".123_456" |> print_endline ; - [%expect ".123456"] - -let rec op_to_funapp op args type_ = - let loc = Ast.expr_loc_lub args in - let adlevel = Ast.expr_ad_lub args in - Expr. - { Fixed.pattern= - FunApp (StanLib (Operator.to_string op, FnPlain, AoS), trans_exprs args) - ; meta= Expr.Typed.Meta.create ~type_ ~adlevel ~loc () } - -and trans_expr {Ast.expr; Ast.emeta} = - let ewrap pattern = +module type Ast_Mir_translator = sig + val gather_data : + Ast.typed_program + -> (Expr.Typed.t SizedType.t * Expr.Typed.t Transformation.t * string) list + + val trans_prog : string -> Ast.typed_program -> Program.Typed.t +end + +module Make (StdLib : Std_library_utils.Library) = struct + let trans_fn_kind kind name = + let fname = Utils.stdlib_distribution_name name in + match kind with + | Ast.StanLib suffix -> Fun_kind.StanLib (fname, suffix, AoS) + | UserDefined suffix -> UserDefined (fname, suffix) + + let without_underscores = String.filter ~f:(( <> ) '_') + + let drop_leading_zeros s = + match String.lfindi ~f:(fun _ c -> c <> '0') s with + | Some p when p > 0 -> ( + match s.[p] with + | 'e' | '.' -> String.drop_prefix s (p - 1) + | _ -> String.drop_prefix s p ) + | Some _ -> s + | None -> "0" + + let format_number s = s |> without_underscores |> drop_leading_zeros + + let%expect_test "format_number0" = + format_number "0_000." |> print_endline ; + [%expect "0."] + + let%expect_test "format_number1" = + format_number ".123_456" |> print_endline ; + [%expect ".123456"] + + let rec op_to_funapp op args type_ = + let loc = Ast.expr_loc_lub args in + let adlevel = Ast.expr_ad_lub args in Expr. - { Fixed.pattern - ; meta= - Typed.Meta. - {type_= emeta.Ast.type_; adlevel= emeta.ad_level; loc= emeta.loc} } - in - match expr with - | Ast.Paren x -> trans_expr x - | BinOp (lhs, And, rhs) -> EAnd (trans_expr lhs, trans_expr rhs) |> ewrap - | BinOp (lhs, Or, rhs) -> EOr (trans_expr lhs, trans_expr rhs) |> ewrap - | BinOp (lhs, op, rhs) -> op_to_funapp op [lhs; rhs] emeta.type_ - | PrefixOp (op, e) | Ast.PostfixOp (e, op) -> op_to_funapp op [e] emeta.type_ - | Ast.TernaryIf (cond, ifb, elseb) -> - Expr.Fixed.Pattern.TernaryIf - (trans_expr cond, trans_expr ifb, trans_expr elseb) - |> ewrap - | Variable {name; _} -> Var name |> ewrap - | IntNumeral x -> Lit (Int, format_number x) |> ewrap - | RealNumeral x -> Lit (Real, format_number x) |> ewrap - | ImagNumeral x -> Lit (Imaginary, format_number x) |> ewrap - | FunApp (fn_kind, {name; _}, args) | CondDistApp (fn_kind, {name; _}, args) - -> - FunApp (trans_fn_kind fn_kind name, trans_exprs args) |> ewrap - | GetLP | GetTarget -> FunApp (StanLib ("target", FnTarget, AoS), []) |> ewrap - | ArrayExpr eles -> - FunApp (CompilerInternal FnMakeArray, trans_exprs eles) |> ewrap - | RowVectorExpr eles -> - FunApp (CompilerInternal FnMakeRowVec, trans_exprs eles) |> ewrap - | Indexed (lhs, indices) -> - Indexed (trans_expr lhs, List.map ~f:trans_idx indices) |> ewrap - | Promotion (e, ty, ad) -> Promotion (trans_expr e, ty, ad) |> ewrap - -and trans_idx = function - | Ast.All -> All - | Ast.Upfrom e -> Upfrom (trans_expr e) - | Ast.Downfrom e -> Between (Expr.Helpers.loop_bottom, trans_expr e) - | Ast.Between (lb, ub) -> Between (trans_expr lb, trans_expr ub) - | Ast.Single e -> ( - match e.emeta.type_ with - | UInt -> Single (trans_expr e) - | UArray _ -> MultiIndex (trans_expr e) - | _ -> - Common.FatalError.fatal_error_msg - [%message "Expecting int or array" (e.emeta.type_ : UnsizedType.t)] ) - -and trans_exprs exprs = List.map ~f:trans_expr exprs - -let trans_sizedtype = SizedType.map trans_expr - -let neg_inf = - Expr. - { Fixed.pattern= FunApp (CompilerInternal FnNegInf, []) - ; meta= - Typed.Meta.{type_= UReal; loc= Location_span.empty; adlevel= DataOnly} - } - -let trans_arg (adtype, ut, ident) = (adtype, ident.Ast.name, ut) - -let truncate_dist ud_dists (id : Ast.identifier) ast_obs ast_args t = - let cdf_suffices = ["_lcdf"; "_cdf_log"] in - let ccdf_suffices = ["_lccdf"; "_ccdf_log"] in - let find_function_info sfx = - let possible_names = List.map ~f:(( ^ ) id.name) sfx |> String.Set.of_list in - match List.find ~f:(fun (n, _) -> Set.mem possible_names n) ud_dists with - | Some (name, tp) -> (Ast.UserDefined FnPlain, name, tp) - | None -> - ( Ast.StanLib FnPlain - , Set.to_list possible_names |> List.hd_exn - , if Stan_math_signatures.is_stan_math_function_name (id.name ^ "_lpmf") - then UnsizedType.UInt - else UnsizedType.UReal (* close enough *) ) in - let trunc cond_op (x : Ast.typed_expression) y = - let smeta = x.Ast.emeta.loc in - { Stmt.Fixed.meta= smeta - ; pattern= - IfElse - ( op_to_funapp cond_op [ast_obs; x] UInt - , {Stmt.Fixed.meta= smeta; pattern= TargetPE neg_inf} - , Some y ) } in - let targetme loc e = - { Stmt.Fixed.meta= loc - ; pattern= TargetPE (op_to_funapp Operator.PMinus [e] e.emeta.type_) } in - let funapp meta kind name args = - { Ast.emeta= meta - ; expr= Ast.FunApp (kind, {name; id_loc= Location_span.empty}, args) } in - let inclusive_bound tp (lb : Ast.typed_expression) = - let emeta = lb.emeta in - if UnsizedType.is_int_type tp then - Ast. - { emeta - ; expr= BinOp (lb, Operator.Minus, {emeta; expr= Ast.IntNumeral "1"}) } - else lb in - match t with - | Ast.NoTruncate -> [] - | TruncateUpFrom lb -> - let fk, fn, tp = find_function_info ccdf_suffices in - [ trunc Less lb - (targetme lb.emeta.loc - (funapp lb.emeta fk fn (inclusive_bound tp lb :: ast_args)) ) ] - | TruncateDownFrom ub -> - let fk, fn, _ = find_function_info cdf_suffices in - [ trunc Greater ub - (targetme ub.emeta.loc (funapp ub.emeta fk fn (ub :: ast_args))) ] - | TruncateBetween (lb, ub) -> - let fk, fn, tp = find_function_info cdf_suffices in - [ trunc Less lb - (trunc Greater ub - (targetme ub.emeta.loc - (funapp ub.emeta (Ast.StanLib FnPlain) "log_diff_exp" - [ funapp ub.emeta fk fn (ub :: ast_args) - ; funapp ub.emeta fk fn (inclusive_bound tp lb :: ast_args) - ] ) ) ) ] - -let unquote s = - if s.[0] = '"' && s.[String.length s - 1] = '"' then - String.drop_suffix (String.drop_prefix s 1) 1 - else s - -let trans_printables mloc (ps : Ast.typed_expression Ast.printable list) = - List.map - ~f:(function - | Ast.PString s -> - { (Expr.Helpers.str (unquote s)) with - meta= - Expr.Typed.Meta.create ~type_:UReal ~loc:mloc ~adlevel:DataOnly () - } - | Ast.PExpr e -> trans_expr e ) - ps + { Fixed.pattern= + FunApp + (StanLib (Operator.to_string op, FnPlain, AoS), trans_exprs args) + ; meta= Expr.Typed.Meta.create ~type_ ~adlevel ~loc () } + + and trans_expr {Ast.expr; Ast.emeta} = + let ewrap pattern = + Expr. + { Fixed.pattern + ; meta= + Typed.Meta. + {type_= emeta.Ast.type_; adlevel= emeta.ad_level; loc= emeta.loc} + } in + match expr with + | Ast.Paren x -> trans_expr x + | BinOp (lhs, And, rhs) -> EAnd (trans_expr lhs, trans_expr rhs) |> ewrap + | BinOp (lhs, Or, rhs) -> EOr (trans_expr lhs, trans_expr rhs) |> ewrap + | BinOp (lhs, op, rhs) -> op_to_funapp op [lhs; rhs] emeta.type_ + | PrefixOp (op, e) | Ast.PostfixOp (e, op) -> + op_to_funapp op [e] emeta.type_ + | Ast.TernaryIf (cond, ifb, elseb) -> + Expr.Fixed.Pattern.TernaryIf + (trans_expr cond, trans_expr ifb, trans_expr elseb) + |> ewrap + | Variable {name; _} -> Var name |> ewrap + | IntNumeral x -> Lit (Int, format_number x) |> ewrap + | RealNumeral x -> Lit (Real, format_number x) |> ewrap + | ImagNumeral x -> Lit (Imaginary, format_number x) |> ewrap + | FunApp (fn_kind, {name; _}, args) | CondDistApp (fn_kind, {name; _}, args) + -> + FunApp (trans_fn_kind fn_kind name, trans_exprs args) |> ewrap + | GetLP | GetTarget -> + FunApp (StanLib ("target", FnTarget, AoS), []) |> ewrap + | ArrayExpr eles -> + FunApp (CompilerInternal FnMakeArray, trans_exprs eles) |> ewrap + | RowVectorExpr eles -> + FunApp (CompilerInternal FnMakeRowVec, trans_exprs eles) |> ewrap + | Indexed (lhs, indices) -> + Indexed (trans_expr lhs, List.map ~f:trans_idx indices) |> ewrap + | Promotion (e, ty, ad) -> Promotion (trans_expr e, ty, ad) |> ewrap + + and trans_idx = function + | Ast.All -> All + | Ast.Upfrom e -> Upfrom (trans_expr e) + | Ast.Downfrom e -> Between (Expr.Helpers.loop_bottom, trans_expr e) + | Ast.Between (lb, ub) -> Between (trans_expr lb, trans_expr ub) + | Ast.Single e -> ( + match e.emeta.type_ with + | UInt -> Single (trans_expr e) + | UArray _ -> MultiIndex (trans_expr e) + | _ -> + Common.FatalError.fatal_error_msg + [%message "Expecting int or array" (e.emeta.type_ : UnsizedType.t)] + ) -(** These types signal the context for a declaration during statement translation. - They are only interpreted by trans_decl.*) -type transform_action = Check | Constrain | Unconstrain | IgnoreTransform -[@@deriving sexp] + and trans_exprs exprs = List.map ~f:trans_expr exprs -type decl_context = - {transform_action: transform_action; dadlevel: UnsizedType.autodifftype} + let trans_sizedtype = SizedType.map trans_expr -let same_shape decl_id decl_var id var meta = - if UnsizedType.is_scalar_type (Expr.Typed.type_of var) then [] - else - [ Stmt. + let neg_inf = + Expr. + { Fixed.pattern= FunApp (CompilerInternal FnNegInf, []) + ; meta= + Typed.Meta.{type_= UReal; loc= Location_span.empty; adlevel= DataOnly} + } + + let trans_arg (adtype, ut, ident) = (adtype, ident.Ast.name, ut) + + let truncate_dist ud_dists (id : Ast.identifier) ast_obs ast_args t = + let cdf_suffices = ["_lcdf"; "_cdf_log"] in + let ccdf_suffices = ["_lccdf"; "_ccdf_log"] in + let find_function_info sfx = + let possible_names = + List.map ~f:(( ^ ) id.name) sfx |> String.Set.of_list in + match List.find ~f:(fun (n, _) -> Set.mem possible_names n) ud_dists with + | Some (name, tp) -> (Ast.UserDefined FnPlain, name, tp) + | None -> + ( Ast.StanLib FnPlain + , Set.to_list possible_names |> List.hd_exn + , if StdLib.is_stdlib_function_name (id.name ^ "_lpmf") then + UnsizedType.UInt + else UnsizedType.UReal (* close enough *) ) in + let trunc cond_op (x : Ast.typed_expression) y = + let smeta = x.Ast.emeta.loc in + { Stmt.Fixed.meta= smeta + ; pattern= + IfElse + ( op_to_funapp cond_op [ast_obs; x] UInt + , {Stmt.Fixed.meta= smeta; pattern= TargetPE neg_inf} + , Some y ) } in + let targetme loc e = + { Stmt.Fixed.meta= loc + ; pattern= TargetPE (op_to_funapp Operator.PMinus [e] e.emeta.type_) } + in + let funapp meta kind name args = + { Ast.emeta= meta + ; expr= Ast.FunApp (kind, {name; id_loc= Location_span.empty}, args) } + in + let inclusive_bound tp (lb : Ast.typed_expression) = + let emeta = lb.emeta in + if UnsizedType.is_int_type tp then + Ast. + { emeta + ; expr= BinOp (lb, Operator.Minus, {emeta; expr= Ast.IntNumeral "1"}) + } + else lb in + match t with + | Ast.NoTruncate -> [] + | TruncateUpFrom lb -> + let fk, fn, tp = find_function_info ccdf_suffices in + [ trunc Less lb + (targetme lb.emeta.loc + (funapp lb.emeta fk fn (inclusive_bound tp lb :: ast_args)) ) ] + | TruncateDownFrom ub -> + let fk, fn, _ = find_function_info cdf_suffices in + [ trunc Greater ub + (targetme ub.emeta.loc (funapp ub.emeta fk fn (ub :: ast_args))) ] + | TruncateBetween (lb, ub) -> + let fk, fn, tp = find_function_info cdf_suffices in + [ trunc Less lb + (trunc Greater ub + (targetme ub.emeta.loc + (funapp ub.emeta (Ast.StanLib FnPlain) "log_diff_exp" + [ funapp ub.emeta fk fn (ub :: ast_args) + ; funapp ub.emeta fk fn (inclusive_bound tp lb :: ast_args) + ] ) ) ) ] + + let unquote s = + if s.[0] = '"' && s.[String.length s - 1] = '"' then + String.drop_suffix (String.drop_prefix s 1) 1 + else s + + let trans_printables mloc (ps : Ast.typed_expression Ast.printable list) = + List.map + ~f:(function + | Ast.PString s -> + { (Expr.Helpers.str (unquote s)) with + meta= + Expr.Typed.Meta.create ~type_:UReal ~loc:mloc ~adlevel:DataOnly + () } + | Ast.PExpr e -> trans_expr e ) + ps + + (** These types signal the context for a declaration during statement translation. + They are only interpreted by trans_decl.*) + type transform_action = Check | Constrain | Unconstrain | IgnoreTransform + [@@deriving sexp] + + type decl_context = + {transform_action: transform_action; dadlevel: UnsizedType.autodifftype} + + let same_shape decl_id decl_var id var meta = + if UnsizedType.is_scalar_type (Expr.Typed.type_of var) then [] + else + [ Stmt. + { Fixed.pattern= + NRFunApp + ( StanLib ("check_matching_dims", FnPlain, AoS) + , Expr.Helpers. + [str "constraint"; str decl_id; decl_var; str id; var] ) + ; meta } ] + + let check_transform_shape decl_id decl_var meta = function + | Transformation.Offset e -> same_shape decl_id decl_var "offset" e meta + | Multiplier e -> same_shape decl_id decl_var "multiplier" e meta + | Lower e -> same_shape decl_id decl_var "lower" e meta + | Upper e -> same_shape decl_id decl_var "upper" e meta + | OffsetMultiplier (e1, e2) -> + same_shape decl_id decl_var "offset" e1 meta + @ same_shape decl_id decl_var "multiplier" e2 meta + | LowerUpper (e1, e2) -> + same_shape decl_id decl_var "lower" e1 meta + @ same_shape decl_id decl_var "upper" e2 meta + | Covariance | Correlation | CholeskyCov | CholeskyCorr | Ordered + |PositiveOrdered | Simplex | UnitVector | Identity -> + [] + + let copy_indices indexed (var : Expr.Typed.t) = + if UnsizedType.is_scalar_type var.meta.type_ then var + else + match Expr.Helpers.collect_indices indexed with + | [] -> var + | indices -> + Expr.Fixed. + { pattern= Indexed (var, indices) + ; meta= + { var.meta with + type_= + Expr.Helpers.infer_type_of_indexed var.meta.type_ indices } + } + + let extract_transform_args var = function + | Transformation.Lower a | Upper a -> [copy_indices var a] + | Offset a -> + [copy_indices var a; {a with Expr.Fixed.pattern= Lit (Int, "1")}] + | Multiplier a -> [{a with pattern= Lit (Int, "0")}; copy_indices var a] + | LowerUpper (a1, a2) | OffsetMultiplier (a1, a2) -> + [copy_indices var a1; copy_indices var a2] + | Covariance | Correlation | CholeskyCov | CholeskyCorr | Ordered + |PositiveOrdered | Simplex | UnitVector | Identity -> + [] + + let param_size transform sizedtype = + let rec shrink_eigen f st = + match st with + | SizedType.SArray (t, d) -> SizedType.SArray (shrink_eigen f t, d) + | SVector (mem_pattern, d) | SMatrix (mem_pattern, d, _) -> + SVector (mem_pattern, f d) + | SInt | SReal | SComplex | SRowVector _ | SComplexRowVector _ + |SComplexVector _ | SComplexMatrix _ -> + Common.FatalError.fatal_error_msg + [%message + "Expecting SVector or SMatrix, got " + (st : Expr.Typed.t SizedType.t)] in + let rec shrink_eigen_mat f st = + match st with + | SizedType.SArray (t, d) -> SizedType.SArray (shrink_eigen_mat f t, d) + | SMatrix (mem_pattern, d1, d2) -> SVector (mem_pattern, f d1 d2) + | SInt | SReal | SComplex | SRowVector _ | SVector _ + |SComplexRowVector _ | SComplexVector _ | SComplexMatrix _ -> + Common.FatalError.fatal_error_msg + [%message "Expecting SMatrix, got " (st : Expr.Typed.t SizedType.t)] + in + let k_choose_2 k = + Expr.Helpers.( + binop (binop k Times (binop k Minus (int 1))) Divide (int 2)) in + match transform with + | Transformation.Identity | Lower _ | Upper _ + |LowerUpper (_, _) + |Offset _ | Multiplier _ + |OffsetMultiplier (_, _) + |Ordered | PositiveOrdered | UnitVector -> + sizedtype + | Simplex -> + shrink_eigen (fun d -> Expr.Helpers.(binop d Minus (int 1))) sizedtype + | CholeskyCorr | Correlation -> shrink_eigen k_choose_2 sizedtype + | CholeskyCov -> + (* (N * (N + 1)) / 2 + (M - N) * N *) + shrink_eigen_mat + (fun m n -> + Expr.Helpers.( + binop + (binop (k_choose_2 n) Plus n) + Plus + (binop (binop m Minus n) Times n)) ) + sizedtype + | Covariance -> + shrink_eigen + (fun k -> Expr.Helpers.(binop k Plus (k_choose_2 k))) + sizedtype + + let rec check_decl var decl_type' decl_id decl_trans smeta adlevel = + match decl_trans with + | Transformation.LowerUpper (lb, ub) -> + check_decl var decl_type' decl_id (Lower lb) smeta adlevel + @ check_decl var decl_type' decl_id (Upper ub) smeta adlevel + | _ when Transformation.has_check decl_trans -> + let check_id id = + let var_name = Fmt.str "%a" Expr.Typed.pp id in + let args = extract_transform_args id decl_trans in + Stmt.Helpers.internal_nrfunapp + (FnCheck {trans= decl_trans; var_name; var= id}) + args smeta in + [check_id var] + | _ -> [] + + let check_sizedtype name = + let check x = function + | {Expr.Fixed.pattern= Lit (Int, i); _} when float_of_string i >= 0. -> [] + | n -> + [ Stmt.Helpers.internal_nrfunapp FnValidateSize + Expr.Helpers. + [ str name + ; str (Fmt.str "%a" Pretty_printing.pp_typed_expression x); n ] + n.meta.loc ] in + let rec sizedtype = function + | SizedType.(SInt | SReal | SComplex) as t -> ([], t) + | SVector (mem_pattern, s) -> + let e = trans_expr s in + (check s e, SizedType.SVector (mem_pattern, e)) + | SRowVector (mem_pattern, s) -> + let e = trans_expr s in + (check s e, SizedType.SRowVector (mem_pattern, e)) + | SMatrix (mem_pattern, r, c) -> + let er = trans_expr r in + let ec = trans_expr c in + (check r er @ check c ec, SizedType.SMatrix (mem_pattern, er, ec)) + | SComplexVector s -> + let e = trans_expr s in + (check s e, SizedType.SComplexVector e) + | SComplexRowVector s -> + let e = trans_expr s in + (check s e, SizedType.SComplexRowVector e) + | SComplexMatrix (r, c) -> + let er = trans_expr r in + let ec = trans_expr c in + (check r er @ check c ec, SizedType.SComplexMatrix (er, ec)) + | SArray (t, s) -> + let e = trans_expr s in + let ll, t = sizedtype t in + (check s e @ ll, SizedType.SArray (t, e)) in + function + | Type.Sized st -> + let ll, st = sizedtype st in + (ll, Type.Sized st) + | Unsized ut -> ([], Unsized ut) + + let trans_decl {transform_action; dadlevel} smeta decl_type transform + identifier initial_value = + let decl_id = identifier.Ast.name in + let rhs = Option.map ~f:trans_expr initial_value in + let size_checks, dt = check_sizedtype identifier.name decl_type in + let decl_adtype = dadlevel in + let decl_var = + Expr. + { Fixed.pattern= Var decl_id + ; meta= + Typed.Meta.create ~adlevel:dadlevel ~loc:smeta + ~type_:(Type.to_unsized decl_type) + () } in + let decl = + Stmt. { Fixed.pattern= - NRFunApp - ( StanLib ("check_matching_dims", FnPlain, AoS) - , Expr.Helpers. - [str "constraint"; str decl_id; decl_var; str id; var] ) - ; meta } ] - -let check_transform_shape decl_id decl_var meta = function - | Transformation.Offset e -> same_shape decl_id decl_var "offset" e meta - | Multiplier e -> same_shape decl_id decl_var "multiplier" e meta - | Lower e -> same_shape decl_id decl_var "lower" e meta - | Upper e -> same_shape decl_id decl_var "upper" e meta - | OffsetMultiplier (e1, e2) -> - same_shape decl_id decl_var "offset" e1 meta - @ same_shape decl_id decl_var "multiplier" e2 meta - | LowerUpper (e1, e2) -> - same_shape decl_id decl_var "lower" e1 meta - @ same_shape decl_id decl_var "upper" e2 meta - | Covariance | Correlation | CholeskyCov | CholeskyCorr | Ordered - |PositiveOrdered | Simplex | UnitVector | Identity -> - [] - -let copy_indices indexed (var : Expr.Typed.t) = - if UnsizedType.is_scalar_type var.meta.type_ then var - else - match Expr.Helpers.collect_indices indexed with - | [] -> var - | indices -> - Expr.Fixed. - { pattern= Indexed (var, indices) - ; meta= - { var.meta with - type_= Expr.Helpers.infer_type_of_indexed var.meta.type_ indices - } } - -let extract_transform_args var = function - | Transformation.Lower a | Upper a -> [copy_indices var a] - | Offset a -> [copy_indices var a; {a with Expr.Fixed.pattern= Lit (Int, "1")}] - | Multiplier a -> [{a with pattern= Lit (Int, "0")}; copy_indices var a] - | LowerUpper (a1, a2) | OffsetMultiplier (a1, a2) -> - [copy_indices var a1; copy_indices var a2] - | Covariance | Correlation | CholeskyCov | CholeskyCorr | Ordered - |PositiveOrdered | Simplex | UnitVector | Identity -> - [] - -let param_size transform sizedtype = - let rec shrink_eigen f st = - match st with - | SizedType.SArray (t, d) -> SizedType.SArray (shrink_eigen f t, d) - | SVector (mem_pattern, d) | SMatrix (mem_pattern, d, _) -> - SVector (mem_pattern, f d) - | SInt | SReal | SComplex | SRowVector _ | SComplexRowVector _ - |SComplexVector _ | SComplexMatrix _ -> + Decl {decl_adtype; decl_id; decl_type= dt; initialize= true} + ; meta= smeta } in + let rhs_assignment = + Option.map + ~f:(fun e -> + Stmt.Fixed. + {pattern= Assignment ((decl_id, e.meta.type_, []), e); meta= smeta} + ) + rhs + |> Option.to_list in + if Utils.is_user_ident decl_id then + let constrain_checks = + match transform_action with + | Constrain | Unconstrain -> + Common.FatalError.fatal_error_msg + [%message "Constraints must use trans_sizedtype_decl instead"] + | Check -> + check_transform_shape decl_id decl_var smeta transform + @ check_decl decl_var dt decl_id transform smeta dadlevel + | IgnoreTransform -> [] in + size_checks @ (decl :: rhs_assignment) @ constrain_checks + else size_checks @ (decl :: rhs_assignment) + + let unwrap_block_or_skip = function + | [({Stmt.Fixed.pattern= Block _; _} as b)] -> Some b + | [{pattern= Skip; _}] -> None + | x -> + Common.FatalError.fatal_error_msg + [%message "Expecting a block or skip, not" (x : Stmt.Located.t list)] + + let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) + = + let stmt_typed = ts.stmt and smeta = ts.smeta.loc in + let trans_stmt = + trans_stmt ud_dists {declc with transform_action= IgnoreTransform} in + let trans_single_stmt s = + match trans_stmt s with + | [s] -> s + | s -> Stmt.Fixed.{pattern= SList s; meta= smeta} in + let swrap pattern = [Stmt.Fixed.{meta= smeta; pattern}] in + let mloc = smeta in + match stmt_typed with + | Ast.Assignment {assign_lhs; assign_rhs; assign_op} -> + let rec get_lhs_base = function + | {Ast.lval= Ast.LIndexed (l, _); _} -> get_lhs_base l + | {lval= LVariable s; lmeta} -> (s, lmeta) in + let assign_identifier, lmeta = get_lhs_base assign_lhs in + let id_ad_level = lmeta.Ast.ad_level in + let id_type_ = lmeta.Ast.type_ in + let lhs_type_ = assign_lhs.Ast.lmeta.type_ in + let lhs_ad_level = assign_lhs.Ast.lmeta.ad_level in + let rec get_lhs_indices = function + | {Ast.lval= Ast.LIndexed (l, i); _} -> get_lhs_indices l @ i + | {Ast.lval= Ast.LVariable _; _} -> [] in + let assign_indices = get_lhs_indices assign_lhs in + let assignee = + { Ast.expr= + ( match assign_indices with + | [] -> Ast.Variable assign_identifier + | _ -> + Ast.Indexed + ( { expr= Ast.Variable assign_identifier + ; emeta= + { Ast.loc= Location_span.empty + ; ad_level= id_ad_level + ; type_= id_type_ } } + , assign_indices ) ) + ; emeta= + { Ast.loc= assign_lhs.lmeta.loc + ; ad_level= lhs_ad_level + ; type_= lhs_type_ } } in + let rhs = + match assign_op with + | Ast.Assign | Ast.ArrowAssign -> trans_expr assign_rhs + | Ast.OperatorAssign op -> + op_to_funapp op [assignee; assign_rhs] assignee.emeta.type_ in + Assignment + ( ( assign_identifier.Ast.name + , id_type_ + , List.map ~f:trans_idx assign_indices ) + , rhs ) + |> swrap + | Ast.NRFunApp (fn_kind, {name; _}, args) -> + NRFunApp (trans_fn_kind fn_kind name, trans_exprs args) |> swrap + | Ast.IncrementLogProb e | Ast.TargetPE e -> + TargetPE (trans_expr e) |> swrap + | Ast.Tilde {arg; distribution; args; truncation} -> + let suffix = + Std_library_utils.dist_name_suffix + (module StdLib) + ud_dists distribution.name in + let name = distribution.name ^ suffix in + let kind = + let possible_names = + List.map ~f:(( ^ ) distribution.name) Utils.distribution_suffices + |> String.Set.of_list in + if List.exists ~f:(fun (n, _) -> Set.mem possible_names n) ud_dists + then Fun_kind.UserDefined (name, FnLpdf true) + else StanLib (name, FnLpdf true, AoS) in + let add_dist = + Stmt.Fixed.Pattern.TargetPE + Expr. + { Fixed.pattern= FunApp (kind, trans_exprs (arg :: args)) + ; meta= + Typed.Meta.create ~type_:UReal ~loc:mloc + ~adlevel:(Ast.expr_ad_lub (arg :: args)) + () } in + truncate_dist ud_dists distribution arg args truncation @ swrap add_dist + | Ast.Print ps -> + NRFunApp (CompilerInternal FnPrint, trans_printables smeta ps) |> swrap + | Ast.Reject ps -> + NRFunApp (CompilerInternal FnReject, trans_printables smeta ps) |> swrap + | Ast.IfThenElse (cond, ifb, elseb) -> + IfElse + ( trans_expr cond + , trans_single_stmt ifb + , Option.map ~f:trans_single_stmt elseb ) + |> swrap + | Ast.While (cond, body) -> + While (trans_expr cond, trans_single_stmt body) |> swrap + | Ast.For {loop_variable; lower_bound; upper_bound; loop_body} -> + let body = + match trans_single_stmt loop_body with + | {pattern= Block _; _} as b -> b + | x -> {x with pattern= Block [x]} in + For + { loopvar= loop_variable.Ast.name + ; lower= trans_expr lower_bound + ; upper= trans_expr upper_bound + ; body } + |> swrap + | Ast.ForEach (loopvar, iteratee, body) -> + let iteratee' = trans_expr iteratee in + let body_stmts = + match trans_single_stmt body with + | {pattern= Block body_stmts; _} -> body_stmts + | b -> [b] in + let decl_type = + match Expr.Typed.type_of iteratee' with + | UMatrix -> UnsizedType.UReal + | t -> + Expr.Helpers.(infer_type_of_indexed t [Index.Single loop_bottom]) + in + let decl_loopvar = + Stmt.Fixed. + { meta= smeta + ; pattern= + Decl + { decl_adtype= Expr.Typed.adlevel_of iteratee' + ; decl_id= loopvar.name + ; decl_type= Unsized decl_type + ; initialize= true } } in + let assignment var = + Stmt.Fixed. + { pattern= Assignment ((loopvar.name, decl_type, []), var) + ; meta= smeta } in + let bodyfn var = + Stmt.Fixed. + { pattern= Block (decl_loopvar :: assignment var :: body_stmts) + ; meta= smeta } in + Stmt.Helpers.[ensure_var (for_each bodyfn) iteratee' smeta] + | Ast.FunDef _ -> Common.FatalError.fatal_error_msg [%message - "Expecting SVector or SMatrix, got " (st : Expr.Typed.t SizedType.t)] - in - let rec shrink_eigen_mat f st = - match st with - | SizedType.SArray (t, d) -> SizedType.SArray (shrink_eigen_mat f t, d) - | SMatrix (mem_pattern, d1, d2) -> SVector (mem_pattern, f d1 d2) - | SInt | SReal | SComplex | SRowVector _ | SVector _ | SComplexRowVector _ - |SComplexVector _ | SComplexMatrix _ -> + "Found function definition statement outside of function block"] + | Ast.VarDecl + {decl_type; transformation; identifier; initial_value; is_global= _} -> + trans_decl declc smeta decl_type + (Transformation.map trans_expr transformation) + identifier initial_value + | Ast.Block stmts -> Block (List.concat_map ~f:trans_stmt stmts) |> swrap + | Ast.Profile (name, stmts) -> + Profile (name, List.concat_map ~f:trans_stmt stmts) |> swrap + | Ast.Return e -> Return (Some (trans_expr e)) |> swrap + | Ast.ReturnVoid -> Return None |> swrap + | Ast.Break -> Break |> swrap + | Ast.Continue -> Continue |> swrap + | Ast.Skip -> Skip |> swrap + + let trans_fun_def ud_dists (ts : Ast.typed_statement) = + match ts.stmt with + | Ast.FunDef {returntype; funname; arguments; body} -> + [ Program. + { fdrt= + (match returntype with Void -> None | ReturnType ut -> Some ut) + ; fdname= funname.name + ; fdsuffix= + Fun_kind.(suffix_from_name funname.name |> without_propto) + ; fdargs= List.map ~f:trans_arg arguments + ; fdbody= + trans_stmt ud_dists + {transform_action= IgnoreTransform; dadlevel= AutoDiffable} + body + |> unwrap_block_or_skip + ; fdloc= ts.smeta.loc } ] + | _ -> Common.FatalError.fatal_error_msg - [%message "Expecting SMatrix, got " (st : Expr.Typed.t SizedType.t)] - in - let k_choose_2 k = - Expr.Helpers.(binop (binop k Times (binop k Minus (int 1))) Divide (int 2)) - in - match transform with - | Transformation.Identity | Lower _ | Upper _ - |LowerUpper (_, _) - |Offset _ | Multiplier _ - |OffsetMultiplier (_, _) - |Ordered | PositiveOrdered | UnitVector -> - sizedtype - | Simplex -> - shrink_eigen (fun d -> Expr.Helpers.(binop d Minus (int 1))) sizedtype - | CholeskyCorr | Correlation -> shrink_eigen k_choose_2 sizedtype - | CholeskyCov -> - (* (N * (N + 1)) / 2 + (M - N) * N *) - shrink_eigen_mat - (fun m n -> - Expr.Helpers.( - binop - (binop (k_choose_2 n) Plus n) - Plus - (binop (binop m Minus n) Times n)) ) - sizedtype - | Covariance -> - shrink_eigen - (fun k -> Expr.Helpers.(binop k Plus (k_choose_2 k))) - sizedtype - -let rec check_decl var decl_type' decl_id decl_trans smeta adlevel = - match decl_trans with - | Transformation.LowerUpper (lb, ub) -> - check_decl var decl_type' decl_id (Lower lb) smeta adlevel - @ check_decl var decl_type' decl_id (Upper ub) smeta adlevel - | _ when Transformation.has_check decl_trans -> - let check_id id = - let var_name = Fmt.str "%a" Expr.Typed.pp id in - let args = extract_transform_args id decl_trans in - Stmt.Helpers.internal_nrfunapp - (FnCheck {trans= decl_trans; var_name; var= id}) - args smeta in - [check_id var] - | _ -> [] - -let check_sizedtype name = - let check x = function - | {Expr.Fixed.pattern= Lit (Int, i); _} when float_of_string i >= 0. -> [] - | n -> - [ Stmt.Helpers.internal_nrfunapp FnValidateSize - Expr.Helpers. - [ str name - ; str (Fmt.str "%a" Pretty_printing.pp_typed_expression x); n ] - n.meta.loc ] in - let rec sizedtype = function - | SizedType.(SInt | SReal | SComplex) as t -> ([], t) - | SVector (mem_pattern, s) -> - let e = trans_expr s in - (check s e, SizedType.SVector (mem_pattern, e)) - | SRowVector (mem_pattern, s) -> - let e = trans_expr s in - (check s e, SizedType.SRowVector (mem_pattern, e)) - | SMatrix (mem_pattern, r, c) -> - let er = trans_expr r in - let ec = trans_expr c in - (check r er @ check c ec, SizedType.SMatrix (mem_pattern, er, ec)) - | SComplexVector s -> - let e = trans_expr s in - (check s e, SizedType.SComplexVector e) - | SComplexRowVector s -> - let e = trans_expr s in - (check s e, SizedType.SComplexRowVector e) - | SComplexMatrix (r, c) -> - let er = trans_expr r in - let ec = trans_expr c in - (check r er @ check c ec, SizedType.SComplexMatrix (er, ec)) - | SArray (t, s) -> - let e = trans_expr s in - let ll, t = sizedtype t in - (check s e @ ll, SizedType.SArray (t, e)) in - function - | Type.Sized st -> - let ll, st = sizedtype st in - (ll, Type.Sized st) - | Unsized ut -> ([], Unsized ut) - -let trans_decl {transform_action; dadlevel} smeta decl_type transform identifier - initial_value = - let decl_id = identifier.Ast.name in - let rhs = Option.map ~f:trans_expr initial_value in - let size_checks, dt = check_sizedtype identifier.name decl_type in - let decl_adtype = dadlevel in - let decl_var = - Expr. - { Fixed.pattern= Var decl_id - ; meta= - Typed.Meta.create ~adlevel:dadlevel ~loc:smeta - ~type_:(Type.to_unsized decl_type) - () } in - let decl = - Stmt. - { Fixed.pattern= - Decl {decl_adtype; decl_id; decl_type= dt; initialize= true} - ; meta= smeta } in - let rhs_assignment = - Option.map - ~f:(fun e -> - Stmt.Fixed. - {pattern= Assignment ((decl_id, e.meta.type_, []), e); meta= smeta} ) - rhs - |> Option.to_list in - if Utils.is_user_ident decl_id then - let constrain_checks = - match transform_action with - | Constrain | Unconstrain -> - Common.FatalError.fatal_error_msg - [%message "Constraints must use trans_sizedtype_decl instead"] - | Check -> - check_transform_shape decl_id decl_var smeta transform - @ check_decl decl_var dt decl_id transform smeta dadlevel - | IgnoreTransform -> [] in - size_checks @ (decl :: rhs_assignment) @ constrain_checks - else size_checks @ (decl :: rhs_assignment) - -let unwrap_block_or_skip = function - | [({Stmt.Fixed.pattern= Block _; _} as b)] -> Some b - | [{pattern= Skip; _}] -> None - | x -> - Common.FatalError.fatal_error_msg - [%message "Expecting a block or skip, not" (x : Stmt.Located.t list)] - -let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) = - let stmt_typed = ts.stmt and smeta = ts.smeta.loc in - let trans_stmt = - trans_stmt ud_dists {declc with transform_action= IgnoreTransform} in - let trans_single_stmt s = - match trans_stmt s with - | [s] -> s - | s -> Stmt.Fixed.{pattern= SList s; meta= smeta} in - let swrap pattern = [Stmt.Fixed.{meta= smeta; pattern}] in - let mloc = smeta in - match stmt_typed with - | Ast.Assignment {assign_lhs; assign_rhs; assign_op} -> - let rec get_lhs_base = function - | {Ast.lval= Ast.LIndexed (l, _); _} -> get_lhs_base l - | {lval= LVariable s; lmeta} -> (s, lmeta) in - let assign_identifier, lmeta = get_lhs_base assign_lhs in - let id_ad_level = lmeta.Ast.ad_level in - let id_type_ = lmeta.Ast.type_ in - let lhs_type_ = assign_lhs.Ast.lmeta.type_ in - let lhs_ad_level = assign_lhs.Ast.lmeta.ad_level in - let rec get_lhs_indices = function - | {Ast.lval= Ast.LIndexed (l, i); _} -> get_lhs_indices l @ i - | {Ast.lval= Ast.LVariable _; _} -> [] in - let assign_indices = get_lhs_indices assign_lhs in - let assignee = - { Ast.expr= - ( match assign_indices with - | [] -> Ast.Variable assign_identifier - | _ -> - Ast.Indexed - ( { expr= Ast.Variable assign_identifier - ; emeta= - { Ast.loc= Location_span.empty - ; ad_level= id_ad_level - ; type_= id_type_ } } - , assign_indices ) ) - ; emeta= - { Ast.loc= assign_lhs.lmeta.loc - ; ad_level= lhs_ad_level - ; type_= lhs_type_ } } in - let rhs = - match assign_op with - | Ast.Assign | Ast.ArrowAssign -> trans_expr assign_rhs - | Ast.OperatorAssign op -> - op_to_funapp op [assignee; assign_rhs] assignee.emeta.type_ in - Assignment - ( ( assign_identifier.Ast.name - , id_type_ - , List.map ~f:trans_idx assign_indices ) - , rhs ) - |> swrap - | Ast.NRFunApp (fn_kind, {name; _}, args) -> - NRFunApp (trans_fn_kind fn_kind name, trans_exprs args) |> swrap - | Ast.IncrementLogProb e | Ast.TargetPE e -> TargetPE (trans_expr e) |> swrap - | Ast.Tilde {arg; distribution; args; truncation} -> - let suffix = - Stan_math_signatures.dist_name_suffix ud_dists distribution.name in - let name = distribution.name ^ suffix in - let kind = - let possible_names = - List.map ~f:(( ^ ) distribution.name) Utils.distribution_suffices - |> String.Set.of_list in - if List.exists ~f:(fun (n, _) -> Set.mem possible_names n) ud_dists then - Fun_kind.UserDefined (name, FnLpdf true) - else StanLib (name, FnLpdf true, AoS) in - let add_dist = - Stmt.Fixed.Pattern.TargetPE - Expr. - { Fixed.pattern= FunApp (kind, trans_exprs (arg :: args)) - ; meta= - Typed.Meta.create ~type_:UReal ~loc:mloc - ~adlevel:(Ast.expr_ad_lub (arg :: args)) - () } in - truncate_dist ud_dists distribution arg args truncation @ swrap add_dist - | Ast.Print ps -> - NRFunApp (CompilerInternal FnPrint, trans_printables smeta ps) |> swrap - | Ast.Reject ps -> - NRFunApp (CompilerInternal FnReject, trans_printables smeta ps) |> swrap - | Ast.IfThenElse (cond, ifb, elseb) -> - IfElse - ( trans_expr cond - , trans_single_stmt ifb - , Option.map ~f:trans_single_stmt elseb ) - |> swrap - | Ast.While (cond, body) -> - While (trans_expr cond, trans_single_stmt body) |> swrap - | Ast.For {loop_variable; lower_bound; upper_bound; loop_body} -> - let body = - match trans_single_stmt loop_body with - | {pattern= Block _; _} as b -> b - | x -> {x with pattern= Block [x]} in - For - { loopvar= loop_variable.Ast.name - ; lower= trans_expr lower_bound - ; upper= trans_expr upper_bound - ; body } - |> swrap - | Ast.ForEach (loopvar, iteratee, body) -> - let iteratee' = trans_expr iteratee in - let body_stmts = - match trans_single_stmt body with - | {pattern= Block body_stmts; _} -> body_stmts - | b -> [b] in - let decl_type = - match Expr.Typed.type_of iteratee' with - | UMatrix -> UnsizedType.UReal - | t -> Expr.Helpers.(infer_type_of_indexed t [Index.Single loop_bottom]) - in - let decl_loopvar = - Stmt.Fixed. - { meta= smeta - ; pattern= - Decl - { decl_adtype= Expr.Typed.adlevel_of iteratee' - ; decl_id= loopvar.name - ; decl_type= Unsized decl_type - ; initialize= true } } in - let assignment var = - Stmt.Fixed. - {pattern= Assignment ((loopvar.name, decl_type, []), var); meta= smeta} - in - let bodyfn var = - Stmt.Fixed. - { pattern= Block (decl_loopvar :: assignment var :: body_stmts) - ; meta= smeta } in - Stmt.Helpers.[ensure_var (for_each bodyfn) iteratee' smeta] - | Ast.FunDef _ -> - Common.FatalError.fatal_error_msg - [%message - "Found function definition statement outside of function block"] - | Ast.VarDecl - {decl_type; transformation; identifier; initial_value; is_global= _} -> - trans_decl declc smeta decl_type - (Transformation.map trans_expr transformation) - identifier initial_value - | Ast.Block stmts -> Block (List.concat_map ~f:trans_stmt stmts) |> swrap - | Ast.Profile (name, stmts) -> - Profile (name, List.concat_map ~f:trans_stmt stmts) |> swrap - | Ast.Return e -> Return (Some (trans_expr e)) |> swrap - | Ast.ReturnVoid -> Return None |> swrap - | Ast.Break -> Break |> swrap - | Ast.Continue -> Continue |> swrap - | Ast.Skip -> Skip |> swrap - -let trans_fun_def ud_dists (ts : Ast.typed_statement) = - match ts.stmt with - | Ast.FunDef {returntype; funname; arguments; body} -> - [ Program. - { fdrt= - (match returntype with Void -> None | ReturnType ut -> Some ut) - ; fdname= funname.name - ; fdsuffix= Fun_kind.(suffix_from_name funname.name |> without_propto) - ; fdargs= List.map ~f:trans_arg arguments - ; fdbody= - trans_stmt ud_dists - {transform_action= IgnoreTransform; dadlevel= AutoDiffable} - body - |> unwrap_block_or_skip - ; fdloc= ts.smeta.loc } ] - | _ -> - Common.FatalError.fatal_error_msg - [%message "Found non-function definition statement in function block"] - -let get_block block prog = - match block with - | Program.Parameters -> prog.Ast.parametersblock - | TransformedParameters -> prog.transformedparametersblock - | GeneratedQuantities -> prog.generatedquantitiesblock - -let trans_sizedtype_decl declc tr name = - let check fn x n = - Stmt.Helpers.internal_nrfunapp fn - Expr.Helpers. - [str name; str (Fmt.str "%a" Pretty_printing.pp_typed_expression x); n] - n.meta.loc in - let grab_size fn n = function - | Ast.{expr= IntNumeral i; _} as s when float_of_string i >= 2. -> - ([], trans_expr s) - | Ast.({expr= IntNumeral _; _} | {expr= Variable _; _}) as s -> - let e = trans_expr s in - ([check fn s e], e) - | s -> - let e = trans_expr s in - let decl_id = Fmt.str "%s_%ddim__" name n in - let decl = - { Stmt.Fixed.pattern= - Decl - { decl_type= Sized SInt - ; decl_id - ; decl_adtype= DataOnly - ; initialize= true } - ; meta= e.meta.loc } in - let assign = - { Stmt.Fixed.pattern= Assignment ((decl_id, UInt, []), e) - ; meta= e.meta.loc } in - let var = - Expr. - { Fixed.pattern= Var decl_id - ; meta= - Typed.Meta. - { type_= s.Ast.emeta.Ast.type_ - ; adlevel= s.emeta.ad_level - ; loc= s.emeta.loc } } in - ([decl; assign; check fn s var], var) in - let rec go n = function - | SizedType.(SInt | SReal | SComplex) as t -> ([], t) - | SVector (mem_pattern, s) -> - let fn = - match (declc.transform_action, tr) with - | Constrain, Transformation.Simplex -> - Internal_fun.FnValidateSizeSimplex - | Constrain, UnitVector -> FnValidateSizeUnitVector - | _ -> FnValidateSize in - let l, s = grab_size fn n s in - (l, SizedType.SVector (mem_pattern, s)) - | SRowVector (mem_pattern, s) -> - let l, s = grab_size FnValidateSize n s in - (l, SizedType.SRowVector (mem_pattern, s)) - | SComplexRowVector s -> - let l, s = grab_size FnValidateSize n s in - (l, SizedType.SComplexRowVector s) - | SComplexVector s -> - let l, s = grab_size FnValidateSize n s in - (l, SizedType.SComplexVector s) - | SMatrix (mem_pattern, r, c) -> - let l1, r = grab_size FnValidateSize n r in - let l2, c = grab_size FnValidateSize (n + 1) c in - let cf_cov = - match (declc.transform_action, tr) with - | Constrain, CholeskyCov -> - [ { Stmt.Fixed.pattern= - NRFunApp - ( StanLib ("check_greater_or_equal", FnPlain, AoS) - , Expr.Helpers. - [ str ("cholesky_factor_cov " ^ name) - ; str - "num rows (must be greater or equal to num cols)" - ; r; c ] ) - ; meta= r.Expr.Fixed.meta.Expr.Typed.Meta.loc } ] - | _ -> [] in - (l1 @ l2 @ cf_cov, SizedType.SMatrix (mem_pattern, r, c)) - | SComplexMatrix (r, c) -> - let l1, r = grab_size FnValidateSize n r in - let l2, c = grab_size FnValidateSize (n + 1) c in - (l1 @ l2, SizedType.SComplexMatrix (r, c)) - | SArray (t, s) -> - let l, s = grab_size FnValidateSize n s in - let ll, t = go (n + 1) t in - (l @ ll, SizedType.SArray (t, s)) in - go 1 - -let trans_block ud_dists declc block prog = - let f stmt (accum1, accum2, accum3) = - match stmt with - | { Ast.stmt= - VarDecl - { decl_type= Sized type_ - ; identifier - ; transformation - ; initial_value - ; is_global= true } - ; smeta } -> - let decl_id = identifier.Ast.name in - let transform = Transformation.map trans_expr transformation in - let rhs = Option.map ~f:trans_expr initial_value in - let size, type_ = - trans_sizedtype_decl declc transform identifier.name type_ in - let decl_adtype = declc.dadlevel in - let decl_var = - Expr. - { Fixed.pattern= Var decl_id - ; meta= - Typed.Meta.create ~adlevel:declc.dadlevel ~loc:smeta.Ast.loc - ~type_:(SizedType.to_unsized type_) - () } in - let decl = - Stmt. - { Fixed.pattern= + [%message "Found non-function definition statement in function block"] + + let get_block block prog = + match block with + | Program.Parameters -> prog.Ast.parametersblock + | TransformedParameters -> prog.transformedparametersblock + | GeneratedQuantities -> prog.generatedquantitiesblock + + let trans_sizedtype_decl declc tr name = + let check fn x n = + Stmt.Helpers.internal_nrfunapp fn + Expr.Helpers. + [str name; str (Fmt.str "%a" Pretty_printing.pp_typed_expression x); n] + n.meta.loc in + let grab_size fn n = function + | Ast.{expr= IntNumeral i; _} as s when float_of_string i >= 2. -> + ([], trans_expr s) + | Ast.({expr= IntNumeral _; _} | {expr= Variable _; _}) as s -> + let e = trans_expr s in + ([check fn s e], e) + | s -> + let e = trans_expr s in + let decl_id = Fmt.str "%s_%ddim__" name n in + let decl = + { Stmt.Fixed.pattern= Decl - { decl_adtype + { decl_type= Sized SInt ; decl_id - ; decl_type= Sized type_ + ; decl_adtype= DataOnly ; initialize= true } - ; meta= smeta.loc } in - let rhs_assignment = - Option.map - ~f:(fun e -> - Stmt.Fixed. - { pattern= Assignment ((decl_id, e.meta.type_, []), e) - ; meta= smeta.loc } ) - rhs - |> Option.to_list in - let outvar = - ( identifier.name - , Program. - { out_constrained_st= type_ - ; out_unconstrained_st= param_size transform type_ - ; out_block= block - ; out_trans= transform } ) in - let stmts = - if Utils.is_user_ident decl_id then - let constrain_checks = - match declc.transform_action with - | Constrain | Unconstrain -> - check_transform_shape decl_id decl_var smeta.loc transform - | Check -> - check_transform_shape decl_id decl_var smeta.loc transform - @ check_decl decl_var (Type.Sized type_) decl_id transform - smeta.loc declc.dadlevel - | IgnoreTransform -> [] in - (decl :: rhs_assignment) @ constrain_checks - else decl :: rhs_assignment in - (outvar :: accum1, size @ accum2, stmts @ accum3) - | stmt -> (accum1, accum2, trans_stmt ud_dists declc stmt @ accum3) in - Ast.get_stmts (get_block block prog) |> List.fold_right ~f ~init:([], [], []) - -let stmt_contains_check stmt = - let is_check = function - | Fun_kind.CompilerInternal (Internal_fun.FnCheck _) -> true - | _ -> false in - Stmt.Helpers.contains_fn_kind is_check stmt - -let migrate_checks_to_end_of_block stmts = - let checks, not_checks = List.partition_tf ~f:stmt_contains_check stmts in - not_checks @ checks - -let gather_data (p : Ast.typed_program) = - let data = Ast.get_stmts p.datablock in - List.filter_map data ~f:(function - | { stmt= - VarDecl - { decl_type= Sized sizedtype - ; transformation - ; identifier= {name; _} - ; _ } - ; _ } -> - Some - ( SizedType.map trans_expr sizedtype - , Transformation.map trans_expr transformation - , name ) - | _ -> None ) - -let trans_prog filename (p : Ast.typed_program) : Program.Typed.t = - let {Ast.functionblock; datablock; transformeddatablock; modelblock; _} = p in - let map f list_op = - Option.value_map ~default:[] - ~f:(fun {Ast.stmts; _} -> List.concat_map ~f stmts) - list_op in - let grab_fundef_names_and_types = function - | {Ast.stmt= Ast.FunDef {funname; arguments= (_, type_, _) :: _; _}; _} -> - [(funname.name, type_)] - | _ -> [] in - let ud_dists = map grab_fundef_names_and_types functionblock in - let trans_stmt = trans_stmt ud_dists in - let get_name_size s = - match s.Ast.stmt with - | Ast.VarDecl {decl_type= Sized st; identifier; transformation; _} -> - [(identifier.name, trans_sizedtype st, transformation)] - | _ -> [] in - let input_vars = - map get_name_size datablock |> List.map ~f:(fun (n, st, _) -> (n, st)) in - let declc = {transform_action= IgnoreTransform; dadlevel= DataOnly} in - let datab = map (trans_stmt {declc with transform_action= Check}) datablock in - let _, _, param = - trans_block ud_dists - {transform_action= Constrain; dadlevel= AutoDiffable} - Parameters p in - (* Backends will add to transform_inits as needed *) - let transform_inits = [] in - let out_param, paramsizes, param_gq = - trans_block ud_dists {declc with transform_action= Constrain} Parameters p - in - let _, _, txparam = - trans_block ud_dists - {transform_action= Check; dadlevel= AutoDiffable} - TransformedParameters p in - let out_tparam, tparamsizes, txparam_gq = - trans_block ud_dists - {declc with transform_action= Check} - TransformedParameters p in - let out_gq, gq_sizes, gq_stmts = - trans_block ud_dists - {declc with transform_action= Check} - GeneratedQuantities p in - let output_vars = out_param @ out_tparam @ out_gq in - let prepare_data = - datab - @ ( map - (trans_stmt {declc with transform_action= Check}) - transformeddatablock - |> migrate_checks_to_end_of_block ) - @ paramsizes @ tparamsizes @ gq_sizes in - let modelb = map (trans_stmt {declc with dadlevel= AutoDiffable}) modelblock in - let log_prob = - param - @ (txparam |> migrate_checks_to_end_of_block) - @ - match modelb with - | [] -> [] - | hd :: _ -> [{pattern= Block modelb; meta= hd.meta}] in - let txparam_decls, txparam_checks, txparam_stmts = - txparam_gq - |> List.partition3_map ~f:(function - | {pattern= Decl _; _} as d -> `Fst d - | s when stmt_contains_check s -> `Snd s - | s -> `Trd s ) in - let compiler_if_return cond = - Stmt.Fixed. - { pattern= - IfElse (cond, {pattern= Return None; meta= Location_span.empty}, None) - ; meta= Location_span.empty } in - let iexpr pattern = Expr.{pattern; Fixed.meta= Typed.Meta.empty} in - let fnot e = - FunApp (StanLib (Operator.to_string PNot, FnPlain, AoS), [e]) |> iexpr in - let tparam_early_return = - let to_var fv = iexpr (Var (Flag_vars.to_string fv)) in - let v1 = to_var EmitTransformedParameters in - let v2 = to_var EmitGeneratedQuantities in - [compiler_if_return (fnot (EOr (v1, v2) |> iexpr))] in - let gq_early_return = - [ compiler_if_return - (fnot (Var (Flag_vars.to_string EmitGeneratedQuantities) |> iexpr)) ] - in - let generate_quantities = - param_gq @ txparam_decls @ tparam_early_return @ txparam_stmts - @ txparam_checks @ gq_early_return - @ migrate_checks_to_end_of_block gq_stmts in - let normalize_prog_name prog_name = - if String.length prog_name > 0 && not (Char.is_alpha prog_name.[0]) then - "_" ^ prog_name - else prog_name in - { functions_block= map (trans_fun_def ud_dists) functionblock - ; input_vars - ; prepare_data - ; log_prob - ; generate_quantities - ; transform_inits - ; output_vars - ; prog_name= normalize_prog_name !Typechecker.model_name - ; prog_path= filename } + ; meta= e.meta.loc } in + let assign = + { Stmt.Fixed.pattern= Assignment ((decl_id, UInt, []), e) + ; meta= e.meta.loc } in + let var = + Expr. + { Fixed.pattern= Var decl_id + ; meta= + Typed.Meta. + { type_= s.Ast.emeta.Ast.type_ + ; adlevel= s.emeta.ad_level + ; loc= s.emeta.loc } } in + ([decl; assign; check fn s var], var) in + let rec go n = function + | SizedType.(SInt | SReal | SComplex) as t -> ([], t) + | SVector (mem_pattern, s) -> + let fn = + match (declc.transform_action, tr) with + | Constrain, Transformation.Simplex -> + Internal_fun.FnValidateSizeSimplex + | Constrain, UnitVector -> FnValidateSizeUnitVector + | _ -> FnValidateSize in + let l, s = grab_size fn n s in + (l, SizedType.SVector (mem_pattern, s)) + | SRowVector (mem_pattern, s) -> + let l, s = grab_size FnValidateSize n s in + (l, SizedType.SRowVector (mem_pattern, s)) + | SComplexRowVector s -> + let l, s = grab_size FnValidateSize n s in + (l, SizedType.SComplexRowVector s) + | SComplexVector s -> + let l, s = grab_size FnValidateSize n s in + (l, SizedType.SComplexVector s) + | SMatrix (mem_pattern, r, c) -> + let l1, r = grab_size FnValidateSize n r in + let l2, c = grab_size FnValidateSize (n + 1) c in + let cf_cov = + match (declc.transform_action, tr) with + | Constrain, CholeskyCov -> + [ { Stmt.Fixed.pattern= + NRFunApp + ( StanLib ("check_greater_or_equal", FnPlain, AoS) + , Expr.Helpers. + [ str ("cholesky_factor_cov " ^ name) + ; str + "num rows (must be greater or equal to num \ + cols)"; r; c ] ) + ; meta= r.Expr.Fixed.meta.Expr.Typed.Meta.loc } ] + | _ -> [] in + (l1 @ l2 @ cf_cov, SizedType.SMatrix (mem_pattern, r, c)) + | SComplexMatrix (r, c) -> + let l1, r = grab_size FnValidateSize n r in + let l2, c = grab_size FnValidateSize (n + 1) c in + (l1 @ l2, SizedType.SComplexMatrix (r, c)) + | SArray (t, s) -> + let l, s = grab_size FnValidateSize n s in + let ll, t = go (n + 1) t in + (l @ ll, SizedType.SArray (t, s)) in + go 1 + + let trans_block ud_dists declc block prog = + let f stmt (accum1, accum2, accum3) = + match stmt with + | { Ast.stmt= + VarDecl + { decl_type= Sized type_ + ; identifier + ; transformation + ; initial_value + ; is_global= true } + ; smeta } -> + let decl_id = identifier.Ast.name in + let transform = Transformation.map trans_expr transformation in + let rhs = Option.map ~f:trans_expr initial_value in + let size, type_ = + trans_sizedtype_decl declc transform identifier.name type_ in + let decl_adtype = declc.dadlevel in + let decl_var = + Expr. + { Fixed.pattern= Var decl_id + ; meta= + Typed.Meta.create ~adlevel:declc.dadlevel ~loc:smeta.Ast.loc + ~type_:(SizedType.to_unsized type_) + () } in + let decl = + Stmt. + { Fixed.pattern= + Decl + { decl_adtype + ; decl_id + ; decl_type= Sized type_ + ; initialize= true } + ; meta= smeta.loc } in + let rhs_assignment = + Option.map + ~f:(fun e -> + Stmt.Fixed. + { pattern= Assignment ((decl_id, e.meta.type_, []), e) + ; meta= smeta.loc } ) + rhs + |> Option.to_list in + let outvar = + ( identifier.name + , Program. + { out_constrained_st= type_ + ; out_unconstrained_st= param_size transform type_ + ; out_block= block + ; out_trans= transform } ) in + let stmts = + if Utils.is_user_ident decl_id then + let constrain_checks = + match declc.transform_action with + | Constrain | Unconstrain -> + check_transform_shape decl_id decl_var smeta.loc transform + | Check -> + check_transform_shape decl_id decl_var smeta.loc transform + @ check_decl decl_var (Type.Sized type_) decl_id transform + smeta.loc declc.dadlevel + | IgnoreTransform -> [] in + (decl :: rhs_assignment) @ constrain_checks + else decl :: rhs_assignment in + (outvar :: accum1, size @ accum2, stmts @ accum3) + | stmt -> (accum1, accum2, trans_stmt ud_dists declc stmt @ accum3) in + Ast.get_stmts (get_block block prog) |> List.fold_right ~f ~init:([], [], []) + + let stmt_contains_check stmt = + let is_check = function + | Fun_kind.CompilerInternal (Internal_fun.FnCheck _) -> true + | _ -> false in + Stmt.Helpers.contains_fn_kind is_check stmt + + let migrate_checks_to_end_of_block stmts = + let checks, not_checks = List.partition_tf ~f:stmt_contains_check stmts in + not_checks @ checks + + let gather_data (p : Ast.typed_program) = + let data = Ast.get_stmts p.datablock in + List.filter_map data ~f:(function + | { stmt= + VarDecl + { decl_type= Sized sizedtype + ; transformation + ; identifier= {name; _} + ; _ } + ; _ } -> + Some + ( SizedType.map trans_expr sizedtype + , Transformation.map trans_expr transformation + , name ) + | _ -> None ) + + let trans_prog filename (p : Ast.typed_program) : Program.Typed.t = + let {Ast.functionblock; datablock; transformeddatablock; modelblock; _} = + p in + let map f list_op = + Option.value_map ~default:[] + ~f:(fun {Ast.stmts; _} -> List.concat_map ~f stmts) + list_op in + let grab_fundef_names_and_types = function + | {Ast.stmt= Ast.FunDef {funname; arguments= (_, type_, _) :: _; _}; _} -> + [(funname.name, type_)] + | _ -> [] in + let ud_dists = map grab_fundef_names_and_types functionblock in + let trans_stmt = trans_stmt ud_dists in + let get_name_size s = + match s.Ast.stmt with + | Ast.VarDecl {decl_type= Sized st; identifier; transformation; _} -> + [(identifier.name, trans_sizedtype st, transformation)] + | _ -> [] in + let input_vars = + map get_name_size datablock |> List.map ~f:(fun (n, st, _) -> (n, st)) + in + let declc = {transform_action= IgnoreTransform; dadlevel= DataOnly} in + let datab = + map (trans_stmt {declc with transform_action= Check}) datablock in + let _, _, param = + trans_block ud_dists + {transform_action= Constrain; dadlevel= AutoDiffable} + Parameters p in + (* Backends will add to transform_inits as needed *) + let transform_inits = [] in + let out_param, paramsizes, param_gq = + trans_block ud_dists {declc with transform_action= Constrain} Parameters p + in + let _, _, txparam = + trans_block ud_dists + {transform_action= Check; dadlevel= AutoDiffable} + TransformedParameters p in + let out_tparam, tparamsizes, txparam_gq = + trans_block ud_dists + {declc with transform_action= Check} + TransformedParameters p in + let out_gq, gq_sizes, gq_stmts = + trans_block ud_dists + {declc with transform_action= Check} + GeneratedQuantities p in + let output_vars = out_param @ out_tparam @ out_gq in + let prepare_data = + datab + @ ( map + (trans_stmt {declc with transform_action= Check}) + transformeddatablock + |> migrate_checks_to_end_of_block ) + @ paramsizes @ tparamsizes @ gq_sizes in + let modelb = + map (trans_stmt {declc with dadlevel= AutoDiffable}) modelblock in + let log_prob = + param + @ (txparam |> migrate_checks_to_end_of_block) + @ + match modelb with + | [] -> [] + | hd :: _ -> [{pattern= Block modelb; meta= hd.meta}] in + let txparam_decls, txparam_checks, txparam_stmts = + txparam_gq + |> List.partition3_map ~f:(function + | {pattern= Decl _; _} as d -> `Fst d + | s when stmt_contains_check s -> `Snd s + | s -> `Trd s ) in + let compiler_if_return cond = + Stmt.Fixed. + { pattern= + IfElse + (cond, {pattern= Return None; meta= Location_span.empty}, None) + ; meta= Location_span.empty } in + let iexpr pattern = Expr.{pattern; Fixed.meta= Typed.Meta.empty} in + let fnot e = + FunApp (StanLib (Operator.to_string PNot, FnPlain, AoS), [e]) |> iexpr + in + let tparam_early_return = + let to_var fv = iexpr (Var (Flag_vars.to_string fv)) in + let v1 = to_var EmitTransformedParameters in + let v2 = to_var EmitGeneratedQuantities in + [compiler_if_return (fnot (EOr (v1, v2) |> iexpr))] in + let gq_early_return = + [ compiler_if_return + (fnot (Var (Flag_vars.to_string EmitGeneratedQuantities) |> iexpr)) ] + in + let generate_quantities = + param_gq @ txparam_decls @ tparam_early_return @ txparam_stmts + @ txparam_checks @ gq_early_return + @ migrate_checks_to_end_of_block gq_stmts in + let normalize_prog_name prog_name = + if String.length prog_name > 0 && not (Char.is_alpha prog_name.[0]) then + "_" ^ prog_name + else prog_name in + { functions_block= map (trans_fun_def ud_dists) functionblock + ; input_vars + ; prepare_data + ; log_prob + ; generate_quantities + ; transform_inits + ; output_vars + ; prog_name= normalize_prog_name !Typechecking.model_name + ; prog_path= filename } +end diff --git a/src/frontend/Ast_to_Mir.mli b/src/frontend/Ast_to_Mir.mli index 1a7c528881..89b8374875 100644 --- a/src/frontend/Ast_to_Mir.mli +++ b/src/frontend/Ast_to_Mir.mli @@ -1,8 +1,15 @@ (** Translate from the AST to the MIR *) open Middle + +module type Ast_Mir_translator = sig + val gather_data : Ast.typed_program -> (Expr.Typed.t SizedType.t * Expr.Typed.t Transformation.t * string) list val trans_prog : string -> Ast.typed_program -> Program.Typed.t + +end + +module Make(StdLib:Std_library_utils.Library): Ast_Mir_translator diff --git a/src/frontend/Info.ml b/src/frontend/Info.ml index 5ded69f550..399d9f7477 100644 --- a/src/frontend/Info.ml +++ b/src/frontend/Info.ml @@ -45,48 +45,6 @@ let rec get_function_calls_expr (funs, distrs) expr = | _ -> (funs, distrs) in fold_expression get_function_calls_expr (fun acc _ -> acc) acc expr.expr -let rec get_function_calls_stmt ud_dists (funs, distrs) stmt = - let acc = - match stmt.stmt with - | NRFunApp (StanLib _, f, _) -> (Set.add funs f.name, distrs) - | Tilde {distribution; _} -> - let possible_names = - List.map ~f:(( ^ ) distribution.name) Utils.distribution_suffices - |> String.Set.of_list in - if List.exists ~f:(fun (n, _) -> Set.mem possible_names n) ud_dists then - (funs, distrs) - else - let suffix = - Stan_math_signatures.dist_name_suffix ud_dists distribution.name - in - let name = distribution.name ^ Utils.unnormalized_suffix suffix in - (funs, Set.add distrs name) - | _ -> (funs, distrs) in - fold_statement get_function_calls_expr - (get_function_calls_stmt ud_dists) - (fun acc _ -> acc) - (fun acc _ -> acc) - acc stmt.stmt - -let function_calls_json p = - let map f list_op = - Option.value_map ~default:[] - ~f:(fun {stmts; _} -> List.concat_map ~f stmts) - list_op in - let grab_fundef_names_and_types = function - | {Ast.stmt= Ast.FunDef {funname; arguments= (_, type_, _) :: _; _}; _} -> - [(funname.name, type_)] - | _ -> [] in - let ud_dists = map grab_fundef_names_and_types p.functionblock in - let funs, distrs = - fold_program - (get_function_calls_stmt ud_dists) - (String.Set.empty, String.Set.empty) - p in - let set_to_List s = - `List (Set.to_list s |> List.map ~f:(fun str -> `String str)) in - `Assoc [("functions", set_to_List funs); ("distributions", set_to_List distrs)] - let includes_json () = `Assoc [ ( "included_files" @@ -94,12 +52,62 @@ let includes_json () = ( List.rev !Preprocessor.included_files |> List.map ~f:(fun str -> `String str) ) ) ] -let info_json ast = - List.fold ~f:Util.combine ~init:(`Assoc []) - [ block_info_json "inputs" ast.datablock - ; block_info_json "parameters" ast.parametersblock - ; block_info_json "transformed parameters" ast.transformedparametersblock - ; block_info_json "generated quantities" ast.generatedquantitiesblock - ; function_calls_json ast; includes_json () ] +module type Information = sig + val info : Ast.typed_program -> string +end + +module Make (StdLib : Std_library_utils.Library) : Information = struct + let rec get_function_calls_stmt ud_dists (funs, distrs) stmt = + let acc = + match stmt.stmt with + | NRFunApp (StanLib _, f, _) -> (Set.add funs f.name, distrs) + | Tilde {distribution; _} -> + let possible_names = + List.map ~f:(( ^ ) distribution.name) Utils.distribution_suffices + |> String.Set.of_list in + if List.exists ~f:(fun (n, _) -> Set.mem possible_names n) ud_dists + then (funs, distrs) + else + let suffix = + Std_library_utils.dist_name_suffix + (module StdLib) + ud_dists distribution.name in + let name = distribution.name ^ Utils.unnormalized_suffix suffix in + (funs, Set.add distrs name) + | _ -> (funs, distrs) in + fold_statement get_function_calls_expr + (get_function_calls_stmt ud_dists) + (fun acc _ -> acc) + (fun acc _ -> acc) + acc stmt.stmt + + let function_calls_json p = + let map f list_op = + Option.value_map ~default:[] + ~f:(fun {stmts; _} -> List.concat_map ~f stmts) + list_op in + let grab_fundef_names_and_types = function + | {Ast.stmt= Ast.FunDef {funname; arguments= (_, type_, _) :: _; _}; _} -> + [(funname.name, type_)] + | _ -> [] in + let ud_dists = map grab_fundef_names_and_types p.functionblock in + let funs, distrs = + fold_program + (get_function_calls_stmt ud_dists) + (String.Set.empty, String.Set.empty) + p in + let set_to_List s = + `List (Set.to_list s |> List.map ~f:(fun str -> `String str)) in + `Assoc + [("functions", set_to_List funs); ("distributions", set_to_List distrs)] + + let info_json ast = + List.fold ~f:Util.combine ~init:(`Assoc []) + [ block_info_json "inputs" ast.datablock + ; block_info_json "parameters" ast.parametersblock + ; block_info_json "transformed parameters" ast.transformedparametersblock + ; block_info_json "generated quantities" ast.generatedquantitiesblock + ; function_calls_json ast; includes_json () ] -let info ast = pretty_to_string (info_json ast) + let info ast = pretty_to_string (info_json ast) +end diff --git a/src/frontend/Info.mli b/src/frontend/Info.mli index a21aa8f717..9eaf3e576c 100644 --- a/src/frontend/Info.mli +++ b/src/frontend/Info.mli @@ -10,10 +10,14 @@ - [type]: the base type of the variable (["int"] or ["real"]). - [dimensions]: the number of dimensions ([0] for a scalar, [1] for a vector or row vector, etc.). - + The JSON object also have the fields [stanlib_calls] and [distributions] containing the name of the standard library functions called and distributions used. *) -val info : Ast.typed_program -> string +module type Information = sig + val info : Ast.typed_program -> string +end + +module Make(StdLib:Std_library_utils.Library): Information diff --git a/src/frontend/Semantic_error.ml b/src/frontend/Semantic_error.ml index 38090618fe..42c36a5b59 100644 --- a/src/frontend/Semantic_error.ml +++ b/src/frontend/Semantic_error.ml @@ -14,7 +14,11 @@ module TypeError = struct | IntIntArrayOrRangeExpected of UnsizedType.t | IntOrRealContainerExpected of UnsizedType.t | ArrayVectorRowVectorMatrixExpected of UnsizedType.t - | IllTypedAssignment of Operator.t * UnsizedType.t * UnsizedType.t + | IllTypedAssignment of + Operator.t + * UnsizedType.t + * UnsizedType.t + * Std_library_utils.signature list | IllTypedTernaryIf of UnsizedType.t * UnsizedType.t * UnsizedType.t | IllTypedVariadicFn of string @@ -97,18 +101,18 @@ module TypeError = struct "Foreach-loop must be over array, vector, row_vector or matrix. \ Instead found expression of type %a." UnsizedType.pp ut - | IllTypedAssignment (Operator.Equals, lt, rt) -> + | IllTypedAssignment (Operator.Equals, lt, rt, _) -> Fmt.pf ppf "Ill-typed arguments supplied to assignment operator =: lhs has type \ %a and rhs has type %a" UnsizedType.pp lt UnsizedType.pp rt - | IllTypedAssignment (op, lt, rt) -> + | IllTypedAssignment (op, lt, rt, sigs) -> Fmt.pf ppf "@[Ill-typed arguments supplied to assignment operator %a=: lhs \ has type %a and rhs has type %a.@ Available signatures for given \ lhs:@]@ %a" Operator.pp op UnsizedType.pp lt UnsizedType.pp rt - SignatureMismatch.pp_math_lib_assignmentoperator_sigs (lt, op) + SignatureMismatch.pp_assignmentoperator_sigs (lt, sigs) | IllTypedTernaryIf (UInt, ut, _) when UnsizedType.is_fun_type ut -> Fmt.pf ppf "Ternary expression cannot have a function type: %a" UnsizedType.pp ut @@ -514,8 +518,8 @@ let int_or_real_container_expected loc ut = let array_vector_rowvector_matrix_expected loc ut = TypeError (loc, TypeError.ArrayVectorRowVectorMatrixExpected ut) -let illtyped_assignment loc assignop lt rt = - TypeError (loc, TypeError.IllTypedAssignment (assignop, lt, rt)) +let illtyped_assignment loc assignop lt rt sigs = + TypeError (loc, TypeError.IllTypedAssignment (assignop, lt, rt, sigs)) let illtyped_ternary_if loc predt lt rt = TypeError (loc, TypeError.IllTypedTernaryIf (predt, lt, rt)) diff --git a/src/frontend/Semantic_error.mli b/src/frontend/Semantic_error.mli index 874cbacc41..18787d79ee 100644 --- a/src/frontend/Semantic_error.mli +++ b/src/frontend/Semantic_error.mli @@ -25,7 +25,12 @@ val array_vector_rowvector_matrix_expected : Location_span.t -> UnsizedType.t -> t val illtyped_assignment : - Location_span.t -> Operator.t -> UnsizedType.t -> UnsizedType.t -> t + Location_span.t + -> Operator.t + -> UnsizedType.t + -> UnsizedType.t + -> Std_library_utils.signature list + -> t val illtyped_ternary_if : Location_span.t -> UnsizedType.t -> UnsizedType.t -> UnsizedType.t -> t diff --git a/src/frontend/SignatureMismatch.ml b/src/frontend/SignatureMismatch.ml index e1b678b046..11bac9a347 100644 --- a/src/frontend/SignatureMismatch.ml +++ b/src/frontend/SignatureMismatch.ml @@ -380,10 +380,8 @@ let pp_signature_mismatch ppf (name, arg_tys, (sigs, omitted)) = (list ~sep:cut pp_signature) sigs pp_omitted () -let pp_math_lib_assignmentoperator_sigs ppf (lt, op) = +let pp_assignmentoperator_sigs ppf (lt, errors) = let signatures = - let errors = - Stan_math_signatures.make_assignmentoperator_stan_math_signatures op in let errors = List.filter ~f:(fun (_, args, _) -> diff --git a/src/frontend/SignatureMismatch.mli b/src/frontend/SignatureMismatch.mli index 4871bf49f5..ab373f4222 100644 --- a/src/frontend/SignatureMismatch.mli +++ b/src/frontend/SignatureMismatch.mli @@ -80,8 +80,10 @@ val pp_signature_mismatch : * bool ) -> unit -val pp_math_lib_assignmentoperator_sigs : - Format.formatter -> UnsizedType.t * Operator.t -> unit +val pp_assignmentoperator_sigs : + Format.formatter + -> UnsizedType.t* Std_library_utils.signature list + -> unit val compare_errors : function_mismatch -> function_mismatch -> int val compare_match_results : match_result -> match_result -> int diff --git a/src/frontend/Std_library_utils.ml b/src/frontend/Std_library_utils.ml index 6022d1540e..58f724a909 100644 --- a/src/frontend/Std_library_utils.ml +++ b/src/frontend/Std_library_utils.ml @@ -18,12 +18,6 @@ type variadic_checker = -> Ast.typed_expression list -> Ast.typed_expression -let pp_math_sig ppf (rt, args, mem_pattern) = - UnsizedType.pp ppf (UFun (args, rt, FnPlain, mem_pattern)) - -let pp_math_sigs ppf sigs = (Fmt.list ~sep:Fmt.cut pp_math_sig) ppf sigs -let pretty_print_math_sigs = Fmt.str "@[@,%a@]" pp_math_sigs - type deprecation_info = {replacement: string; version: string; extra_message: string} [@@deriving sexp] @@ -42,6 +36,7 @@ module type Library = sig val get_signatures : string -> signature list val get_operator_signatures : Operator.t -> signature list + val get_assignment_operator_signatures : Operator.t -> signature list val is_not_overloadable : string -> bool val is_variadic_function_name : string -> bool val operator_to_function_names : Operator.t -> string list @@ -57,6 +52,7 @@ module NullLibrary : Library = struct let distribution_families : string list = [] let is_stdlib_function_name _ = false let get_signatures _ = [] + let get_assignment_operator_signatures _ = [] let get_operator_signatures _ = [] let is_not_overloadable _ = false let is_variadic_function_name _ = false @@ -65,3 +61,17 @@ module NullLibrary : Library = struct let deprecated_distributions = String.Map.empty let deprecated_functions = String.Map.empty end + +let pp_math_sig ppf (rt, args, mem_pattern) = + UnsizedType.pp ppf (UFun (args, rt, FnPlain, mem_pattern)) + +let pp_math_sigs ppf sigs = (Fmt.list ~sep:Fmt.cut pp_math_sig) ppf sigs +let pretty_print_math_sigs = Fmt.str "@[@,%a@]" pp_math_sigs + +let dist_name_suffix (module StdLib : Library) udf_names name = + let is_udf_name s = + List.exists ~f:(fun (n, _) -> String.equal s n) udf_names in + Utils.distribution_suffices + |> List.filter ~f:(fun sfx -> + StdLib.is_stdlib_function_name (name ^ sfx) || is_udf_name (name ^ sfx) ) + |> List.hd_exn diff --git a/src/frontend/Typechecker.ml b/src/frontend/Typechecking.ml similarity index 99% rename from src/frontend/Typechecker.ml rename to src/frontend/Typechecking.ml index b279c90fc5..9d981c020e 100644 --- a/src/frontend/Typechecker.ml +++ b/src/frontend/Typechecking.ml @@ -1003,9 +1003,9 @@ module Make (StdLibrary : Std_library_utils.Library) : Typechecker = struct | _ -> () let check_assignment_operator loc assop lhs rhs = - let err op = + let err op sigs = Semantic_error.illtyped_assignment loc op lhs.lmeta.type_ rhs.emeta.type_ - in + sigs in match assop with | Assign | ArrowAssign -> ( match @@ -1013,11 +1013,14 @@ module Make (StdLibrary : Std_library_utils.Library) : Typechecker = struct rhs.emeta.type_ with | Ok p -> Promotion.promote rhs p - | Error _ -> err Operator.Equals |> error ) + | Error _ -> err Operator.Equals [] |> error ) | OperatorAssign op -> ( let args = List.map ~f:arg_type [Ast.expr_of_lvalue lhs; rhs] in let return_type = assignmentoperator_return_type op args in - match return_type with Some Void -> rhs | _ -> err op |> error ) + match return_type with + | Some Void -> rhs + | _ -> + err op (StdLibrary.get_assignment_operator_signatures op) |> error ) let check_lvalue cf tenv = function | {lval= LVariable id; lmeta= ({loc} : located_meta)} -> @@ -1497,7 +1500,7 @@ module Make (StdLibrary : Std_library_utils.Library) : Typechecker = struct | Ok p -> Some (Promotion.promote rhs p) | Error _ -> Semantic_error.illtyped_assignment loc Equals lhs.lmeta.type_ - rhs.emeta.type_ + rhs.emeta.type_ [] |> error ) | None -> None diff --git a/src/frontend/Typechecker.mli b/src/frontend/Typechecking.mli similarity index 100% rename from src/frontend/Typechecker.mli rename to src/frontend/Typechecking.mli diff --git a/src/stanc/stanc.ml b/src/stanc/stanc.ml index 6583ae4f58..a67c90e220 100644 --- a/src/stanc/stanc.ml +++ b/src/stanc/stanc.ml @@ -7,6 +7,14 @@ open Stan_math_backend open Analysis_and_optimization open Middle +(* Initialize functor modules with the Stan Math Library *) +module CppLibrary = Std_library_utils.NullLibrary +module Typechecker = Typechecking.Make (CppLibrary) +module Deprecations = Deprecation_analysis.Make (CppLibrary) +module Canonicalizer = Canonicalize.Make (Deprecations) +module ModelInfo = Info.Make (CppLibrary) +module Ast2Mir = Ast_to_Mir.Make (CppLibrary) + (** The main program. *) let version = "%%NAME%%3 %%VERSION%%" @@ -142,7 +150,7 @@ let options = exit 0 ) , " Display stanc version number" ) ; ( "--name" - , Arg.Set_string Typechecker.model_name + , Arg.Set_string Typechecking.model_name , " Take a string to set the model name (default = \ \"$model_filename_model\")" ) ; ( "--O0" @@ -176,10 +184,10 @@ let options = , Arg.Set print_model_cpp , " If set, output the generated C++ Stan model class to stdout." ) ; ( "--allow-undefined" - , Arg.Clear Typechecker.check_that_all_functions_have_definition + , Arg.Clear Typechecking.check_that_all_functions_have_definition , " Do not fail if a function is declared but not defined" ) ; ( "--allow_undefined" - , Arg.Clear Typechecker.check_that_all_functions_have_definition + , Arg.Clear Typechecking.check_that_all_functions_have_definition , " Deprecated. Same as --allow-undefined. Will be removed in Stan 2.32.0" ) ; ( "--include-paths" @@ -278,30 +286,30 @@ let use_file filename = ~print_warnings:(not !canonicalize_settings.deprecations) ~bare_functions:!bare_functions in (* must be before typecheck to fix up deprecated syntax which gets rejected *) - let ast = Canonicalize.repair_syntax ast !canonicalize_settings in + let ast = Canonicalizer.repair_syntax ast !canonicalize_settings in Debugging.ast_logger ast ; let typed_ast = type_ast_or_exit ast in let canonical_ast = - Canonicalize.canonicalize_program typed_ast !canonicalize_settings in + Canonicalizer.canonicalize_program typed_ast !canonicalize_settings in if !pretty_print_program then print_or_write (Pretty_printing.pretty_print_typed_program ~bare_functions:!bare_functions ~line_length:!pretty_print_line_length ~inline_includes:!canonicalize_settings.inline_includes canonical_ast ) ; if !print_info_json then ( - print_endline (Info.info canonical_ast) ; + print_endline (ModelInfo.info canonical_ast) ; exit 0 ) ; let printed_filename = match !filename_for_msg with "" -> None | s -> Some s in if not !canonicalize_settings.deprecations then Warnings.pp_warnings Fmt.stderr ?printed_filename - (Deprecation_analysis.collect_warnings typed_ast) ; + (Deprecations.collect_warnings typed_ast) ; if !generate_data then print_endline - (Debug_data_generation.print_data_prog (Ast_to_Mir.gather_data typed_ast)) ; + (Debug_data_generation.print_data_prog (Ast2Mir.gather_data typed_ast)) ; Debugging.typed_ast_logger typed_ast ; if not !pretty_print_program then ( - let mir = Ast_to_Mir.trans_prog filename typed_ast in + let mir = Ast2Mir.trans_prog filename typed_ast in if !dump_mir then Sexp.pp_hum Format.std_formatter [%sexp (mir : Middle.Program.Typed.t)] ; if !dump_mir_pretty then Program.Typed.pp Format.std_formatter mir ; @@ -356,12 +364,12 @@ let main () = Stan_math_code_gen.standalone_functions := true ; bare_functions := true ) ; (* Just translate a stan program *) - if !Typechecker.model_name = "" then - Typechecker.model_name := + if !Typechecking.model_name = "" then + Typechecking.model_name := mangle (remove_dotstan List.(hd_exn (rev (String.split !model_file ~on:'/')))) ^ "_model" - else Typechecker.model_name := mangle !Typechecker.model_name ; + else Typechecking.model_name := mangle !Typechecking.model_name ; use_file !model_file let () = main () From 5f0ac98561945b6c165c30c1b5b61286de4d48dc Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Wed, 27 Apr 2022 15:25:04 -0400 Subject: [PATCH 04/14] Trivially make stan math signatures comply to the new api --- src/frontend/Ast_to_Mir.mli | 13 +- src/frontend/Info.mli | 2 +- src/frontend/SignatureMismatch.mli | 4 +- src/frontend/Std_library_utils.ml | 4 + src/frontend/dune | 10 +- src/stan_math_backend/Expression_gen.ml | 28 +- src/stan_math_backend/Function_gen.ml | 6 +- src/stan_math_backend/Stan_math_signatures.ml | 2483 ----------------- .../Stan_math_signatures.mli | 75 - src/stan_math_backend/dune | 2 +- src/stanc/stanc.ml | 10 +- 11 files changed, 35 insertions(+), 2602 deletions(-) delete mode 100644 src/stan_math_backend/Stan_math_signatures.ml delete mode 100644 src/stan_math_backend/Stan_math_signatures.mli diff --git a/src/frontend/Ast_to_Mir.mli b/src/frontend/Ast_to_Mir.mli index 89b8374875..c045b9a5b0 100644 --- a/src/frontend/Ast_to_Mir.mli +++ b/src/frontend/Ast_to_Mir.mli @@ -1,15 +1,12 @@ (** Translate from the AST to the MIR *) open Middle - module type Ast_Mir_translator = sig + val gather_data : + Ast.typed_program + -> (Expr.Typed.t SizedType.t * Expr.Typed.t Transformation.t * string) list -val gather_data : - Ast.typed_program - -> (Expr.Typed.t SizedType.t * Expr.Typed.t Transformation.t * string) list - -val trans_prog : string -> Ast.typed_program -> Program.Typed.t - + val trans_prog : string -> Ast.typed_program -> Program.Typed.t end -module Make(StdLib:Std_library_utils.Library): Ast_Mir_translator +module Make (StdLib : Std_library_utils.Library) : Ast_Mir_translator diff --git a/src/frontend/Info.mli b/src/frontend/Info.mli index 9eaf3e576c..f9a7b57cda 100644 --- a/src/frontend/Info.mli +++ b/src/frontend/Info.mli @@ -20,4 +20,4 @@ module type Information = sig val info : Ast.typed_program -> string end -module Make(StdLib:Std_library_utils.Library): Information +module Make (StdLib : Std_library_utils.Library) : Information diff --git a/src/frontend/SignatureMismatch.mli b/src/frontend/SignatureMismatch.mli index ab373f4222..b1565ab744 100644 --- a/src/frontend/SignatureMismatch.mli +++ b/src/frontend/SignatureMismatch.mli @@ -81,9 +81,7 @@ val pp_signature_mismatch : -> unit val pp_assignmentoperator_sigs : - Format.formatter - -> UnsizedType.t* Std_library_utils.signature list - -> unit + Format.formatter -> UnsizedType.t * Std_library_utils.signature list -> unit val compare_errors : function_mismatch -> function_mismatch -> int val compare_match_results : match_result -> match_result -> int diff --git a/src/frontend/Std_library_utils.ml b/src/frontend/Std_library_utils.ml index 58f724a909..09c2989dfe 100644 --- a/src/frontend/Std_library_utils.ml +++ b/src/frontend/Std_library_utils.ml @@ -46,6 +46,10 @@ module type Library = sig end module NullLibrary : Library = struct + (** A "standard library" for stan which contains no functions. + Useful only for testing + *) + let function_signatures : (string, signature list) Hashtbl.t = String.Table.create () diff --git a/src/frontend/dune b/src/frontend/dune index f5e0020097..0359a63c1d 100644 --- a/src/frontend/dune +++ b/src/frontend/dune @@ -1,15 +1,7 @@ (library (name frontend) (public_name stanc.frontend) - (libraries - core_kernel - re - menhirLib - fmt - middle - common - yojson - stan_math_backend) + (libraries core_kernel re menhirLib fmt middle common yojson) (inline_tests) (preprocess (pps ppx_jane ppx_deriving.fold ppx_deriving.map))) diff --git a/src/stan_math_backend/Expression_gen.ml b/src/stan_math_backend/Expression_gen.ml index 44fda9ddee..d7d511f5b5 100644 --- a/src/stan_math_backend/Expression_gen.ml +++ b/src/stan_math_backend/Expression_gen.ml @@ -138,11 +138,9 @@ let variadic_dae_functor_suffix = "_daefunctor__" let functor_suffix_select hof = match hof with - | x when Stan_math_signatures.is_reduce_sum_fn x -> reduce_sum_functor_suffix - | x when Stan_math_signatures.is_variadic_ode_fn x -> - variadic_ode_functor_suffix - | x when Stan_math_signatures.is_variadic_dae_fn x -> - variadic_dae_functor_suffix + | x when Stan_math_library.is_reduce_sum_fn x -> reduce_sum_functor_suffix + | x when Stan_math_library.is_variadic_ode_fn x -> variadic_ode_functor_suffix + | x when Stan_math_library.is_variadic_dae_fn x -> variadic_dae_functor_suffix | _ -> functor_suffix let constraint_to_string = function @@ -334,7 +332,7 @@ and gen_functionals fname suffix es mem_pattern = | ( x , {pattern= FunApp ((UserDefined (f, _) | StanLib (f, _, _)), _); _} :: grainsize :: container :: tl ) - when Stan_math_signatures.is_reduce_sum_fn x -> + when Stan_math_library.is_reduce_sum_fn x -> let chop_functor_suffix = String.chop_suffix_exn ~suffix:reduce_sum_functor_suffix in let propto_template = @@ -349,16 +347,16 @@ and gen_functionals fname suffix es mem_pattern = ( strf "%s<%s%s>" fname normalized_dist_functor propto_template , grainsize :: container :: msgs :: tl ) | x, f :: y0 :: t0 :: ts :: rel_tol :: abs_tol :: max_steps :: tl - when Stan_math_signatures.is_variadic_ode_fn x + when Stan_math_library.is_variadic_ode_fn x && String.is_suffix fname - ~suffix:Stan_math_signatures.ode_tolerances_suffix - && not (Stan_math_signatures.variadic_ode_adjoint_fn = x) -> + ~suffix:Stan_math_library.ode_tolerances_suffix + && not (Stan_math_library.variadic_ode_adjoint_fn = x) -> ( fname , f :: y0 :: t0 :: ts :: rel_tol :: abs_tol :: max_steps :: msgs :: tl ) | x, f :: y0 :: t0 :: ts :: tl - when Stan_math_signatures.is_variadic_ode_fn x - && not (Stan_math_signatures.variadic_ode_adjoint_fn = x) -> + when Stan_math_library.is_variadic_ode_fn x + && not (Stan_math_library.variadic_ode_adjoint_fn = x) -> (fname, f :: y0 :: t0 :: ts :: msgs :: tl) | ( x , f @@ -375,7 +373,7 @@ and gen_functionals fname suffix es mem_pattern = :: num_checkpoints :: interpolation_polynomial :: solver_f :: solver_b :: tl ) - when Stan_math_signatures.variadic_ode_adjoint_fn = x -> + when Stan_math_library.variadic_ode_adjoint_fn = x -> ( fname , f :: y0 :: t0 :: ts :: rel_tol :: abs_tol :: rel_tol_b :: abs_tol_b :: rel_tol_q :: abs_tol_q :: max_num_steps @@ -384,14 +382,14 @@ and gen_functionals fname suffix es mem_pattern = | ( x , f :: yy0 :: yp0 :: t0 :: ts :: rel_tol :: abs_tol :: max_steps :: tl ) - when Stan_math_signatures.is_variadic_dae_fn x + when Stan_math_library.is_variadic_dae_fn x && String.is_suffix fname - ~suffix:Stan_math_signatures.dae_tolerances_suffix -> + ~suffix:Stan_math_library.dae_tolerances_suffix -> ( fname , f :: yy0 :: yp0 :: t0 :: ts :: rel_tol :: abs_tol :: max_steps :: msgs :: tl ) | x, f :: yy0 :: yp0 :: t0 :: ts :: tl - when Stan_math_signatures.is_variadic_dae_fn x -> + when Stan_math_library.is_variadic_dae_fn x -> (fname, f :: yy0 :: yp0 :: t0 :: ts :: msgs :: tl) | ( "map_rect" , {pattern= FunApp ((UserDefined (f, _) | StanLib (f, _, _)), _); _} diff --git a/src/stan_math_backend/Function_gen.ml b/src/stan_math_backend/Function_gen.ml index 38bfc0274d..bbee5bffbb 100644 --- a/src/stan_math_backend/Function_gen.ml +++ b/src/stan_math_backend/Function_gen.ml @@ -406,11 +406,11 @@ let collect_functors_functions (p : Program.Numbered.t) = String.Table.create () in let forward_decls = Hash_set.Poly.create () in let reduce_sum_fns = - is_fun_used_with_variadic_fn Stan_math_signatures.is_reduce_sum_fn p in + is_fun_used_with_variadic_fn Stan_math_library.is_reduce_sum_fn p in let variadic_ode_fns = - is_fun_used_with_variadic_fn Stan_math_signatures.is_variadic_ode_fn p in + is_fun_used_with_variadic_fn Stan_math_library.is_variadic_ode_fn p in let variadic_dae_fns = - is_fun_used_with_variadic_fn Stan_math_signatures.is_variadic_dae_fn p in + is_fun_used_with_variadic_fn Stan_math_library.is_variadic_dae_fn p in let pp_fun_def_with_variadic_fn_list ppf fblock = (hovbox ~indent:2 pp_fun_def) ppf diff --git a/src/stan_math_backend/Stan_math_signatures.ml b/src/stan_math_backend/Stan_math_signatures.ml deleted file mode 100644 index 0bf5de07a2..0000000000 --- a/src/stan_math_backend/Stan_math_signatures.ml +++ /dev/null @@ -1,2483 +0,0 @@ -(** The signatures of the Stan Math library, which are used for type checking *) - -open Core_kernel -open Core_kernel.Poly -open Middle - -(** The "dimensionality" (bad name?) is supposed to help us represent the - vectorized nature of many Stan functions. It allows us to represent when - a function argument can be just a real or matrix, or some common forms of - vectorization over reals. This captures the most commonly used forms in our - previous signatures; there are a lot partially because we had a lot of - inconsistencies. -*) -type dimensionality = - | DInt - | DReal - | DVector - | DMatrix - | DIntArray - (* Vectorizable int *) - | DVInt - (* Vectorizable real *) - | DVReal - (* DEPRECATED; vectorizable ints or reals *) - | DIntAndReals - (* Vectorizable vectors - for multivariate functions *) - | DVectors - | DDeepVectorized - -(* all base types with up 8 levels of nested containers - - just used for element-wise vectorized unary functions now *) - -let rec bare_array_type (t, i) = - match i with 0 -> t | j -> UnsizedType.UArray (bare_array_type (t, j - 1)) - -let rec expand_arg = function - | DInt -> [UnsizedType.UInt] - | DReal -> [UReal] - | DVector -> [UVector] - | DMatrix -> [UMatrix] - | DIntArray -> [UArray UInt] - | DVInt -> [UInt; UArray UInt] - | DVReal -> [UReal; UArray UReal; UVector; URowVector] - | DIntAndReals -> expand_arg DVReal @ expand_arg DVInt - | DVectors -> [UVector; UArray UVector; URowVector; UArray URowVector] - | DDeepVectorized -> - let all_base = [UnsizedType.UInt; UReal; URowVector; UVector; UMatrix] in - List.( - concat_map all_base ~f:(fun a -> - map (range 0 8) ~f:(fun i -> bare_array_type (a, i)) )) - -type fkind = Lpmf | Lpdf | Rng | Cdf | Ccdf | UnaryVectorized -[@@deriving show {with_path= false}] - -type fun_arg = UnsizedType.autodifftype * UnsizedType.t - -type signature = - UnsizedType.returntype * fun_arg list * Common.Helpers.mem_pattern - -let is_primitive = function - | UnsizedType.UReal -> true - | UInt -> true - | _ -> false - -(** The signatures hash table *) -let (stan_math_signatures : (string, signature list) Hashtbl.t) = - String.Table.create () - -(** All of the signatures that are added by hand, rather than the ones - added "declaratively" *) -let (manual_stan_math_signatures : (string, signature list) Hashtbl.t) = - String.Table.create () - -(* XXX The correct word here isn't combination - what is it? *) -let all_combinations xx = - List.fold_right xx ~init:[[]] ~f:(fun x accum -> - List.concat_map accum ~f:(fun acc -> - List.map ~f:(fun arg -> arg :: acc) x ) ) - -let%expect_test "combinations " = - let a = all_combinations [[1; 2]; [3; 4]; [5; 6]] in - [%sexp (a : int list list)] |> Sexp.to_string_hum |> print_endline ; - [%expect - {| ((1 3 5) (2 3 5) (1 4 5) (2 4 5) (1 3 6) (2 3 6) (1 4 6) (2 4 6)) |}] - -let missing_math_functions = - String.Set.of_list - ["beta_proportion_cdf"; "loglogistic_lcdf"; "loglogistic_cdf_log"] - -let rng_return_type t lt = - if List.for_all ~f:is_primitive lt then t else UnsizedType.UArray t - -let add_unqualified (name, rt, uqargts, mem_pattern) = - Hashtbl.add_multi manual_stan_math_signatures ~key:name - ~data: - ( rt - , List.map ~f:(fun x -> (UnsizedType.AutoDiffable, x)) uqargts - , mem_pattern ) - -let rec ints_to_real unsized = - match unsized with - | UnsizedType.UInt -> UnsizedType.UReal - | UArray t -> UArray (ints_to_real t) - | x -> x - -let rec complex_to_real = function - | UnsizedType.UComplex -> UnsizedType.UReal - | UComplexVector -> UVector - | UComplexRowVector -> URowVector - | UComplexMatrix -> UMatrix - | UArray t -> UArray (complex_to_real t) - | x -> x - -let reduce_sum_allowed_dimensionalities = [1; 2; 3; 4; 5; 6; 7] - -let reduce_sum_slice_types = - let base_slice_type i = - [ bare_array_type (UnsizedType.UReal, i) - ; bare_array_type (UnsizedType.UInt, i) - ; bare_array_type (UnsizedType.UMatrix, i) - ; bare_array_type (UnsizedType.UVector, i) - ; bare_array_type (UnsizedType.URowVector, i) ] in - List.concat (List.map ~f:base_slice_type reduce_sum_allowed_dimensionalities) - -(* Variadic ODE *) -let variadic_ode_adjoint_ctl_tol_arg_types = - [ (UnsizedType.DataOnly, UnsizedType.UReal) - (* real relative_tolerance_forward *) - ; (DataOnly, UVector) (* vector absolute_tolerance_forward *) - ; (DataOnly, UReal) (* real relative_tolerance_backward *) - ; (DataOnly, UVector) (* real absolute_tolerance_backward *) - ; (DataOnly, UReal) (* real relative_tolerance_quadrature *) - ; (DataOnly, UReal) (* real absolute_tolerance_quadrature *) - ; (DataOnly, UInt) (* int max_num_steps *) - ; (DataOnly, UInt) (* int num_steps_between_checkpoints *) - ; (DataOnly, UInt) (* int interpolation_polynomial *) - ; (DataOnly, UInt) (* int solver_forward *); (DataOnly, UInt) - (* int solver_backward *) ] - -let variadic_ode_tol_arg_types = - [ (UnsizedType.DataOnly, UnsizedType.UReal); (DataOnly, UReal) - ; (DataOnly, UInt) ] - -let variadic_ode_mandatory_arg_types = - [ (UnsizedType.AutoDiffable, UnsizedType.UVector); (AutoDiffable, UReal) - ; (AutoDiffable, UArray UReal) ] - -let variadic_ode_mandatory_fun_args = - [ (UnsizedType.AutoDiffable, UnsizedType.UReal) - ; (UnsizedType.AutoDiffable, UnsizedType.UVector) ] - -let variadic_ode_fun_return_type = UnsizedType.UVector -let variadic_ode_return_type = UnsizedType.UArray UnsizedType.UVector - -let variadic_dae_tol_arg_types = - [ (UnsizedType.DataOnly, UnsizedType.UReal); (DataOnly, UReal) - ; (DataOnly, UInt) ] - -let variadic_dae_mandatory_arg_types = - [ (UnsizedType.AutoDiffable, UnsizedType.UVector); (* yy *) - (UnsizedType.AutoDiffable, UnsizedType.UVector); (* yp *) - (AutoDiffable, UReal); (AutoDiffable, UArray UReal) ] - -let variadic_dae_mandatory_fun_args = - [ (UnsizedType.AutoDiffable, UnsizedType.UReal) - ; (UnsizedType.AutoDiffable, UnsizedType.UVector) - ; (UnsizedType.AutoDiffable, UnsizedType.UVector) ] - -let variadic_dae_fun_return_type = UnsizedType.UVector -let variadic_dae_return_type = UnsizedType.UArray UnsizedType.UVector - -let mk_declarative_sig (fnkinds, name, args, mem_pattern) = - let is_glm = String.is_suffix ~suffix:"_glm" name in - let sfxes = function - | Lpmf when is_glm -> ["_lpmf"] - | Lpmf -> ["_lpmf"; "_log"] - | Lpdf when is_glm -> ["_lpdf"] - | Lpdf -> ["_lpdf"; "_log"] - | Rng -> ["_rng"] - | Cdf -> ["_cdf"; "_cdf_log"; "_lcdf"] - | Ccdf -> ["_ccdf_log"; "_lccdf"] - | UnaryVectorized -> [""] in - let add_ints = function DVReal -> DIntAndReals | x -> x in - let all_expanded args = all_combinations (List.map ~f:expand_arg args) in - let promoted_dim = function - | DInt | DIntArray | DVInt -> UnsizedType.UInt - (* XXX fix this up to work with more RNGs *) - | _ -> UReal in - let find_rt rt args = function - | Rng -> UnsizedType.ReturnType (rng_return_type rt args) - | UnaryVectorized -> ReturnType (ints_to_real (List.hd_exn args)) - | _ -> ReturnType UReal in - let create_from_fk_args fk arglists = - List.concat_map arglists ~f:(fun args -> - List.map (sfxes fk) ~f:(fun sfx -> - (name ^ sfx, find_rt UReal args fk, args, mem_pattern) ) ) in - let add_fnkind = function - | Rng -> - let rt, args = (List.hd_exn args, List.tl_exn args) in - let args = List.map ~f:add_ints args in - let rt = promoted_dim rt in - let name = name ^ "_rng" in - List.map (all_expanded args) ~f:(fun args -> - (name, find_rt rt args Rng, args, mem_pattern) ) - | UnaryVectorized -> create_from_fk_args UnaryVectorized (all_expanded args) - | fk -> create_from_fk_args fk (all_expanded args) in - List.concat_map fnkinds ~f:add_fnkind - |> List.filter ~f:(fun (n, _, _, _) -> not (Set.mem missing_math_functions n)) - |> List.map ~f:(fun (n, rt, args, support_soa) -> - ( n - , rt - , List.map ~f:(fun x -> (UnsizedType.AutoDiffable, x)) args - , support_soa ) ) - -let full_lpdf = [Lpdf; Rng; Ccdf; Cdf] -let full_lpmf = [Lpmf; Rng; Ccdf; Cdf] -let reduce_sum_functions = String.Set.of_list ["reduce_sum"; "reduce_sum_static"] -let variadic_ode_adjoint_fn = "ode_adjoint_tol_ctl" - -let variadic_ode_nonadjoint_fns = - String.Set.of_list - [ "ode_bdf_tol"; "ode_rk45_tol"; "ode_adams_tol"; "ode_bdf"; "ode_rk45" - ; "ode_adams"; "ode_ckrk"; "ode_ckrk_tol" ] - -let ode_tolerances_suffix = "_tol" -let is_reduce_sum_fn f = Set.mem reduce_sum_functions f -let is_variadic_ode_nonadjoint_fn f = Set.mem variadic_ode_nonadjoint_fns f - -let is_variadic_ode_fn f = - Set.mem variadic_ode_nonadjoint_fns f || f = variadic_ode_adjoint_fn - -let is_variadic_ode_nonadjoint_tol_fn f = - is_variadic_ode_nonadjoint_fn f - && String.is_suffix f ~suffix:ode_tolerances_suffix - -let variadic_dae_fns = String.Set.of_list ["dae_tol"; "dae"] -let dae_tolerances_suffix = "_tol" -let is_variadic_dae_fn f = Set.mem variadic_dae_fns f - -let is_variadic_dae_tol_fn f = - is_variadic_dae_fn f && String.is_suffix f ~suffix:dae_tolerances_suffix - -let distributions = - [ ( full_lpmf - , "beta_binomial" - , [DVInt; DVInt; DVReal; DVReal] - , Common.Helpers.SoA ); (full_lpdf, "beta", [DVReal; DVReal; DVReal], SoA) - ; ([Lpdf; Ccdf; Cdf], "beta_proportion", [DVReal; DVReal; DIntAndReals], SoA) - ; (full_lpmf, "bernoulli", [DVInt; DVReal], SoA) - ; ([Lpmf; Rng], "bernoulli_logit", [DVInt; DVReal], SoA) - ; ([Lpmf], "bernoulli_logit_glm", [DVInt; DMatrix; DReal; DVector], SoA) - ; (full_lpmf, "binomial", [DVInt; DVInt; DVReal], SoA) - ; ([Lpmf], "binomial_logit", [DVInt; DVInt; DVReal], SoA) - ; ([Lpmf], "categorical", [DVInt; DVector], AoS) - ; ([Lpmf], "categorical_logit", [DVInt; DVector], AoS) - ; ([Lpmf], "categorical_logit_glm", [DVInt; DMatrix; DVector; DMatrix], SoA) - ; (full_lpdf, "cauchy", [DVReal; DVReal; DVReal], SoA) - ; (full_lpdf, "chi_square", [DVReal; DVReal], SoA) - ; ([Lpdf], "dirichlet", [DVectors; DVectors], SoA) - ; (full_lpmf, "discrete_range", [DVInt; DVInt; DVInt], SoA) - ; (full_lpdf, "double_exponential", [DVReal; DVReal; DVReal], SoA) - ; (full_lpdf, "exp_mod_normal", [DVReal; DVReal; DVReal; DVReal], SoA) - ; (full_lpdf, "exponential", [DVReal; DVReal], SoA) - ; (full_lpdf, "frechet", [DVReal; DVReal; DVReal], SoA) - ; (full_lpdf, "gamma", [DVReal; DVReal; DVReal], SoA) - ; ( [Lpdf] - , "gaussian_dlm_obs" - , [DMatrix; DMatrix; DMatrix; DMatrix; DMatrix; DVector; DMatrix] - , AoS ); (full_lpdf, "gumbel", [DVReal; DVReal; DVReal], SoA) - ; ([Rng], "hmm_latent", [DIntArray; DMatrix; DMatrix; DVector], AoS) - ; ([Lpmf; Rng], "hypergeometric", [DInt; DInt; DInt; DInt], SoA) - ; (full_lpdf, "inv_chi_square", [DVReal; DVReal], SoA) - ; (full_lpdf, "inv_gamma", [DVReal; DVReal; DVReal], SoA) - ; ([Lpdf], "inv_wishart", [DMatrix; DReal; DMatrix], SoA) - ; ([Lpdf], "lkj_corr", [DMatrix; DReal], AoS) - ; ([Lpdf], "lkj_corr_cholesky", [DMatrix; DReal], AoS) - ; (full_lpdf, "logistic", [DVReal; DVReal; DVReal], SoA) - ; ([Lpdf; Rng; Cdf], "loglogistic", [DVReal; DVReal; DVReal], SoA) - ; (full_lpdf, "lognormal", [DVReal; DVReal; DVReal], SoA) - ; ([Lpdf], "multi_gp", [DMatrix; DMatrix; DVector], AoS) - ; ([Lpdf], "multi_gp_cholesky", [DMatrix; DMatrix; DVector], AoS) - ; ([Lpmf], "multinomial", [DIntArray; DVector], AoS) - ; ([Lpmf], "multinomial_logit", [DIntArray; DVector], AoS) - ; ([Lpdf], "multi_normal", [DVectors; DVectors; DMatrix], AoS) - ; ([Lpdf], "multi_normal_cholesky", [DVectors; DVectors; DMatrix], AoS) - ; ([Lpdf], "multi_normal_prec", [DVectors; DVectors; DMatrix], AoS) - ; ([Lpdf], "multi_student_t", [DVectors; DReal; DVectors; DMatrix], AoS) - ; (full_lpmf, "neg_binomial", [DVInt; DVReal; DVReal], SoA) - ; (full_lpmf, "neg_binomial_2", [DVInt; DVReal; DVReal], SoA) - ; ([Lpmf; Rng], "neg_binomial_2_log", [DVInt; DVReal; DVReal], SoA) - ; ( [Lpmf] - , "neg_binomial_2_log_glm" - , [DVInt; DMatrix; DReal; DVector; DReal] - , SoA ); (full_lpdf, "normal", [DVReal; DVReal; DVReal], SoA) - ; ([Lpdf], "normal_id_glm", [DVector; DMatrix; DReal; DVector; DReal], SoA) - ; ([Lpmf], "ordered_logistic", [DInt; DReal; DVector], SoA) - ; ([Lpmf], "ordered_logistic_glm", [DVInt; DMatrix; DVector; DVector], SoA) - ; ([Lpmf], "ordered_probit", [DInt; DReal; DVector], SoA) - ; (full_lpdf, "pareto", [DVReal; DVReal; DVReal], SoA) - ; (full_lpdf, "pareto_type_2", [DVReal; DVReal; DVReal; DVReal], SoA) - ; (full_lpmf, "poisson", [DVInt; DVReal], SoA) - ; ([Lpmf; Rng], "poisson_log", [DVInt; DVReal], SoA) - ; ([Lpmf], "poisson_log_glm", [DVInt; DMatrix; DReal; DVector], SoA) - ; (full_lpdf, "rayleigh", [DVReal; DVReal], SoA) - ; (full_lpdf, "scaled_inv_chi_square", [DVReal; DVReal; DVReal], SoA) - ; (full_lpdf, "skew_normal", [DVReal; DVReal; DVReal; DVReal], SoA) - ; (full_lpdf, "skew_double_exponential", [DVReal; DVReal; DVReal; DVReal], SoA) - ; (full_lpdf, "student_t", [DVReal; DVReal; DVReal; DVReal], SoA) - ; (full_lpdf, "std_normal", [DVReal], SoA) - ; (full_lpdf, "uniform", [DVReal; DVReal; DVReal], SoA) - ; (full_lpdf, "von_mises", [DVReal; DVReal; DVReal], SoA) - ; (full_lpdf, "weibull", [DVReal; DVReal; DVReal], SoA) - ; ([Lpdf], "wiener", [DVReal; DVReal; DVReal; DVReal; DVReal], SoA) - ; ([Lpdf], "wishart", [DMatrix; DReal; DMatrix], SoA) ] - -let math_sigs = - [ ([UnaryVectorized], "acos", [DDeepVectorized], Common.Helpers.SoA) - ; ([UnaryVectorized], "acosh", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "asin", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "asinh", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "atan", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "atanh", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "cbrt", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "ceil", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "cos", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "cosh", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "digamma", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "erf", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "erfc", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "exp", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "exp2", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "expm1", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "fabs", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "floor", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "inv", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "inv_cloglog", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "inv_erfc", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "inv_logit", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "inv_Phi", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "inv_sqrt", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "inv_square", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "lambert_w0", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "lambert_wm1", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "lgamma", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "log", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "log10", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "log1m", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "log1m_exp", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "log1m_inv_logit", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "log1p", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "log1p_exp", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "log2", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "log_inv_logit", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "logit", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "Phi", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "Phi_approx", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "round", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "sin", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "sinh", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "sqrt", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "square", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "step", [DReal], SoA) - ; ([UnaryVectorized], "tan", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "tanh", [DDeepVectorized], SoA) - (* ; add_nullary ("target") *) - ; ([UnaryVectorized], "tgamma", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "trunc", [DDeepVectorized], SoA) - ; ([UnaryVectorized], "trigamma", [DDeepVectorized], SoA) ] - -let all_declarative_sigs = distributions @ math_sigs - -let declarative_fnsigs = - List.concat_map ~f:mk_declarative_sig all_declarative_sigs - -let is_stan_math_function_name name = - let name = Utils.stdlib_distribution_name name in - Hashtbl.mem stan_math_signatures name - -let dist_name_suffix udf_names name = - let is_udf_name s = List.exists ~f:(fun (n, _) -> n = s) udf_names in - Utils.distribution_suffices - |> List.filter ~f:(fun sfx -> - is_stan_math_function_name (name ^ sfx) || is_udf_name (name ^ sfx) ) - |> List.hd_exn - -let operator_to_stan_math_fns op = - match op with - | Operator.Plus -> ["add"] - | PPlus -> ["plus"] - | Minus -> ["subtract"] - | PMinus -> ["minus"] - | Times -> ["multiply"] - | Divide -> ["mdivide_right"; "divide"] - | Modulo -> ["modulus"] - | IntDivide -> [] - | LDivide -> ["mdivide_left"] - | EltTimes -> ["elt_multiply"] - | EltDivide -> ["elt_divide"] - | Pow -> ["pow"] - | EltPow -> ["pow"] - | Or -> ["logical_or"] - | And -> ["logical_and"] - | Equals -> ["logical_eq"] - | NEquals -> ["logical_neq"] - | Less -> ["logical_lt"] - | Leq -> ["logical_lte"] - | Greater -> ["logical_gt"] - | Geq -> ["logical_gte"] - | PNot -> ["logical_negation"] - | Transpose -> ["transpose"] - -let get_sigs name = - let name = Utils.stdlib_distribution_name name in - Hashtbl.find_multi stan_math_signatures name |> List.sort ~compare - -let make_assignmentoperator_stan_math_signatures assop = - ( match assop with - | Operator.Divide -> ["divide"] - | assop -> operator_to_stan_math_fns assop ) - |> List.concat_map ~f:get_sigs - |> List.concat_map ~f:(function - | ReturnType rtype, [(ad1, lhs); (ad2, rhs)], _ - when rtype = lhs - && not - ( (assop = Operator.EltTimes || assop = Operator.EltDivide) - && UnsizedType.is_scalar_type rtype ) -> - if rhs = UReal then - [ (UnsizedType.Void, [(ad1, lhs); (ad2, UInt)], Common.Helpers.SoA) - ; (Void, [(ad1, lhs); (ad2, UReal)], SoA) ] - else [(Void, [(ad1, lhs); (ad2, rhs)], SoA)] - | _ -> [] ) - -let string_operator_to_stan_math_fns str = - match str with - | "Plus__" -> "add" - | "PPlus__" -> "plus" - | "Minus__" -> "subtract" - | "PMinus__" -> "minus" - | "Times__" -> "multiply" - | "Divide__" -> "divide" - | "Modulo__" -> "modulus" - | "IntDivide__" -> "divide" - | "LDivide__" -> "mdivide_left" - | "EltTimes__" -> "elt_multiply" - | "EltDivide__" -> "elt_divide" - | "Pow__" -> "pow" - | "EltPow__" -> "pow" - | "Or__" -> "logical_or" - | "And__" -> "logical_and" - | "Equals__" -> "logical_eq" - | "NEquals__" -> "logical_neq" - | "Less__" -> "logical_lt" - | "Leq__" -> "logical_lte" - | "Greater__" -> "logical_gt" - | "Geq__" -> "logical_gte" - | "PNot__" -> "logical_negation" - | "Transpose__" -> "transpose" - | _ -> str - -let pretty_print_all_math_sigs ppf () = - let open Fmt in - let pp_sig ppf (name, (rt, args, _)) = - pf ppf "%s(@[%a@]) => %a" name - (list ~sep:comma UnsizedType.pp) - (List.map ~f:snd args) UnsizedType.pp_returntype rt in - let pp_sigs_for_name ppf name = - (list ~sep:cut pp_sig) ppf - (List.map ~f:(fun t -> (name, t)) (get_sigs name)) in - pf ppf "@[%a@]" - (list ~sep:cut pp_sigs_for_name) - (List.sort ~compare (Hashtbl.keys stan_math_signatures)) - -let pretty_print_all_math_distributions ppf () = - let open Fmt in - let pp_dist ppf (kinds, name, _, _) = - pf ppf "@[%s: %a@]" name - (list ~sep:comma Fmt.string) - (List.map ~f:(Fn.compose String.lowercase show_fkind) kinds) in - pf ppf "@[%a@]" (list ~sep:cut pp_dist) distributions - -(* let int_divide_type = - UnsizedType. - ( ReturnType UInt - , [(AutoDiffable, UInt); (AutoDiffable, UInt)] - , Common.Helpers.AoS ) *) - -(* TODO turn into a get_sigs version - let pretty_print_math_lib_operator_sigs op = - if op = Operator.IntDivide then - [Fmt.str "@[@,%a@]" Std_library_utils.pp_math_sig int_divide_type] - else - operator_to_stan_math_fns op - |> List.map - ~f:(Fn.compose Std_library_utils.pretty_print_math_sigs get_sigs) *) - -(* -- Some helper definitions to populate stan_math_signatures -- *) -let bare_types = - [ UnsizedType.UInt; UReal; UComplex; UVector; URowVector; UMatrix - ; UComplexVector; UComplexRowVector; UComplexMatrix ] - -let vector_types = [UnsizedType.UReal; UArray UReal; UVector; URowVector] -let primitive_types = [UnsizedType.UInt; UReal] - -let complex_types = - [UnsizedType.UComplex; UComplexVector; UComplexRowVector; UComplexMatrix] - -let all_vector_types = - [UnsizedType.UReal; UArray UReal; UVector; URowVector; UInt; UArray UInt] - -let add_qualified (name, rt, argts, supports_soa) = - Hashtbl.add_multi stan_math_signatures ~key:name - ~data:(rt, argts, supports_soa) - -let add_nullary name = - add_unqualified (name, UnsizedType.ReturnType UReal, [], AoS) - -let add_binary name supports_soa = - add_unqualified - (name, ReturnType UReal, [UnsizedType.UReal; UReal], supports_soa) - -let add_binary_vec name supports_soa = - List.iter - ~f:(fun i -> - List.iter - ~f:(fun j -> - add_unqualified - (name, ReturnType (ints_to_real i), [i; j], supports_soa) ) - [UnsizedType.UInt; UReal] ) - [UnsizedType.UInt; UReal] ; - List.iter - ~f:(fun i -> - List.iter - ~f:(fun j -> - add_unqualified - ( name - , ReturnType (ints_to_real (bare_array_type (j, i))) - , [bare_array_type (j, i); bare_array_type (j, i)] - , supports_soa ) ) - [UnsizedType.UArray UInt; UArray UReal; UVector; URowVector; UMatrix] ) - (List.range 0 8) ; - List.iter - ~f:(fun i -> - List.iter - ~f:(fun j -> - List.iter - ~f:(fun k -> - add_unqualified - ( name - , ReturnType (ints_to_real (bare_array_type (k, j))) - , [bare_array_type (k, j); i] - , supports_soa ) ) - [UnsizedType.UArray UInt; UArray UReal; UVector; URowVector; UMatrix] - ) - (List.range 0 8) ) - [UnsizedType.UInt; UReal] ; - List.iter - ~f:(fun i -> - List.iter - ~f:(fun j -> - List.iter - ~f:(fun k -> - add_unqualified - ( name - , ReturnType (ints_to_real (bare_array_type (k, j))) - , [i; bare_array_type (k, j)] - , supports_soa ) ) - [UnsizedType.UArray UInt; UArray UReal; UVector; URowVector; UMatrix] - ) - (List.range 0 8) ) - [UnsizedType.UInt; UReal] - -let add_binary_vec_real_real name supports_soa = - add_binary name supports_soa ; - List.iter - ~f:(fun i -> - List.iter - ~f:(fun j -> - add_unqualified - ( name - , ReturnType (bare_array_type (j, i)) - , [bare_array_type (j, i); bare_array_type (j, i)] - , supports_soa ) ) - [UnsizedType.UArray UReal; UVector; URowVector; UMatrix] ) - (List.range 0 8) ; - List.iter - ~f:(fun i -> - List.iter - ~f:(fun j -> - List.iter - ~f:(fun k -> - add_unqualified - ( name - , ReturnType (bare_array_type (k, j)) - , [bare_array_type (k, j); i] - , supports_soa ) ) - [UnsizedType.UArray UReal; UVector; URowVector; UMatrix] ) - (List.range 0 8) ) - [UnsizedType.UReal] ; - List.iter - ~f:(fun i -> - List.iter - ~f:(fun j -> - List.iter - ~f:(fun k -> - add_unqualified - ( name - , ReturnType (bare_array_type (k, j)) - , [i; bare_array_type (k, j)] - , supports_soa ) ) - [UnsizedType.UArray UReal; UVector; URowVector; UMatrix] ) - (List.range 0 8) ) - [UnsizedType.UReal] - -let add_binary_vec_int_real name supports_soa = - List.iter - ~f:(fun i -> - List.iter - ~f:(fun j -> - add_unqualified - ( name - , ReturnType (bare_array_type (i, j)) - , [UInt; bare_array_type (i, j)] - , supports_soa ) ) - (List.range 0 8) ) - [UnsizedType.UArray UReal; UVector; URowVector; UMatrix] ; - List.iter - ~f:(fun i -> - List.iter - ~f:(fun j -> - add_unqualified - ( name - , ReturnType (bare_array_type (i, j)) - , [bare_array_type (UInt, j + 1); bare_array_type (i, j)] - , supports_soa ) ) - (List.range 0 8) ) - [UnsizedType.UArray UReal; UVector; URowVector] ; - List.iter - ~f:(fun i -> - add_unqualified - ( name - , ReturnType (bare_array_type (UMatrix, i)) - , [bare_array_type (UInt, i + 2); bare_array_type (UMatrix, i)] - , supports_soa ) ) - (List.range 0 8) ; - List.iter - ~f:(fun i -> - add_unqualified - ( name - , ReturnType (bare_array_type (UReal, i)) - , [bare_array_type (UInt, i); UReal] - , supports_soa ) ) - (List.range 0 8) - -let add_binary_vec_real_int name supports_soa = - List.iter - ~f:(fun i -> - List.iter - ~f:(fun j -> - add_unqualified - ( name - , ReturnType (bare_array_type (i, j)) - , [bare_array_type (i, j); UInt] - , supports_soa ) ) - (List.range 0 8) ) - [UnsizedType.UArray UReal; UVector; URowVector; UMatrix] ; - List.iter - ~f:(fun i -> - List.iter - ~f:(fun j -> - add_unqualified - ( name - , ReturnType (bare_array_type (i, j)) - , [bare_array_type (i, j); bare_array_type (UInt, j + 1)] - , supports_soa ) ) - (List.range 0 8) ) - [UnsizedType.UArray UReal; UVector; URowVector] ; - List.iter - ~f:(fun i -> - add_unqualified - ( name - , ReturnType (bare_array_type (UMatrix, i)) - , [bare_array_type (UMatrix, i); bare_array_type (UInt, i + 2)] - , supports_soa ) ) - (List.range 0 8) ; - List.iter - ~f:(fun i -> - add_unqualified - ( name - , ReturnType (bare_array_type (UReal, i)) - , [UReal; bare_array_type (UInt, i)] - , supports_soa ) ) - (List.range 0 8) - -let add_binary_vec_int_int name supports_soa = - List.iter - ~f:(fun i -> - add_unqualified - ( name - , ReturnType (bare_array_type (UInt, i)) - , [bare_array_type (UInt, i); UInt] - , supports_soa ) ) - (List.range 0 8) ; - List.iter - ~f:(fun i -> - add_unqualified - ( name - , ReturnType (bare_array_type (UInt, i)) - , [UInt; bare_array_type (UInt, i)] - , supports_soa ) ) - (List.range 1 8) ; - List.iter - ~f:(fun i -> - add_unqualified - ( name - , ReturnType (bare_array_type (UInt, i)) - , [bare_array_type (UInt, i); bare_array_type (UInt, i)] - , supports_soa ) ) - (List.range 1 8) - -let add_ternary name supports_soa = - add_unqualified (name, ReturnType UReal, [UReal; UReal; UReal], supports_soa) - -(*Adds functions that operate on matrix, double array and real types*) -let add_ternary_vec name supports_soa = - add_unqualified (name, ReturnType UReal, [UReal; UReal; UReal], supports_soa) ; - add_unqualified - (name, ReturnType UVector, [UVector; UReal; UReal], supports_soa) ; - add_unqualified - (name, ReturnType UVector, [UVector; UVector; UReal], supports_soa) ; - add_unqualified - (name, ReturnType UVector, [UVector; UReal; UVector], supports_soa) ; - add_unqualified - (name, ReturnType UVector, [UVector; UVector; UVector], supports_soa) ; - add_unqualified - (name, ReturnType UVector, [UReal; UVector; UReal], supports_soa) ; - add_unqualified - (name, ReturnType UVector, [UReal; UVector; UVector], supports_soa) ; - add_unqualified - (name, ReturnType UVector, [UReal; UReal; UVector], supports_soa) ; - add_unqualified - (name, ReturnType URowVector, [URowVector; UReal; UReal], supports_soa) ; - add_unqualified - (name, ReturnType URowVector, [URowVector; URowVector; UReal], supports_soa) ; - add_unqualified - (name, ReturnType URowVector, [URowVector; UReal; URowVector], supports_soa) ; - add_unqualified - ( name - , ReturnType URowVector - , [URowVector; URowVector; URowVector] - , supports_soa ) ; - add_unqualified - (name, ReturnType URowVector, [UReal; URowVector; UReal], supports_soa) ; - add_unqualified - (name, ReturnType URowVector, [UReal; URowVector; URowVector], supports_soa) ; - add_unqualified - (name, ReturnType URowVector, [UReal; UReal; URowVector], supports_soa) ; - add_unqualified - (name, ReturnType UMatrix, [UMatrix; UReal; UReal], supports_soa) ; - add_unqualified - (name, ReturnType UMatrix, [UMatrix; UMatrix; UReal], supports_soa) ; - add_unqualified - (name, ReturnType UMatrix, [UMatrix; UReal; UMatrix], supports_soa) ; - add_unqualified - (name, ReturnType UMatrix, [UMatrix; UMatrix; UMatrix], supports_soa) ; - add_unqualified - (name, ReturnType UMatrix, [UReal; UMatrix; UReal], supports_soa) ; - add_unqualified - (name, ReturnType UMatrix, [UReal; UMatrix; UMatrix], supports_soa) ; - add_unqualified - (name, ReturnType UMatrix, [UReal; UReal; UMatrix], supports_soa) - -let for_all_vector_types s = List.iter ~f:s all_vector_types -let for_vector_types s = List.iter ~f:s vector_types - -(* -- Start populating stan_math_signaturess -- *) -let () = - List.iter declarative_fnsigs ~f:(fun (key, rt, args, mem_pattern) -> - Hashtbl.add_multi stan_math_signatures ~key ~data:(rt, args, mem_pattern) ) ; - add_unqualified ("abs", ReturnType UInt, [UInt], SoA) ; - add_unqualified ("abs", ReturnType UReal, [UReal], SoA) ; - add_unqualified ("abs", ReturnType UReal, [UComplex], AoS) ; - add_unqualified ("acos", ReturnType UComplex, [UComplex], AoS) ; - add_unqualified ("acosh", ReturnType UComplex, [UComplex], AoS) ; - List.iter - ~f:(fun x -> add_unqualified ("add", ReturnType x, [x; x], SoA)) - bare_types ; - add_unqualified ("add", ReturnType UVector, [UVector; UReal], SoA) ; - add_unqualified ("add", ReturnType URowVector, [URowVector; UReal], SoA) ; - add_unqualified ("add", ReturnType UMatrix, [UMatrix; UReal], SoA) ; - add_unqualified ("add", ReturnType UVector, [UReal; UVector], SoA) ; - add_unqualified ("add", ReturnType URowVector, [UReal; URowVector], SoA) ; - add_unqualified ("add", ReturnType UMatrix, [UReal; UMatrix], SoA) ; - add_unqualified ("add_diag", ReturnType UMatrix, [UMatrix; UReal], AoS) ; - add_unqualified ("add_diag", ReturnType UMatrix, [UMatrix; UVector], AoS) ; - add_unqualified ("add_diag", ReturnType UMatrix, [UMatrix; URowVector], AoS) ; - add_unqualified - ("add_diag", ReturnType UComplexMatrix, [UComplexMatrix; UComplex], AoS) ; - add_unqualified - ( "add_diag" - , ReturnType UComplexMatrix - , [UComplexMatrix; UComplexVector] - , AoS ) ; - add_unqualified - ( "add_diag" - , ReturnType UComplexMatrix - , [UComplexMatrix; UComplexRowVector] - , AoS ) ; - add_qualified - ( "algebra_solver" - , ReturnType UVector - , [ ( AutoDiffable - , UFun - ( [ (AutoDiffable, UVector); (AutoDiffable, UVector) - ; (DataOnly, UArray UReal); (DataOnly, UArray UInt) ] - , ReturnType UVector - , FnPlain - , AoS ) ); (AutoDiffable, UVector); (AutoDiffable, UVector) - ; (DataOnly, UArray UReal); (DataOnly, UArray UInt) ] - , AoS ) ; - add_qualified - ( "algebra_solver" - , ReturnType UVector - , [ ( AutoDiffable - , UFun - ( [ (AutoDiffable, UVector); (AutoDiffable, UVector) - ; (DataOnly, UArray UReal); (DataOnly, UArray UInt) ] - , ReturnType UVector - , FnPlain - , Common.Helpers.AoS ) ); (AutoDiffable, UVector) - ; (AutoDiffable, UVector); (DataOnly, UArray UReal) - ; (DataOnly, UArray UInt); (DataOnly, UReal); (DataOnly, UReal) - ; (DataOnly, UReal) ] - , AoS ) ; - add_qualified - ( "algebra_solver_newton" - , ReturnType UVector - , [ ( AutoDiffable - , UFun - ( [ (AutoDiffable, UVector); (AutoDiffable, UVector) - ; (DataOnly, UArray UReal); (DataOnly, UArray UInt) ] - , ReturnType UVector - , FnPlain - , Common.Helpers.AoS ) ); (AutoDiffable, UVector) - ; (AutoDiffable, UVector); (DataOnly, UArray UReal) - ; (DataOnly, UArray UInt) ] - , AoS ) ; - add_qualified - ( "algebra_solver_newton" - , ReturnType UVector - , [ ( AutoDiffable - , UFun - ( [ (AutoDiffable, UVector); (AutoDiffable, UVector) - ; (DataOnly, UArray UReal); (DataOnly, UArray UInt) ] - , ReturnType UVector - , FnPlain - , Common.Helpers.AoS ) ); (AutoDiffable, UVector) - ; (AutoDiffable, UVector); (DataOnly, UArray UReal) - ; (DataOnly, UArray UInt); (DataOnly, UReal); (DataOnly, UReal) - ; (DataOnly, UReal) ] - , AoS ) ; - List.iter - ~f:(fun i -> - List.iter - ~f:(fun t -> - add_unqualified - ( "append_array" - , ReturnType (bare_array_type (t, i)) - , [bare_array_type (t, i); bare_array_type (t, i)] - , AoS ) ) - bare_types ) - (List.range 1 8) ; - add_unqualified ("arg", ReturnType UReal, [UComplex], AoS) ; - add_unqualified ("asin", ReturnType UComplex, [UComplex], AoS) ; - add_unqualified ("asinh", ReturnType UComplex, [UComplex], AoS) ; - add_unqualified ("atan", ReturnType UComplex, [UComplex], AoS) ; - add_unqualified ("atanh", ReturnType UComplex, [UComplex], AoS) ; - add_binary "atan2" AoS ; - add_unqualified - ( "bernoulli_logit_glm_lpmf" - , ReturnType UReal - , [UArray UInt; UMatrix; UVector; UVector] - , SoA ) ; - add_unqualified - ( "bernoulli_logit_glm_lpmf" - , ReturnType UReal - , [UInt; UMatrix; UVector; UVector] - , SoA ) ; - add_unqualified - ( "bernoulli_logit_glm_lpmf" - , ReturnType UReal - , [UArray UInt; URowVector; UReal; UVector] - , SoA ) ; - add_unqualified - ( "bernoulli_logit_glm_lpmf" - , ReturnType UReal - , [UArray UInt; URowVector; UVector; UVector] - , SoA ) ; - add_unqualified - ( "bernoulli_logit_glm_rng" - , ReturnType (UArray UInt) - , [UMatrix; UVector; UVector] - , AoS ) ; - add_unqualified - ( "bernoulli_logit_glm_rng" - , ReturnType (UArray UInt) - , [URowVector; UVector; UVector] - , AoS ) ; - add_binary_vec_int_real "bessel_first_kind" SoA ; - add_binary_vec_int_real "bessel_second_kind" SoA ; - add_binary_vec "beta" SoA ; - (* XXX For some reason beta_proportion_rng doesn't take ints as first arg *) - for_vector_types (fun t -> - for_all_vector_types (fun u -> - add_unqualified - ( "beta_proportion_rng" - , ReturnType (rng_return_type UReal [t; u]) - , [t; u] - , AoS ) ) ) ; - add_binary_vec_int_real "binary_log_loss" AoS ; - add_binary_vec "binomial_coefficient_log" AoS ; - add_unqualified - ("block", ReturnType UMatrix, [UMatrix; UInt; UInt; UInt; UInt], SoA) ; - add_unqualified - ( "block" - , ReturnType UComplexMatrix - , [UComplexMatrix; UInt; UInt; UInt; UInt] - , AoS ) ; - add_unqualified ("categorical_rng", ReturnType UInt, [UVector], AoS) ; - add_unqualified ("categorical_logit_rng", ReturnType UInt, [UVector], AoS) ; - add_unqualified - ( "categorical_logit_glm_lpmf" - , ReturnType UReal - , [UArray UInt; URowVector; UVector; UMatrix] - , SoA ) ; - add_unqualified - ( "categorical_logit_glm_lpmf" - , ReturnType UReal - , [UInt; URowVector; UVector; UMatrix] - , SoA ) ; - add_unqualified ("append_col", ReturnType UMatrix, [UMatrix; UMatrix], AoS) ; - add_unqualified ("append_col", ReturnType UMatrix, [UVector; UMatrix], AoS) ; - add_unqualified ("append_col", ReturnType UMatrix, [UMatrix; UVector], AoS) ; - add_unqualified ("append_col", ReturnType UMatrix, [UVector; UVector], AoS) ; - add_unqualified - ("append_col", ReturnType URowVector, [URowVector; URowVector], AoS) ; - add_unqualified ("append_col", ReturnType URowVector, [UReal; URowVector], AoS) ; - add_unqualified ("append_col", ReturnType URowVector, [URowVector; UReal], AoS) ; - add_unqualified - ( "append_col" - , ReturnType UComplexMatrix - , [UComplexMatrix; UComplexMatrix] - , AoS ) ; - add_unqualified - ( "append_col" - , ReturnType UComplexMatrix - , [UComplexVector; UComplexMatrix] - , AoS ) ; - add_unqualified - ( "append_col" - , ReturnType UComplexMatrix - , [UComplexMatrix; UComplexVector] - , AoS ) ; - add_unqualified - ( "append_col" - , ReturnType UComplexMatrix - , [UComplexVector; UComplexVector] - , AoS ) ; - add_unqualified - ( "append_col" - , ReturnType UComplexRowVector - , [UComplexRowVector; UComplexRowVector] - , AoS ) ; - add_unqualified - ( "append_col" - , ReturnType UComplexRowVector - , [UComplex; UComplexRowVector] - , AoS ) ; - add_unqualified - ( "append_col" - , ReturnType UComplexRowVector - , [UComplexRowVector; UComplex] - , AoS ) ; - add_unqualified ("chol2inv", ReturnType UMatrix, [UMatrix], AoS) ; - add_unqualified ("cholesky_decompose", ReturnType UMatrix, [UMatrix], SoA) ; - add_binary_vec_int_int "choose" AoS ; - add_unqualified ("col", ReturnType UVector, [UMatrix; UInt], AoS) ; - add_unqualified ("col", ReturnType UComplexVector, [UComplexMatrix; UInt], SoA) ; - add_unqualified ("cols", ReturnType UInt, [UVector], SoA) ; - add_unqualified ("cols", ReturnType UInt, [URowVector], SoA) ; - add_unqualified ("cols", ReturnType UInt, [UMatrix], SoA) ; - add_unqualified ("cols", ReturnType UInt, [UComplexVector], SoA) ; - add_unqualified ("cols", ReturnType UInt, [UComplexRowVector], SoA) ; - add_unqualified ("cols", ReturnType UInt, [UComplexMatrix], SoA) ; - add_unqualified - ("columns_dot_product", ReturnType URowVector, [UVector; UVector], AoS) ; - add_unqualified - ("columns_dot_product", ReturnType URowVector, [URowVector; URowVector], AoS) ; - add_unqualified - ("columns_dot_product", ReturnType URowVector, [UMatrix; UMatrix], SoA) ; - add_unqualified - ( "columns_dot_product" - , ReturnType UComplexRowVector - , [UComplexVector; UComplexVector] - , AoS ) ; - add_unqualified - ( "columns_dot_product" - , ReturnType UComplexRowVector - , [UComplexRowVector; UComplexRowVector] - , AoS ) ; - add_unqualified - ( "columns_dot_product" - , ReturnType UComplexRowVector - , [UComplexMatrix; UComplexMatrix] - , AoS ) ; - add_unqualified ("columns_dot_self", ReturnType URowVector, [UVector], AoS) ; - add_unqualified ("columns_dot_self", ReturnType URowVector, [URowVector], AoS) ; - add_unqualified ("columns_dot_self", ReturnType URowVector, [UMatrix], AoS) ; - add_unqualified - ("columns_dot_self", ReturnType UComplexRowVector, [UComplexVector], AoS) ; - add_unqualified - ("columns_dot_self", ReturnType UComplexRowVector, [UComplexRowVector], AoS) ; - add_unqualified - ("columns_dot_self", ReturnType UComplexRowVector, [UComplexMatrix], AoS) ; - add_unqualified ("conj", ReturnType UComplex, [UComplex], AoS) ; - add_unqualified ("cos", ReturnType UComplex, [UComplex], AoS) ; - add_unqualified ("cosh", ReturnType UComplex, [UComplex], AoS) ; - add_unqualified - ("cov_exp_quad", ReturnType UMatrix, [UArray UReal; UReal; UReal], AoS) ; - add_unqualified - ("cov_exp_quad", ReturnType UMatrix, [UArray UVector; UReal; UReal], AoS) ; - add_unqualified - ("cov_exp_quad", ReturnType UMatrix, [UArray URowVector; UReal; UReal], AoS) ; - add_unqualified - ( "cov_exp_quad" - , ReturnType UMatrix - , [UArray UReal; UArray UReal; UReal; UReal] - , AoS ) ; - add_unqualified - ( "cov_exp_quad" - , ReturnType UMatrix - , [UArray UVector; UArray UVector; UReal; UReal] - , AoS ) ; - add_unqualified - ( "cov_exp_quad" - , ReturnType UMatrix - , [UArray URowVector; UArray URowVector; UReal; UReal] - , AoS ) ; - add_unqualified ("crossprod", ReturnType UMatrix, [UMatrix], AoS) ; - add_unqualified - ( "csr_matrix_times_vector" - , ReturnType UVector - , [UInt; UInt; UVector; UArray UInt; UArray UInt; UVector] - , SoA ) ; - add_unqualified - ( "csr_to_dense_matrix" - , ReturnType UMatrix - , [UInt; UInt; UVector; UArray UInt; UArray UInt] - , AoS ) ; - add_unqualified ("csr_extract_w", ReturnType UVector, [UMatrix], AoS) ; - add_unqualified ("csr_extract_v", ReturnType (UArray UInt), [UMatrix], AoS) ; - add_unqualified ("csr_extract_u", ReturnType (UArray UInt), [UMatrix], AoS) ; - add_unqualified - ("cumulative_sum", ReturnType (UArray UInt), [UArray UInt], AoS) ; - add_unqualified - ("cumulative_sum", ReturnType (UArray UReal), [UArray UReal], AoS) ; - add_unqualified ("cumulative_sum", ReturnType UVector, [UVector], SoA) ; - add_unqualified ("cumulative_sum", ReturnType URowVector, [URowVector], SoA) ; - add_unqualified - ("cumulative_sum", ReturnType (UArray UComplex), [UArray UComplex], AoS) ; - add_unqualified - ("cumulative_sum", ReturnType UComplexVector, [UComplexVector], AoS) ; - add_unqualified - ("cumulative_sum", ReturnType UComplexRowVector, [UComplexRowVector], AoS) ; - add_unqualified ("determinant", ReturnType UReal, [UMatrix], SoA) ; - add_unqualified ("diag_matrix", ReturnType UMatrix, [UVector], AoS) ; - add_unqualified - ("diag_matrix", ReturnType UComplexMatrix, [UComplexVector], AoS) ; - add_unqualified - ("diag_post_multiply", ReturnType UMatrix, [UMatrix; UVector], SoA) ; - add_unqualified - ("diag_post_multiply", ReturnType UMatrix, [UMatrix; URowVector], SoA) ; - add_unqualified - ( "diag_post_multiply" - , ReturnType UComplexMatrix - , [UComplexMatrix; UComplexVector] - , AoS ) ; - add_unqualified - ( "diag_post_multiply" - , ReturnType UComplexMatrix - , [UComplexMatrix; UComplexRowVector] - , AoS ) ; - add_unqualified - ("diag_pre_multiply", ReturnType UMatrix, [UVector; UMatrix], SoA) ; - add_unqualified - ("diag_pre_multiply", ReturnType UMatrix, [URowVector; UMatrix], SoA) ; - add_unqualified - ( "diag_pre_multiply" - , ReturnType UComplexMatrix - , [UComplexVector; UComplexMatrix] - , AoS ) ; - add_unqualified - ( "diag_pre_multiply" - , ReturnType UComplexMatrix - , [UComplexRowVector; UComplexMatrix] - , AoS ) ; - add_unqualified ("diagonal", ReturnType UVector, [UMatrix], SoA) ; - add_unqualified ("diagonal", ReturnType UComplexVector, [UComplexMatrix], SoA) ; - add_unqualified ("dims", ReturnType (UArray UInt), [UComplex], AoS) ; - add_unqualified ("dims", ReturnType (UArray UInt), [UInt], SoA) ; - add_unqualified ("dims", ReturnType (UArray UInt), [UReal], SoA) ; - add_unqualified ("dims", ReturnType (UArray UInt), [UVector], SoA) ; - add_unqualified ("dims", ReturnType (UArray UInt), [URowVector], SoA) ; - add_unqualified ("dims", ReturnType (UArray UInt), [UMatrix], SoA) ; - List.iter - ~f:(fun i -> - List.iter - ~f:(fun t -> - add_unqualified - ("dims", ReturnType (UArray UInt), [bare_array_type (t, i + 1)], SoA) - ) - bare_types ) - (List.range 0 8) ; - add_unqualified ("dirichlet_rng", ReturnType UVector, [UVector], AoS) ; - add_unqualified ("distance", ReturnType UReal, [UVector; UVector], SoA) ; - add_unqualified ("distance", ReturnType UReal, [URowVector; URowVector], SoA) ; - add_unqualified ("distance", ReturnType UReal, [UVector; URowVector], SoA) ; - add_unqualified ("distance", ReturnType UReal, [URowVector; UVector], SoA) ; - add_unqualified ("divide", ReturnType UComplex, [UComplex; UComplex], AoS) ; - add_unqualified ("divide", ReturnType UInt, [UInt; UInt], SoA) ; - add_unqualified ("divide", ReturnType UReal, [UReal; UReal], SoA) ; - add_unqualified ("divide", ReturnType UVector, [UVector; UReal], SoA) ; - add_unqualified ("divide", ReturnType URowVector, [URowVector; UReal], SoA) ; - add_unqualified ("divide", ReturnType UMatrix, [UMatrix; UReal], SoA) ; - add_unqualified ("dot_product", ReturnType UReal, [UVector; UVector], SoA) ; - add_unqualified - ("dot_product", ReturnType UReal, [URowVector; URowVector], SoA) ; - add_unqualified ("dot_product", ReturnType UReal, [UVector; URowVector], SoA) ; - add_unqualified ("dot_product", ReturnType UReal, [URowVector; UVector], SoA) ; - add_unqualified - ("dot_product", ReturnType UReal, [UArray UReal; UArray UReal], SoA) ; - add_unqualified - ("dot_product", ReturnType UComplex, [UComplexVector; UComplexVector], AoS) ; - add_unqualified - ( "dot_product" - , ReturnType UComplex - , [UComplexRowVector; UComplexRowVector] - , AoS ) ; - add_unqualified - ( "dot_product" - , ReturnType UComplex - , [UComplexVector; UComplexRowVector] - , AoS ) ; - add_unqualified - ( "dot_product" - , ReturnType UComplex - , [UComplexRowVector; UComplexVector] - , AoS ) ; - add_unqualified - ("dot_product", ReturnType UComplex, [UArray UComplex; UArray UComplex], AoS) ; - add_unqualified ("dot_self", ReturnType UReal, [UVector], SoA) ; - add_unqualified ("dot_self", ReturnType UReal, [URowVector], SoA) ; - add_unqualified ("dot_self", ReturnType UComplex, [UComplexVector], AoS) ; - add_unqualified ("dot_self", ReturnType UComplex, [UComplexRowVector], AoS) ; - add_nullary "e" ; - add_unqualified ("eigenvalues_sym", ReturnType UVector, [UMatrix], AoS) ; - add_unqualified ("eigenvectors_sym", ReturnType UMatrix, [UMatrix], AoS) ; - add_unqualified ("generalized_inverse", ReturnType UMatrix, [UMatrix], SoA) ; - add_unqualified ("qr_Q", ReturnType UMatrix, [UMatrix], AoS) ; - add_unqualified ("qr_R", ReturnType UMatrix, [UMatrix], AoS) ; - add_unqualified ("qr_thin_Q", ReturnType UMatrix, [UMatrix], AoS) ; - add_unqualified ("qr_thin_R", ReturnType UMatrix, [UMatrix], AoS) ; - List.iter - ~f:(fun x -> add_unqualified ("elt_divide", ReturnType x, [x; x], SoA)) - bare_types ; - add_unqualified ("elt_divide", ReturnType UVector, [UVector; UReal], SoA) ; - add_unqualified ("elt_divide", ReturnType URowVector, [URowVector; UReal], SoA) ; - add_unqualified ("elt_divide", ReturnType UMatrix, [UMatrix; UReal], SoA) ; - add_unqualified ("elt_divide", ReturnType UVector, [UReal; UVector], SoA) ; - add_unqualified ("elt_divide", ReturnType URowVector, [UReal; URowVector], SoA) ; - add_unqualified ("elt_divide", ReturnType UMatrix, [UReal; UMatrix], SoA) ; - List.iter - ~f:(fun x -> add_unqualified ("elt_multiply", ReturnType x, [x; x], SoA)) - bare_types ; - add_unqualified ("exp", ReturnType UComplex, [UComplex], AoS) ; - add_binary_vec_int_int "falling_factorial" SoA ; - add_binary_vec_real_int "falling_factorial" SoA ; - add_binary_vec "fdim" AoS ; - add_ternary_vec "fma" SoA ; - add_binary_vec "fmax" AoS ; - add_binary_vec "fmin" AoS ; - add_binary_vec "fmod" AoS ; - add_binary_vec_real_real "gamma_p" AoS ; - add_binary_vec_real_real "gamma_q" AoS ; - add_unqualified - ( "gaussian_dlm_obs_log" - , ReturnType UReal - , [UMatrix; UMatrix; UMatrix; UVector; UMatrix; UVector; UMatrix] - , AoS ) ; - add_unqualified - ( "gaussian_dlm_obs_lpdf" - , ReturnType UReal - , [UMatrix; UMatrix; UMatrix; UVector; UMatrix; UVector; UMatrix] - , AoS ) ; - List.iter - ~f:(fun i -> - List.iter - ~f:(fun t -> - add_unqualified - ( "get_imag" - , ReturnType (bare_array_type (complex_to_real t, i)) - , [bare_array_type (t, i)] - , AoS ) ) - complex_types ) - (List.range 0 8) ; - List.iter - ~f:(fun i -> - List.iter - ~f:(fun t -> - add_unqualified - ( "get_real" - , ReturnType (bare_array_type (complex_to_real t, i)) - , [bare_array_type (t, i)] - , AoS ) ) - complex_types ) - (List.range 0 8) ; - add_unqualified - ("gp_dot_prod_cov", ReturnType UMatrix, [UArray UReal; UReal], AoS) ; - add_unqualified - ( "gp_dot_prod_cov" - , ReturnType UMatrix - , [UArray UReal; UArray UReal; UReal] - , AoS ) ; - add_unqualified - ("gp_dot_prod_cov", ReturnType UMatrix, [UArray UVector; UReal], AoS) ; - add_unqualified - ( "gp_dot_prod_cov" - , ReturnType UMatrix - , [UArray UVector; UArray UVector; UReal] - , AoS ) ; - add_unqualified - ("gp_exp_quad_cov", ReturnType UMatrix, [UArray UReal; UReal; UReal], AoS) ; - add_unqualified - ( "gp_exp_quad_cov" - , ReturnType UMatrix - , [UArray UReal; UArray UReal; UReal; UReal] - , AoS ) ; - add_unqualified - ("gp_exp_quad_cov", ReturnType UMatrix, [UArray UVector; UReal; UReal], AoS) ; - add_unqualified - ( "gp_exp_quad_cov" - , ReturnType UMatrix - , [UArray UVector; UArray UVector; UReal; UReal] - , AoS ) ; - add_unqualified - ( "gp_exp_quad_cov" - , ReturnType UMatrix - , [UArray UVector; UReal; UArray UReal] - , AoS ) ; - add_unqualified - ( "gp_exp_quad_cov" - , ReturnType UMatrix - , [UArray UVector; UArray UVector; UReal; UArray UReal] - , AoS ) ; - add_unqualified - ("gp_matern32_cov", ReturnType UMatrix, [UArray UReal; UReal; UReal], AoS) ; - add_unqualified - ( "gp_matern32_cov" - , ReturnType UMatrix - , [UArray UReal; UArray UReal; UReal; UReal] - , AoS ) ; - add_unqualified - ("gp_matern32_cov", ReturnType UMatrix, [UArray UVector; UReal; UReal], AoS) ; - add_unqualified - ( "gp_matern32_cov" - , ReturnType UMatrix - , [UArray UVector; UArray UVector; UReal; UReal] - , AoS ) ; - add_unqualified - ( "gp_matern32_cov" - , ReturnType UMatrix - , [UArray UVector; UReal; UArray UReal] - , AoS ) ; - add_unqualified - ( "gp_matern32_cov" - , ReturnType UMatrix - , [UArray UVector; UArray UVector; UReal; UArray UReal] - , AoS ) ; - add_unqualified - ("gp_matern52_cov", ReturnType UMatrix, [UArray UReal; UReal; UReal], AoS) ; - add_unqualified - ( "gp_matern52_cov" - , ReturnType UMatrix - , [UArray UReal; UArray UReal; UReal; UReal] - , AoS ) ; - add_unqualified - ("gp_matern52_cov", ReturnType UMatrix, [UArray UVector; UReal; UReal], AoS) ; - add_unqualified - ( "gp_matern52_cov" - , ReturnType UMatrix - , [UArray UVector; UArray UVector; UReal; UReal] - , AoS ) ; - add_unqualified - ( "gp_matern52_cov" - , ReturnType UMatrix - , [UArray UVector; UReal; UArray UReal] - , AoS ) ; - add_unqualified - ( "gp_matern52_cov" - , ReturnType UMatrix - , [UArray UVector; UArray UVector; UReal; UArray UReal] - , AoS ) ; - add_unqualified - ("gp_exponential_cov", ReturnType UMatrix, [UArray UReal; UReal; UReal], AoS) ; - add_unqualified - ( "gp_exponential_cov" - , ReturnType UMatrix - , [UArray UReal; UArray UReal; UReal; UReal] - , AoS ) ; - add_unqualified - ( "gp_exponential_cov" - , ReturnType UMatrix - , [UArray UVector; UReal; UReal] - , AoS ) ; - add_unqualified - ( "gp_exponential_cov" - , ReturnType UMatrix - , [UArray UVector; UArray UVector; UReal; UReal] - , AoS ) ; - add_unqualified - ( "gp_exponential_cov" - , ReturnType UMatrix - , [UArray UVector; UReal; UArray UReal] - , AoS ) ; - add_unqualified - ( "gp_exponential_cov" - , ReturnType UMatrix - , [UArray UVector; UArray UVector; UReal; UArray UReal] - , AoS ) ; - add_unqualified - ( "gp_periodic_cov" - , ReturnType UMatrix - , [UArray UReal; UReal; UReal; UReal] - , AoS ) ; - add_unqualified - ( "gp_periodic_cov" - , ReturnType UMatrix - , [UArray UReal; UArray UReal; UReal; UReal; UReal] - , AoS ) ; - add_unqualified - ( "gp_periodic_cov" - , ReturnType UMatrix - , [UArray UVector; UReal; UReal; UReal] - , AoS ) ; - add_unqualified - ( "gp_periodic_cov" - , ReturnType UMatrix - , [UArray UVector; UArray UVector; UReal; UReal; UReal] - , AoS ) ; - (* ; add_nullary ("get_lp") *) - add_unqualified ("head", ReturnType URowVector, [URowVector; UInt], SoA) ; - add_unqualified ("head", ReturnType UVector, [UVector; UInt], SoA) ; - add_unqualified - ("head", ReturnType UComplexRowVector, [UComplexRowVector; UInt], AoS) ; - add_unqualified - ("head", ReturnType UComplexVector, [UComplexVector; UInt], AoS) ; - List.iter - ~f:(fun t -> - List.iter - ~f:(fun j -> - add_unqualified - ( "head" - , ReturnType (bare_array_type (t, j)) - , [bare_array_type (t, j); UInt] - , SoA ) ) - (List.range 1 4) ) - bare_types ; - add_unqualified - ("hmm_marginal", ReturnType UReal, [UMatrix; UMatrix; UVector], AoS) ; - add_qualified - ( "hmm_hidden_state_prob" - , ReturnType UMatrix - , [(DataOnly, UMatrix); (DataOnly, UMatrix); (DataOnly, UVector)] - , AoS ) ; - add_binary_vec "hypot" AoS ; - add_unqualified ("identity_matrix", ReturnType UMatrix, [UInt], SoA) ; - add_unqualified ("if_else", ReturnType UInt, [UInt; UInt; UInt], SoA) ; - add_unqualified ("if_else", ReturnType UReal, [UInt; UReal; UReal], SoA) ; - add_unqualified ("inc_beta", ReturnType UReal, [UReal; UReal; UReal], SoA) ; - add_unqualified ("int_step", ReturnType UInt, [UReal], SoA) ; - add_unqualified ("int_step", ReturnType UInt, [UInt], SoA) ; - add_qualified - ( "integrate_1d" - , ReturnType UReal - , [ ( AutoDiffable - , UFun - ( [ (AutoDiffable, UReal); (AutoDiffable, UReal) - ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) - ; (DataOnly, UArray UInt) ] - , ReturnType UReal - , FnPlain - , Common.Helpers.AoS ) ); (AutoDiffable, UReal) - ; (AutoDiffable, UReal); (AutoDiffable, UArray UReal) - ; (DataOnly, UArray UReal); (DataOnly, UArray UInt) ] - , AoS ) ; - add_qualified - ( "integrate_1d" - , ReturnType UReal - , [ ( AutoDiffable - , UFun - ( [ (AutoDiffable, UReal); (AutoDiffable, UReal) - ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) - ; (DataOnly, UArray UInt) ] - , ReturnType UReal - , FnPlain - , Common.Helpers.AoS ) ); (AutoDiffable, UReal) - ; (AutoDiffable, UReal); (AutoDiffable, UArray UReal) - ; (DataOnly, UArray UReal); (DataOnly, UArray UInt); (DataOnly, UReal) ] - , AoS ) ; - add_qualified - ( "integrate_ode" - , ReturnType (UArray (UArray UReal)) - , [ ( AutoDiffable - , UFun - ( [ (AutoDiffable, UReal); (AutoDiffable, UArray UReal) - ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) - ; (DataOnly, UArray UInt) ] - , ReturnType (UArray UReal) - , FnPlain - , Common.Helpers.AoS ) ); (AutoDiffable, UArray UReal) - ; (AutoDiffable, UReal); (AutoDiffable, UArray UReal) - ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) - ; (DataOnly, UArray UInt) ] - , AoS ) ; - add_qualified - ( "integrate_ode_adams" - , ReturnType (UArray (UArray UReal)) - , [ ( AutoDiffable - , UFun - ( [ (AutoDiffable, UReal); (AutoDiffable, UArray UReal) - ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) - ; (DataOnly, UArray UInt) ] - , ReturnType (UArray UReal) - , FnPlain - , Common.Helpers.AoS ) ); (AutoDiffable, UArray UReal) - ; (AutoDiffable, UReal); (AutoDiffable, UArray UReal) - ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) - ; (DataOnly, UArray UInt) ] - , AoS ) ; - add_qualified - ( "integrate_ode_adams" - , ReturnType (UArray (UArray UReal)) - , [ ( AutoDiffable - , UFun - ( [ (AutoDiffable, UReal); (AutoDiffable, UArray UReal) - ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) - ; (DataOnly, UArray UInt) ] - , ReturnType (UArray UReal) - , FnPlain - , Common.Helpers.AoS ) ); (AutoDiffable, UArray UReal) - ; (AutoDiffable, UReal); (AutoDiffable, UArray UReal) - ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) - ; (DataOnly, UArray UInt); (DataOnly, UReal); (DataOnly, UReal) - ; (DataOnly, UReal) ] - , AoS ) ; - add_qualified - ( "integrate_ode_bdf" - , ReturnType (UArray (UArray UReal)) - , [ ( AutoDiffable - , UFun - ( [ (AutoDiffable, UReal); (AutoDiffable, UArray UReal) - ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) - ; (DataOnly, UArray UInt) ] - , ReturnType (UArray UReal) - , FnPlain - , Common.Helpers.AoS ) ); (AutoDiffable, UArray UReal) - ; (AutoDiffable, UReal); (AutoDiffable, UArray UReal) - ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) - ; (DataOnly, UArray UInt) ] - , AoS ) ; - add_qualified - ( "integrate_ode_bdf" - , ReturnType (UArray (UArray UReal)) - , [ ( AutoDiffable - , UFun - ( [ (AutoDiffable, UReal); (AutoDiffable, UArray UReal) - ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) - ; (DataOnly, UArray UInt) ] - , ReturnType (UArray UReal) - , FnPlain - , Common.Helpers.AoS ) ); (AutoDiffable, UArray UReal) - ; (AutoDiffable, UReal); (AutoDiffable, UArray UReal) - ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) - ; (DataOnly, UArray UInt); (DataOnly, UReal); (DataOnly, UReal) - ; (DataOnly, UReal) ] - , AoS ) ; - add_qualified - ( "integrate_ode_rk45" - , ReturnType (UArray (UArray UReal)) - , [ ( AutoDiffable - , UFun - ( [ (AutoDiffable, UReal); (AutoDiffable, UArray UReal) - ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) - ; (DataOnly, UArray UInt) ] - , ReturnType (UArray UReal) - , FnPlain - , Common.Helpers.AoS ) ); (AutoDiffable, UArray UReal) - ; (AutoDiffable, UReal); (AutoDiffable, UArray UReal) - ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) - ; (DataOnly, UArray UInt) ] - , AoS ) ; - add_qualified - ( "integrate_ode_rk45" - , ReturnType (UArray (UArray UReal)) - , [ ( AutoDiffable - , UFun - ( [ (AutoDiffable, UReal); (AutoDiffable, UArray UReal) - ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) - ; (DataOnly, UArray UInt) ] - , ReturnType (UArray UReal) - , FnPlain - , Common.Helpers.AoS ) ); (AutoDiffable, UArray UReal) - ; (AutoDiffable, UReal); (AutoDiffable, UArray UReal) - ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) - ; (DataOnly, UArray UInt); (DataOnly, UReal); (DataOnly, UReal) - ; (DataOnly, UReal) ] - , AoS ) ; - add_unqualified ("inv_wishart_rng", ReturnType UMatrix, [UReal; UMatrix], AoS) ; - add_unqualified ("inverse", ReturnType UMatrix, [UMatrix], SoA) ; - add_unqualified ("inverse_spd", ReturnType UMatrix, [UMatrix], AoS) ; - add_unqualified ("is_inf", ReturnType UInt, [UReal], SoA) ; - add_unqualified ("is_nan", ReturnType UInt, [UReal], SoA) ; - add_binary_vec "lbeta" AoS ; - add_binary_vec "lchoose" AoS ; - add_binary_vec_real_int "ldexp" AoS ; - add_qualified - ( "linspaced_int_array" - , ReturnType (UArray UInt) - , [(DataOnly, UInt); (DataOnly, UInt); (DataOnly, UInt)] - , SoA ) ; - add_qualified - ( "linspaced_array" - , ReturnType (UArray UReal) - , [(DataOnly, UInt); (DataOnly, UReal); (DataOnly, UReal)] - , SoA ) ; - add_qualified - ( "linspaced_row_vector" - , ReturnType URowVector - , [(DataOnly, UInt); (DataOnly, UReal); (DataOnly, UReal)] - , SoA ) ; - add_qualified - ( "linspaced_vector" - , ReturnType UVector - , [(DataOnly, UInt); (DataOnly, UReal); (DataOnly, UReal)] - , SoA ) ; - add_unqualified - ("lkj_corr_cholesky_rng", ReturnType UMatrix, [UInt; UReal], AoS) ; - add_unqualified ("lkj_corr_rng", ReturnType UMatrix, [UInt; UReal], AoS) ; - add_unqualified - ("lkj_cov_log", ReturnType UReal, [UMatrix; UVector; UVector; UReal], AoS) ; - add_binary_vec_int_real "lmgamma" AoS ; - add_binary_vec "lmultiply" SoA ; - add_unqualified ("log", ReturnType UComplex, [UComplex], AoS) ; - add_nullary "log10" ; - add_unqualified ("log10", ReturnType UComplex, [UComplex], AoS) ; - add_nullary "log2" ; - add_unqualified ("log_determinant", ReturnType UReal, [UMatrix], SoA) ; - add_binary_vec "log_diff_exp" AoS ; - add_binary_vec "log_falling_factorial" AoS ; - add_binary_vec "log_inv_logit_diff" AoS ; - add_ternary "log_mix" AoS ; - List.iter - ~f:(fun v1 -> - List.iter - ~f:(fun v2 -> - add_unqualified ("log_mix", ReturnType UReal, [v1; v2], AoS) ) - (List.tl_exn vector_types) ; - add_unqualified ("log_mix", ReturnType UReal, [v1; UArray UVector], AoS) ; - add_unqualified ("log_mix", ReturnType UReal, [v1; UArray URowVector], AoS) - ) - (List.tl_exn vector_types) ; - add_binary_vec "log_modified_bessel_first_kind" AoS ; - add_binary_vec "log_rising_factorial" AoS ; - add_unqualified ("log_softmax", ReturnType UVector, [UVector], SoA) ; - add_unqualified ("log_sum_exp", ReturnType UReal, [UArray UReal], SoA) ; - add_unqualified ("log_sum_exp", ReturnType UReal, [UVector], SoA) ; - add_unqualified ("log_sum_exp", ReturnType UReal, [URowVector], SoA) ; - add_unqualified ("log_sum_exp", ReturnType UReal, [UMatrix], SoA) ; - add_binary "log_sum_exp" SoA ; - let logical_binops = - [ "logical_or"; "logical_and"; "logical_eq"; "logical_neq"; "logical_lt" - ; "logical_lte"; "logical_gt"; "logical_gte" ] in - List.iter - ~f:(fun t1 -> - add_unqualified ("logical_negation", ReturnType UInt, [t1], SoA) ; - List.iter - ~f:(fun t2 -> - List.iter - ~f:(fun o -> add_unqualified (o, ReturnType UInt, [t1; t2], SoA)) - logical_binops ) - primitive_types ) - primitive_types ; - add_unqualified ("logical_eq", ReturnType UInt, [UComplex; UReal], SoA) ; - add_unqualified ("logical_eq", ReturnType UInt, [UComplex; UComplex], SoA) ; - add_unqualified ("logical_neq", ReturnType UInt, [UComplex; UReal], SoA) ; - add_unqualified ("logical_neq", ReturnType UInt, [UComplex; UComplex], SoA) ; - add_nullary "machine_precision" ; - add_qualified - ( "map_rect" - , ReturnType UVector - , [ ( AutoDiffable - , UFun - ( [ (AutoDiffable, UVector); (AutoDiffable, UVector) - ; (DataOnly, UArray UReal); (DataOnly, UArray UInt) ] - , ReturnType UVector - , FnPlain - , Common.Helpers.AoS ) ); (AutoDiffable, UVector) - ; (AutoDiffable, UArray UVector); (DataOnly, UArray (UArray UReal)) - ; (DataOnly, UArray (UArray UInt)) ] - , AoS ) ; - add_unqualified ("matrix_exp", ReturnType UMatrix, [UMatrix], AoS) ; - add_unqualified - ("matrix_exp_multiply", ReturnType UMatrix, [UMatrix; UMatrix], AoS) ; - add_unqualified ("matrix_power", ReturnType UMatrix, [UMatrix; UInt], SoA) ; - add_unqualified ("max", ReturnType UInt, [UArray UInt], AoS) ; - add_unqualified ("max", ReturnType UReal, [UArray UReal], AoS) ; - add_unqualified ("max", ReturnType UReal, [UVector], AoS) ; - add_unqualified ("max", ReturnType UReal, [URowVector], AoS) ; - add_unqualified ("max", ReturnType UReal, [UMatrix], AoS) ; - add_unqualified ("max", ReturnType UInt, [UInt; UInt], AoS) ; - add_unqualified ("mdivide_left", ReturnType UVector, [UMatrix; UVector], SoA) ; - add_unqualified ("mdivide_left", ReturnType UMatrix, [UMatrix; UMatrix], SoA) ; - add_unqualified - ("mdivide_left_spd", ReturnType UVector, [UMatrix; UVector], SoA) ; - add_unqualified - ("mdivide_left_spd", ReturnType UMatrix, [UMatrix; UMatrix], SoA) ; - add_unqualified - ("mdivide_left_tri_low", ReturnType UMatrix, [UMatrix; UMatrix], AoS) ; - add_unqualified - ("mdivide_left_tri_low", ReturnType UVector, [UMatrix; UVector], AoS) ; - add_unqualified - ("mdivide_right", ReturnType URowVector, [URowVector; UMatrix], AoS) ; - add_unqualified ("mdivide_right", ReturnType UMatrix, [UMatrix; UMatrix], AoS) ; - add_unqualified - ( "mdivide_right" - , ReturnType UComplexRowVector - , [UComplexRowVector; UComplexMatrix] - , AoS ) ; - add_unqualified - ( "mdivide_right" - , ReturnType UComplexMatrix - , [UComplexMatrix; UComplexMatrix] - , AoS ) ; - add_unqualified - ("mdivide_right_spd", ReturnType UMatrix, [UMatrix; UMatrix], AoS) ; - add_unqualified - ("mdivide_right_spd", ReturnType URowVector, [URowVector; UMatrix], AoS) ; - add_unqualified - ("mdivide_right_tri_low", ReturnType URowVector, [URowVector; UMatrix], AoS) ; - add_unqualified - ("mdivide_right_tri_low", ReturnType UMatrix, [UMatrix; UMatrix], AoS) ; - add_unqualified ("mean", ReturnType UReal, [UArray UReal], SoA) ; - add_unqualified ("mean", ReturnType UReal, [UVector], AoS) ; - add_unqualified ("mean", ReturnType UReal, [URowVector], AoS) ; - add_unqualified ("mean", ReturnType UReal, [UMatrix], AoS) ; - add_unqualified ("min", ReturnType UInt, [UArray UInt], AoS) ; - add_unqualified ("min", ReturnType UReal, [UArray UReal], AoS) ; - add_unqualified ("min", ReturnType UReal, [UVector], AoS) ; - add_unqualified ("min", ReturnType UReal, [URowVector], AoS) ; - add_unqualified ("min", ReturnType UReal, [UMatrix], AoS) ; - add_unqualified ("min", ReturnType UInt, [UInt; UInt], AoS) ; - List.iter - ~f:(fun x -> add_unqualified ("minus", ReturnType x, [x], SoA)) - bare_types ; - add_binary_vec_int_real "modified_bessel_first_kind" AoS ; - add_binary_vec_int_real "modified_bessel_second_kind" AoS ; - add_unqualified ("modulus", ReturnType UInt, [UInt; UInt], AoS) ; - add_unqualified - ("multi_normal_rng", ReturnType UVector, [UVector; UMatrix], AoS) ; - add_unqualified - ( "multi_normal_rng" - , ReturnType (UArray UVector) - , [UArray UVector; UMatrix] - , AoS ) ; - add_unqualified - ("multi_normal_rng", ReturnType UVector, [URowVector; UMatrix], AoS) ; - add_unqualified - ( "multi_normal_rng" - , ReturnType (UArray UVector) - , [UArray URowVector; UMatrix] - , AoS ) ; - add_unqualified - ("multi_normal_cholesky_rng", ReturnType UVector, [UVector; UMatrix], AoS) ; - add_unqualified - ( "multi_normal_cholesky_rng" - , ReturnType (UArray UVector) - , [UArray UVector; UMatrix] - , AoS ) ; - add_unqualified - ("multi_normal_cholesky_rng", ReturnType UVector, [URowVector; UMatrix], AoS) ; - add_unqualified - ( "multi_normal_cholesky_rng" - , ReturnType (UArray UVector) - , [UArray URowVector; UMatrix] - , AoS ) ; - add_unqualified - ("multi_student_t_rng", ReturnType UVector, [UReal; UVector; UMatrix], AoS) ; - add_unqualified - ( "multi_student_t_rng" - , ReturnType (UArray UVector) - , [UReal; UArray UVector; UMatrix] - , AoS ) ; - add_unqualified - ( "multi_student_t_rng" - , ReturnType UVector - , [UReal; URowVector; UMatrix] - , AoS ) ; - add_unqualified - ( "multi_student_t_rng" - , ReturnType (UArray UVector) - , [UReal; UArray URowVector; UMatrix] - , AoS ) ; - add_unqualified - ("multinomial_logit_rng", ReturnType (UArray UInt), [UVector; UInt], AoS) ; - add_unqualified - ("multinomial_rng", ReturnType (UArray UInt), [UVector; UInt], AoS) ; - add_unqualified ("multiply", ReturnType UComplex, [UComplex; UComplex], AoS) ; - add_unqualified ("multiply", ReturnType UInt, [UInt; UInt], SoA) ; - add_unqualified ("multiply", ReturnType UReal, [UReal; UReal], SoA) ; - add_unqualified ("multiply", ReturnType UVector, [UVector; UReal], SoA) ; - add_unqualified ("multiply", ReturnType URowVector, [URowVector; UReal], SoA) ; - add_unqualified ("multiply", ReturnType UMatrix, [UMatrix; UReal], SoA) ; - add_unqualified ("multiply", ReturnType UReal, [URowVector; UVector], SoA) ; - add_unqualified ("multiply", ReturnType UMatrix, [UVector; URowVector], SoA) ; - add_unqualified ("multiply", ReturnType UVector, [UMatrix; UVector], SoA) ; - add_unqualified ("multiply", ReturnType URowVector, [URowVector; UMatrix], SoA) ; - add_unqualified ("multiply", ReturnType UMatrix, [UMatrix; UMatrix], SoA) ; - add_unqualified ("multiply", ReturnType UVector, [UReal; UVector], SoA) ; - add_unqualified ("multiply", ReturnType URowVector, [UReal; URowVector], SoA) ; - add_unqualified ("multiply", ReturnType UMatrix, [UReal; UMatrix], SoA) ; - (* TODO more complex overloads *) - add_unqualified - ( "multiply" - , ReturnType UComplexMatrix - , [UComplexMatrix; UComplexMatrix] - , SoA ) ; - add_unqualified - ("multiply", ReturnType UComplexMatrix, [UComplexMatrix; UComplex], SoA) ; - add_unqualified - ("multiply", ReturnType UComplexMatrix, [UComplex; UComplexMatrix], SoA) ; - add_unqualified - ( "multiply" - , ReturnType UComplexMatrix - , [UComplexVector; UComplexRowVector] - , SoA ) ; - add_unqualified - ("multiply", ReturnType UComplex, [UComplexRowVector; UComplexVector], SoA) ; - add_unqualified - ( "multiply" - , ReturnType UComplexVector - , [UComplexMatrix; UComplexVector] - , SoA ) ; - add_unqualified - ("multiply", ReturnType UComplexVector, [UComplexVector; UComplex], SoA) ; - add_unqualified - ("multiply", ReturnType UComplexVector, [UComplex; UComplexVector], SoA) ; - add_unqualified - ( "multiply" - , ReturnType UComplexRowVector - , [UComplexRowVector; UComplex] - , SoA ) ; - add_unqualified - ( "multiply" - , ReturnType UComplexRowVector - , [UComplex; UComplexRowVector] - , SoA ) ; - add_unqualified - ( "multiply" - , ReturnType UComplexRowVector - , [UComplexRowVector; UComplexMatrix] - , SoA ) ; - add_binary_vec "multiply_log" SoA ; - add_unqualified - ("multiply_lower_tri_self_transpose", ReturnType UMatrix, [UMatrix], SoA) ; - add_unqualified - ( "neg_binomial_2_log_glm_lpmf" - , ReturnType UReal - , [UArray UInt; UMatrix; UVector; UVector; UReal] - , SoA ) ; - add_unqualified - ( "neg_binomial_2_log_glm_lpmf" - , ReturnType UReal - , [UInt; UMatrix; UVector; UVector; UReal] - , SoA ) ; - add_unqualified - ( "neg_binomial_2_log_glm_lpmf" - , ReturnType UReal - , [UArray UInt; URowVector; UReal; UVector; UReal] - , SoA ) ; - add_unqualified - ( "neg_binomial_2_log_glm_lpmf" - , ReturnType UReal - , [UArray UInt; URowVector; UVector; UVector; UReal] - , SoA ) ; - add_nullary "negative_infinity" ; - add_unqualified ("norm", ReturnType UReal, [UComplex], AoS) ; - add_unqualified - ( "normal_id_glm_lpdf" - , ReturnType UReal - , [UVector; UMatrix; UVector; UVector; UReal] - , SoA ) ; - add_unqualified - ( "normal_id_glm_lpdf" - , ReturnType UReal - , [UReal; UMatrix; UReal; UVector; UReal] - , SoA ) ; - add_unqualified - ( "normal_id_glm_lpdf" - , ReturnType UReal - , [UReal; UMatrix; UVector; UVector; UReal] - , SoA ) ; - add_unqualified - ( "normal_id_glm_lpdf" - , ReturnType UReal - , [UReal; UMatrix; UReal; UVector; UVector] - , SoA ) ; - add_unqualified - ( "normal_id_glm_lpdf" - , ReturnType UReal - , [UReal; UMatrix; UVector; UVector; UVector] - , SoA ) ; - add_unqualified - ( "normal_id_glm_lpdf" - , ReturnType UReal - , [UVector; URowVector; UReal; UVector; UVector] - , SoA ) ; - add_unqualified - ( "normal_id_glm_lpdf" - , ReturnType UReal - , [UVector; URowVector; UVector; UVector; UReal] - , SoA ) ; - add_unqualified - ( "normal_id_glm_lpdf" - , ReturnType UReal - , [UVector; URowVector; UVector; UVector; UVector] - , SoA ) ; - add_unqualified - ( "normal_id_glm_lpdf" - , ReturnType UReal - , [UVector; URowVector; UReal; UVector; UReal] - , SoA ) ; - add_unqualified - ( "normal_id_glm_lpdf" - , ReturnType UReal - , [UVector; UMatrix; UReal; UVector; UVector] - , SoA ) ; - add_unqualified - ( "normal_id_glm_lpdf" - , ReturnType UReal - , [UVector; UMatrix; UVector; UVector; UVector] - , SoA ) ; - add_nullary "not_a_number" ; - add_unqualified ("num_elements", ReturnType UInt, [UMatrix], SoA) ; - add_unqualified ("num_elements", ReturnType UInt, [UVector], SoA) ; - add_unqualified ("num_elements", ReturnType UInt, [URowVector], SoA) ; - add_unqualified ("num_elements", ReturnType UInt, [UComplexMatrix], SoA) ; - add_unqualified ("num_elements", ReturnType UInt, [UComplexVector], SoA) ; - add_unqualified ("num_elements", ReturnType UInt, [UComplexRowVector], SoA) ; - List.iter - ~f:(fun i -> - List.iter - ~f:(fun t -> - add_unqualified - ("num_elements", ReturnType UInt, [bare_array_type (t, i)], SoA) ) - bare_types ) - (List.range 1 10) ; - add_unqualified - ("one_hot_int_array", ReturnType (UArray UInt), [UInt; UInt], SoA) ; - add_unqualified ("one_hot_array", ReturnType (UArray UReal), [UInt; UInt], SoA) ; - add_unqualified - ("one_hot_row_vector", ReturnType URowVector, [UInt; UInt], SoA) ; - add_unqualified ("one_hot_vector", ReturnType UVector, [UInt; UInt], SoA) ; - add_unqualified ("ones_int_array", ReturnType (UArray UInt), [UInt], SoA) ; - add_unqualified ("ones_array", ReturnType (UArray UReal), [UInt], SoA) ; - add_unqualified ("ones_row_vector", ReturnType URowVector, [UInt], SoA) ; - add_unqualified ("ones_vector", ReturnType UVector, [UInt], SoA) ; - add_unqualified - ( "ordered_logistic_glm_lpmf" - , ReturnType UReal - , [UArray UInt; URowVector; UVector; UVector] - , SoA ) ; - add_unqualified - ( "ordered_logistic_glm_lpmf" - , ReturnType UReal - , [UInt; URowVector; UVector; UVector] - , SoA ) ; - add_unqualified - ( "ordered_logistic_log" - , ReturnType UReal - , [UArray UInt; UVector; UVector] - , SoA ) ; - add_unqualified - ( "ordered_logistic_log" - , ReturnType UReal - , [UArray UInt; UVector; UArray UVector] - , SoA ) ; - add_unqualified - ( "ordered_logistic_lpmf" - , ReturnType UReal - , [UArray UInt; UVector; UVector] - , SoA ) ; - add_unqualified - ( "ordered_logistic_lpmf" - , ReturnType UReal - , [UArray UInt; UVector; UArray UVector] - , SoA ) ; - add_unqualified - ("ordered_logistic_rng", ReturnType UInt, [UReal; UVector], AoS) ; - add_unqualified - ( "ordered_probit_log" - , ReturnType UReal - , [UArray UInt; UVector; UVector] - , AoS ) ; - add_unqualified - ( "ordered_probit_log" - , ReturnType UReal - , [UArray UInt; UVector; UArray UVector] - , AoS ) ; - add_unqualified - ("ordered_probit_lpmf", ReturnType UReal, [UArray UInt; UReal; UVector], AoS) ; - add_unqualified - ( "ordered_probit_lpmf" - , ReturnType UReal - , [UArray UInt; UReal; UArray UVector] - , AoS ) ; - add_unqualified - ( "ordered_probit_lpmf" - , ReturnType UReal - , [UArray UInt; UVector; UVector] - , AoS ) ; - add_unqualified - ( "ordered_probit_lpmf" - , ReturnType UReal - , [UArray UInt; UVector; UArray UVector] - , AoS ) ; - add_unqualified ("ordered_probit_rng", ReturnType UInt, [UReal; UVector], AoS) ; - add_binary_vec_real_real "owens_t" AoS ; - add_nullary "pi" ; - add_unqualified ("plus", ReturnType UComplex, [UComplex], AoS) ; - add_unqualified ("plus", ReturnType UInt, [UInt], SoA) ; - add_unqualified ("plus", ReturnType UReal, [UReal], SoA) ; - add_unqualified ("plus", ReturnType UVector, [UVector], SoA) ; - add_unqualified ("plus", ReturnType URowVector, [URowVector], SoA) ; - add_unqualified ("plus", ReturnType UMatrix, [UMatrix], SoA) ; - add_unqualified - ( "poisson_log_glm_lpmf" - , ReturnType UReal - , [UArray UInt; UMatrix; UVector; UVector] - , SoA ) ; - add_unqualified - ( "poisson_log_glm_lpmf" - , ReturnType UReal - , [UInt; UMatrix; UVector; UVector] - , SoA ) ; - add_unqualified - ( "poisson_log_glm_lpmf" - , ReturnType UReal - , [UArray UInt; URowVector; UReal; UVector] - , SoA ) ; - add_unqualified - ( "poisson_log_glm_lpmf" - , ReturnType UReal - , [UArray UInt; URowVector; UVector; UVector] - , SoA ) ; - add_unqualified ("polar", ReturnType UComplex, [UReal; UReal], AoS) ; - add_nullary "positive_infinity" ; - add_binary_vec "pow" AoS ; - add_unqualified ("pow", ReturnType UComplex, [UComplex; UReal], AoS) ; - add_unqualified ("pow", ReturnType UComplex, [UComplex; UComplex], AoS) ; - add_unqualified ("prod", ReturnType UInt, [UArray UInt], AoS) ; - add_unqualified ("prod", ReturnType UReal, [UArray UReal], AoS) ; - add_unqualified ("prod", ReturnType UReal, [UVector], AoS) ; - add_unqualified ("prod", ReturnType UReal, [URowVector], AoS) ; - add_unqualified ("prod", ReturnType UReal, [UMatrix], AoS) ; - add_unqualified ("prod", ReturnType UComplex, [UArray UComplex], AoS) ; - add_unqualified ("prod", ReturnType UComplex, [UComplexVector], AoS) ; - add_unqualified ("prod", ReturnType UComplex, [UComplexRowVector], AoS) ; - add_unqualified ("prod", ReturnType UComplex, [UComplexMatrix], AoS) ; - add_unqualified ("proj", ReturnType UComplex, [UComplex], AoS) ; - add_unqualified ("quad_form", ReturnType UReal, [UMatrix; UVector], SoA) ; - add_unqualified ("quad_form", ReturnType UMatrix, [UMatrix; UMatrix], SoA) ; - add_unqualified ("quad_form_sym", ReturnType UReal, [UMatrix; UVector], AoS) ; - add_unqualified ("quad_form_sym", ReturnType UMatrix, [UMatrix; UMatrix], AoS) ; - add_unqualified ("quad_form_diag", ReturnType UMatrix, [UMatrix; UVector], AoS) ; - add_unqualified - ("quad_form_diag", ReturnType UMatrix, [UMatrix; URowVector], AoS) ; - add_qualified - ( "quantile" - , ReturnType UReal - , [(DataOnly, UArray UReal); (DataOnly, UReal)] - , SoA ) ; - add_qualified - ( "quantile" - , ReturnType (UArray UReal) - , [(DataOnly, UArray UReal); (DataOnly, UArray UReal)] - , SoA ) ; - add_qualified - ("quantile", ReturnType UReal, [(DataOnly, UVector); (DataOnly, UReal)], SoA) ; - add_qualified - ( "quantile" - , ReturnType (UArray UReal) - , [(DataOnly, UVector); (DataOnly, UArray UReal)] - , SoA ) ; - add_qualified - ( "quantile" - , ReturnType UReal - , [(DataOnly, URowVector); (DataOnly, UReal)] - , SoA ) ; - add_qualified - ( "quantile" - , ReturnType (UArray UReal) - , [(DataOnly, URowVector); (DataOnly, UArray UReal)] - , SoA ) ; - add_unqualified ("rank", ReturnType UInt, [UArray UInt; UInt], AoS) ; - add_unqualified ("rank", ReturnType UInt, [UArray UReal; UInt], AoS) ; - add_unqualified ("rank", ReturnType UInt, [UVector; UInt], AoS) ; - add_unqualified ("rank", ReturnType UInt, [URowVector; UInt], AoS) ; - add_unqualified ("append_row", ReturnType UMatrix, [UMatrix; UMatrix], AoS) ; - add_unqualified ("append_row", ReturnType UMatrix, [URowVector; UMatrix], AoS) ; - add_unqualified ("append_row", ReturnType UMatrix, [UMatrix; URowVector], AoS) ; - add_unqualified - ("append_row", ReturnType UMatrix, [URowVector; URowVector], AoS) ; - add_unqualified ("append_row", ReturnType UVector, [UVector; UVector], AoS) ; - add_unqualified ("append_row", ReturnType UVector, [UReal; UVector], AoS) ; - add_unqualified ("append_row", ReturnType UVector, [UVector; UReal], AoS) ; - add_unqualified - ( "append_row" - , ReturnType UComplexMatrix - , [UComplexMatrix; UComplexMatrix] - , AoS ) ; - add_unqualified - ( "append_row" - , ReturnType UComplexMatrix - , [UComplexRowVector; UComplexMatrix] - , AoS ) ; - add_unqualified - ( "append_row" - , ReturnType UComplexMatrix - , [UComplexMatrix; UComplexRowVector] - , AoS ) ; - add_unqualified - ( "append_row" - , ReturnType UComplexMatrix - , [UComplexRowVector; UComplexRowVector] - , AoS ) ; - add_unqualified - ( "append_row" - , ReturnType UComplexVector - , [UComplexVector; UComplexVector] - , AoS ) ; - add_unqualified - ("append_row", ReturnType UComplexVector, [UComplex; UComplexVector], AoS) ; - add_unqualified - ("append_row", ReturnType UComplexVector, [UComplexVector; UComplex], AoS) ; - List.iter - ~f:(fun t -> - add_unqualified - ("rep_array", ReturnType (bare_array_type (t, 1)), [t; UInt], SoA) ; - add_unqualified - ("rep_array", ReturnType (bare_array_type (t, 2)), [t; UInt; UInt], SoA) ; - add_unqualified - ( "rep_array" - , ReturnType (bare_array_type (t, 3)) - , [t; UInt; UInt; UInt] - , SoA ) ; - List.iter - ~f:(fun j -> - add_unqualified - ( "rep_array" - , ReturnType (bare_array_type (t, j + 1)) - , [bare_array_type (t, j); UInt] - , SoA ) ; - add_unqualified - ( "rep_array" - , ReturnType (bare_array_type (t, j + 2)) - , [bare_array_type (t, j); UInt; UInt] - , SoA ) ; - add_unqualified - ( "rep_array" - , ReturnType (bare_array_type (t, j + 3)) - , [bare_array_type (t, j); UInt; UInt; UInt] - , SoA ) ) - (List.range 1 3) ) - bare_types ; - add_unqualified ("rep_matrix", ReturnType UMatrix, [UReal; UInt; UInt], SoA) ; - add_unqualified ("rep_matrix", ReturnType UMatrix, [UVector; UInt], AoS) ; - add_unqualified ("rep_matrix", ReturnType UMatrix, [URowVector; UInt], AoS) ; - add_unqualified - ("rep_matrix", ReturnType UComplexMatrix, [UComplex; UInt; UInt], AoS) ; - add_unqualified - ("rep_matrix", ReturnType UComplexMatrix, [UComplexVector; UInt], AoS) ; - add_unqualified - ("rep_matrix", ReturnType UComplexMatrix, [UComplexRowVector; UInt], AoS) ; - add_unqualified ("rep_row_vector", ReturnType URowVector, [UReal; UInt], SoA) ; - add_unqualified - ("rep_row_vector", ReturnType UComplexRowVector, [UComplex; UInt], AoS) ; - add_unqualified ("rep_vector", ReturnType UVector, [UReal; UInt], SoA) ; - add_unqualified - ("rep_vector", ReturnType UComplexVector, [UComplex; UInt], AoS) ; - add_unqualified ("reverse", ReturnType UVector, [UVector], SoA) ; - add_unqualified ("reverse", ReturnType URowVector, [URowVector], SoA) ; - List.iter - ~f:(fun i -> - List.iter - ~f:(fun t -> - add_unqualified - ( "reverse" - , ReturnType (bare_array_type (t, i)) - , [bare_array_type (t, i)] - , SoA ) ) - bare_types ) - (List.range 1 8) ; - add_unqualified ("reverse", ReturnType UComplexVector, [UComplexVector], SoA) ; - add_unqualified - ("reverse", ReturnType UComplexRowVector, [UComplexRowVector], SoA) ; - add_binary_vec_int_int "rising_factorial" AoS ; - add_binary_vec_real_int "rising_factorial" AoS ; - add_unqualified ("row", ReturnType URowVector, [UMatrix; UInt], SoA) ; - add_unqualified - ("row", ReturnType UComplexRowVector, [UComplexMatrix; UInt], AoS) ; - add_unqualified ("rows", ReturnType UInt, [UVector], SoA) ; - add_unqualified ("rows", ReturnType UInt, [URowVector], SoA) ; - add_unqualified ("rows", ReturnType UInt, [UMatrix], SoA) ; - add_unqualified ("rows", ReturnType UInt, [UComplexVector], SoA) ; - add_unqualified ("rows", ReturnType UInt, [UComplexRowVector], SoA) ; - add_unqualified ("rows", ReturnType UInt, [UComplexMatrix], SoA) ; - add_unqualified - ("rows_dot_product", ReturnType UVector, [UVector; UVector], AoS) ; - add_unqualified - ("rows_dot_product", ReturnType UVector, [URowVector; URowVector], AoS) ; - add_unqualified - ("rows_dot_product", ReturnType UVector, [UMatrix; UMatrix], SoA) ; - add_unqualified - ( "rows_dot_product" - , ReturnType UComplexVector - , [UComplexVector; UComplexVector] - , AoS ) ; - add_unqualified - ( "rows_dot_product" - , ReturnType UComplexVector - , [UComplexRowVector; UComplexRowVector] - , AoS ) ; - add_unqualified - ( "rows_dot_product" - , ReturnType UComplexVector - , [UComplexMatrix; UComplexMatrix] - , AoS ) ; - add_unqualified ("rows_dot_self", ReturnType UVector, [UVector], SoA) ; - add_unqualified ("rows_dot_self", ReturnType UVector, [URowVector], SoA) ; - add_unqualified ("rows_dot_self", ReturnType UVector, [UMatrix], SoA) ; - add_unqualified - ("rows_dot_self", ReturnType UComplexVector, [UComplexVector], AoS) ; - add_unqualified - ("rows_dot_self", ReturnType UComplexVector, [UComplexRowVector], AoS) ; - add_unqualified - ("rows_dot_self", ReturnType UComplexVector, [UComplexMatrix], AoS) ; - add_unqualified - ( "scale_matrix_exp_multiply" - , ReturnType UMatrix - , [UReal; UMatrix; UMatrix] - , AoS ) ; - add_unqualified ("sd", ReturnType UReal, [UArray UReal], SoA) ; - add_unqualified ("sd", ReturnType UReal, [UVector], SoA) ; - add_unqualified ("sd", ReturnType UReal, [URowVector], SoA) ; - add_unqualified ("sd", ReturnType UReal, [UMatrix], SoA) ; - add_unqualified - ("segment", ReturnType URowVector, [URowVector; UInt; UInt], SoA) ; - add_unqualified ("segment", ReturnType UVector, [UVector; UInt; UInt], SoA) ; - add_unqualified - ( "segment" - , ReturnType UComplexRowVector - , [UComplexRowVector; UInt; UInt] - , AoS ) ; - add_unqualified - ("segment", ReturnType UComplexVector, [UComplexVector; UInt; UInt], AoS) ; - List.iter - ~f:(fun t -> - List.iter - ~f:(fun j -> - add_unqualified - ( "segment" - , ReturnType (bare_array_type (t, j)) - , [bare_array_type (t, j); UInt; UInt] - , SoA ) ) - (List.range 1 4) ) - bare_types ; - add_unqualified ("sin", ReturnType UComplex, [UComplex], AoS) ; - add_unqualified ("sinh", ReturnType UComplex, [UComplex], AoS) ; - add_unqualified ("singular_values", ReturnType UVector, [UMatrix], SoA) ; - List.iter - ~f:(fun i -> - List.iter - ~f:(fun t -> - add_unqualified - ("size", ReturnType UInt, [bare_array_type (t, i)], SoA) ) - bare_types ) - (List.range 1 8) ; - List.iter - ~f:(fun t -> add_unqualified ("size", ReturnType UInt, [t], SoA)) - bare_types ; - add_unqualified ("softmax", ReturnType UVector, [UVector], SoA) ; - add_unqualified ("sort_asc", ReturnType (UArray UInt), [UArray UInt], AoS) ; - add_unqualified ("sort_asc", ReturnType (UArray UReal), [UArray UReal], AoS) ; - add_unqualified ("sort_asc", ReturnType UVector, [UVector], AoS) ; - add_unqualified ("sort_asc", ReturnType URowVector, [URowVector], AoS) ; - add_unqualified ("sort_desc", ReturnType (UArray UInt), [UArray UInt], AoS) ; - add_unqualified ("sort_desc", ReturnType (UArray UReal), [UArray UReal], AoS) ; - add_unqualified ("sort_desc", ReturnType UVector, [UVector], AoS) ; - add_unqualified ("sort_desc", ReturnType URowVector, [URowVector], AoS) ; - add_unqualified - ("sort_indices_asc", ReturnType (UArray UInt), [UArray UInt], AoS) ; - add_unqualified - ("sort_indices_asc", ReturnType (UArray UInt), [UArray UReal], AoS) ; - add_unqualified ("sort_indices_asc", ReturnType (UArray UInt), [UVector], AoS) ; - add_unqualified - ("sort_indices_asc", ReturnType (UArray UInt), [URowVector], AoS) ; - add_unqualified - ("sort_indices_desc", ReturnType (UArray UInt), [UArray UInt], AoS) ; - add_unqualified - ("sort_indices_desc", ReturnType (UArray UInt), [UArray UReal], AoS) ; - add_unqualified ("sort_indices_desc", ReturnType (UArray UInt), [UVector], AoS) ; - add_unqualified - ("sort_indices_desc", ReturnType (UArray UInt), [URowVector], AoS) ; - add_unqualified ("squared_distance", ReturnType UReal, [UReal; UReal], SoA) ; - add_unqualified ("squared_distance", ReturnType UReal, [UVector; UVector], SoA) ; - add_unqualified - ("squared_distance", ReturnType UReal, [URowVector; URowVector], SoA) ; - add_unqualified - ("squared_distance", ReturnType UReal, [UVector; URowVector], SoA) ; - add_unqualified - ("squared_distance", ReturnType UReal, [URowVector; UVector], SoA) ; - add_unqualified ("sqrt", ReturnType UComplex, [UComplex], AoS) ; - add_nullary "sqrt2" ; - add_unqualified - ("sub_col", ReturnType UVector, [UMatrix; UInt; UInt; UInt], SoA) ; - add_unqualified - ( "sub_col" - , ReturnType UComplexVector - , [UComplexMatrix; UInt; UInt; UInt] - , AoS ) ; - add_unqualified - ("sub_row", ReturnType URowVector, [UMatrix; UInt; UInt; UInt], SoA) ; - add_unqualified - ( "sub_row" - , ReturnType UComplexRowVector - , [UComplexMatrix; UInt; UInt; UInt] - , AoS ) ; - List.iter - ~f:(fun x -> add_unqualified ("subtract", ReturnType x, [x; x], SoA)) - bare_types ; - add_unqualified ("subtract", ReturnType UVector, [UVector; UReal], SoA) ; - add_unqualified ("subtract", ReturnType URowVector, [URowVector; UReal], SoA) ; - add_unqualified ("subtract", ReturnType UMatrix, [UMatrix; UReal], SoA) ; - add_unqualified ("subtract", ReturnType UVector, [UReal; UVector], SoA) ; - add_unqualified ("subtract", ReturnType URowVector, [UReal; URowVector], SoA) ; - add_unqualified ("subtract", ReturnType UMatrix, [UReal; UMatrix], SoA) ; - add_unqualified ("sum", ReturnType UInt, [UArray UInt], SoA) ; - add_unqualified ("sum", ReturnType UReal, [UArray UReal], SoA) ; - add_unqualified ("sum", ReturnType UReal, [UVector], SoA) ; - add_unqualified ("sum", ReturnType UReal, [URowVector], SoA) ; - add_unqualified ("sum", ReturnType UReal, [UMatrix], SoA) ; - add_unqualified ("sum", ReturnType UComplex, [UArray UComplex], SoA) ; - add_unqualified ("sum", ReturnType UComplex, [UComplexVector], SoA) ; - add_unqualified ("sum", ReturnType UComplex, [UComplexRowVector], SoA) ; - add_unqualified ("sum", ReturnType UComplex, [UComplexMatrix], SoA) ; - add_unqualified ("svd_U", ReturnType UMatrix, [UMatrix], SoA) ; - add_unqualified ("svd_V", ReturnType UMatrix, [UMatrix], SoA) ; - add_unqualified - ("symmetrize_from_lower_tri", ReturnType UMatrix, [UMatrix], AoS) ; - add_unqualified - ( "symmetrize_from_lower_tri" - , ReturnType UComplexMatrix - , [UComplexMatrix] - , AoS ) ; - add_unqualified ("tail", ReturnType URowVector, [URowVector; UInt], SoA) ; - add_unqualified ("tail", ReturnType UVector, [UVector; UInt], SoA) ; - add_unqualified - ("tail", ReturnType UComplexRowVector, [UComplexRowVector; UInt], AoS) ; - add_unqualified - ("tail", ReturnType UComplexVector, [UComplexVector; UInt], AoS) ; - List.iter - ~f:(fun t -> - List.iter - ~f:(fun j -> - add_unqualified - ( "tail" - , ReturnType (bare_array_type (t, j)) - , [bare_array_type (t, j); UInt] - , SoA ) ) - (List.range 1 4) ) - bare_types ; - add_unqualified ("tan", ReturnType UComplex, [UComplex], AoS) ; - add_unqualified ("tanh", ReturnType UComplex, [UComplex], AoS) ; - add_unqualified ("tcrossprod", ReturnType UMatrix, [UMatrix], SoA) ; - add_unqualified ("to_array_1d", ReturnType (UArray UReal), [UMatrix], AoS) ; - add_unqualified ("to_array_1d", ReturnType (UArray UReal), [UVector], AoS) ; - add_unqualified ("to_array_1d", ReturnType (UArray UReal), [URowVector], AoS) ; - add_unqualified - ("to_array_1d", ReturnType (UArray UComplex), [UComplexMatrix], AoS) ; - add_unqualified - ("to_array_1d", ReturnType (UArray UComplex), [UComplexVector], AoS) ; - add_unqualified - ("to_array_1d", ReturnType (UArray UComplex), [UComplexRowVector], AoS) ; - List.iter - ~f:(fun i -> - add_unqualified - ( "to_array_1d" - , ReturnType (UArray UReal) - , [bare_array_type (UReal, i)] - , AoS ) ; - add_unqualified - ( "to_array_1d" - , ReturnType (UArray UInt) - , [bare_array_type (UInt, i)] - , AoS ) ) - (List.range 1 10) ; - add_unqualified - ("to_array_2d", ReturnType (bare_array_type (UReal, 2)), [UMatrix], AoS) ; - add_unqualified - ( "to_array_2d" - , ReturnType (bare_array_type (UComplex, 2)) - , [UComplexMatrix] - , AoS ) ; - add_unqualified ("to_complex", ReturnType UComplex, [], AoS) ; - add_unqualified ("to_complex", ReturnType UComplex, [UReal; UReal], AoS) ; - add_unqualified ("to_complex", ReturnType UComplex, [UReal], AoS) ; - add_unqualified ("to_matrix", ReturnType UMatrix, [UMatrix], AoS) ; - add_unqualified ("to_matrix", ReturnType UMatrix, [UMatrix; UInt; UInt], AoS) ; - add_unqualified - ("to_matrix", ReturnType UMatrix, [UMatrix; UInt; UInt; UInt], AoS) ; - add_unqualified ("to_matrix", ReturnType UMatrix, [UVector], AoS) ; - add_unqualified ("to_matrix", ReturnType UMatrix, [UVector; UInt; UInt], AoS) ; - add_unqualified - ("to_matrix", ReturnType UMatrix, [UVector; UInt; UInt; UInt], AoS) ; - add_unqualified ("to_matrix", ReturnType UMatrix, [URowVector], AoS) ; - add_unqualified ("to_matrix", ReturnType UMatrix, [UArray URowVector], AoS) ; - add_unqualified - ("to_matrix", ReturnType UMatrix, [URowVector; UInt; UInt], AoS) ; - add_unqualified - ("to_matrix", ReturnType UMatrix, [URowVector; UInt; UInt; UInt], AoS) ; - add_unqualified - ("to_matrix", ReturnType UMatrix, [UArray UReal; UInt; UInt], AoS) ; - add_unqualified - ("to_matrix", ReturnType UMatrix, [UArray UReal; UInt; UInt; UInt], AoS) ; - add_unqualified - ("to_matrix", ReturnType UMatrix, [UArray UInt; UInt; UInt], AoS) ; - add_unqualified - ("to_matrix", ReturnType UMatrix, [UArray UInt; UInt; UInt; UInt], AoS) ; - add_unqualified - ("to_matrix", ReturnType UMatrix, [bare_array_type (UReal, 2)], AoS) ; - add_unqualified - ("to_matrix", ReturnType UMatrix, [bare_array_type (UInt, 2)], AoS) ; - add_unqualified ("to_matrix", ReturnType UComplexMatrix, [UComplexMatrix], AoS) ; - add_unqualified - ("to_matrix", ReturnType UComplexMatrix, [UComplexMatrix; UInt; UInt], AoS) ; - add_unqualified - ( "to_matrix" - , ReturnType UComplexMatrix - , [UComplexMatrix; UInt; UInt; UInt] - , AoS ) ; - add_unqualified ("to_matrix", ReturnType UComplexMatrix, [UComplexVector], AoS) ; - add_unqualified - ("to_matrix", ReturnType UComplexMatrix, [UComplexVector; UInt; UInt], AoS) ; - add_unqualified - ( "to_matrix" - , ReturnType UComplexMatrix - , [UComplexVector; UInt; UInt; UInt] - , AoS ) ; - add_unqualified - ("to_matrix", ReturnType UComplexMatrix, [UComplexRowVector], AoS) ; - add_unqualified - ("to_matrix", ReturnType UComplexMatrix, [UArray UComplexRowVector], AoS) ; - add_unqualified - ( "to_matrix" - , ReturnType UComplexMatrix - , [UComplexRowVector; UInt; UInt] - , AoS ) ; - add_unqualified - ( "to_matrix" - , ReturnType UComplexMatrix - , [UComplexRowVector; UInt; UInt; UInt] - , AoS ) ; - add_unqualified - ("to_matrix", ReturnType UComplexMatrix, [UArray UComplex; UInt; UInt], AoS) ; - add_unqualified - ( "to_matrix" - , ReturnType UComplexMatrix - , [UArray UComplex; UInt; UInt; UInt] - , AoS ) ; - add_unqualified - ( "to_matrix" - , ReturnType UComplexMatrix - , [bare_array_type (UComplex, 2)] - , AoS ) ; - add_unqualified ("to_row_vector", ReturnType URowVector, [UMatrix], AoS) ; - add_unqualified ("to_row_vector", ReturnType URowVector, [UVector], AoS) ; - add_unqualified ("to_row_vector", ReturnType URowVector, [URowVector], AoS) ; - add_unqualified ("to_row_vector", ReturnType URowVector, [UArray UReal], AoS) ; - add_unqualified ("to_row_vector", ReturnType URowVector, [UArray UInt], AoS) ; - add_unqualified - ("to_row_vector", ReturnType UComplexRowVector, [UComplexMatrix], AoS) ; - add_unqualified - ("to_row_vector", ReturnType UComplexRowVector, [UComplexVector], AoS) ; - add_unqualified - ("to_row_vector", ReturnType UComplexRowVector, [UComplexRowVector], AoS) ; - add_unqualified - ("to_row_vector", ReturnType UComplexRowVector, [UArray UComplex], AoS) ; - add_unqualified ("to_vector", ReturnType UVector, [UMatrix], SoA) ; - add_unqualified ("to_vector", ReturnType UVector, [UVector], SoA) ; - add_unqualified ("to_vector", ReturnType UVector, [URowVector], SoA) ; - add_unqualified ("to_vector", ReturnType UVector, [UArray UReal], AoS) ; - add_unqualified ("to_vector", ReturnType UVector, [UArray UInt], AoS) ; - add_unqualified ("to_vector", ReturnType UComplexVector, [UComplexMatrix], AoS) ; - add_unqualified ("to_vector", ReturnType UComplexVector, [UComplexVector], AoS) ; - add_unqualified - ("to_vector", ReturnType UComplexVector, [UComplexRowVector], AoS) ; - add_unqualified - ("to_vector", ReturnType UComplexVector, [UArray UComplex], AoS) ; - add_unqualified ("trace", ReturnType UReal, [UMatrix], SoA) ; - add_unqualified ("trace", ReturnType UComplex, [UComplexMatrix], AoS) ; - add_unqualified - ("trace_gen_quad_form", ReturnType UReal, [UMatrix; UMatrix; UMatrix], SoA) ; - add_unqualified ("trace_quad_form", ReturnType UReal, [UMatrix; UVector], SoA) ; - add_unqualified ("trace_quad_form", ReturnType UReal, [UMatrix; UMatrix], SoA) ; - add_unqualified ("transpose", ReturnType URowVector, [UVector], SoA) ; - add_unqualified ("transpose", ReturnType UVector, [URowVector], SoA) ; - add_unqualified ("transpose", ReturnType UMatrix, [UMatrix], SoA) ; - add_unqualified - ("transpose", ReturnType UComplexRowVector, [UComplexVector], SoA) ; - add_unqualified - ("transpose", ReturnType UComplexVector, [UComplexRowVector], SoA) ; - add_unqualified ("transpose", ReturnType UComplexMatrix, [UComplexMatrix], SoA) ; - add_unqualified ("uniform_simplex", ReturnType UVector, [UInt], SoA) ; - add_unqualified ("variance", ReturnType UReal, [UArray UReal], SoA) ; - add_unqualified ("variance", ReturnType UReal, [UVector], SoA) ; - add_unqualified ("variance", ReturnType UReal, [URowVector], SoA) ; - add_unqualified ("variance", ReturnType UReal, [UMatrix], SoA) ; - add_unqualified ("wishart_rng", ReturnType UMatrix, [UReal; UMatrix], AoS) ; - add_unqualified ("zeros_int_array", ReturnType (UArray UInt), [UInt], SoA) ; - add_unqualified ("zeros_array", ReturnType (UArray UReal), [UInt], SoA) ; - add_unqualified ("zeros_row_vector", ReturnType URowVector, [UInt], SoA) ; - add_unqualified ("zeros_vector", ReturnType UVector, [UInt], SoA) ; - (* Now add all the manually added stuff to the main hashtable used - for type-checking *) - Hashtbl.iteri manual_stan_math_signatures ~f:(fun ~key ~data -> - List.iter data ~f:(fun data -> - Hashtbl.add_multi stan_math_signatures ~key ~data ) ) - -let%expect_test "dist name suffix" = - dist_name_suffix [] "normal" |> print_endline ; - [%expect {| _lpdf |}] - -let%expect_test "declarative distributions" = - let special_suffixes = - String.Set.of_list - Utils.(["lpmf"; "lpdf"; "log"] @ cumulative_distribution_suffices_w_rng) - in - let d = - distributions - |> List.map ~f:(function _, n, _, _ -> n) - |> String.Set.of_list in - Hashtbl.keys stan_math_signatures - |> List.filter ~f:(fun name -> - match Utils.split_distribution_suffix name with - | Some (name, suffix) - when Set.mem special_suffixes suffix && not (Set.mem d name) -> - true - | _ -> false ) - |> Fmt.str "@[%a@]" Fmt.(list ~sep:cut string) - |> print_endline ; - [%expect {| - binomial_coefficient_log - multiply_log - lkj_cov_log |}] diff --git a/src/stan_math_backend/Stan_math_signatures.mli b/src/stan_math_backend/Stan_math_signatures.mli deleted file mode 100644 index 181333aca6..0000000000 --- a/src/stan_math_backend/Stan_math_signatures.mli +++ /dev/null @@ -1,75 +0,0 @@ -(** This module stores a table of all signatures from the Stan - math C++ library which are exposed to Stan, and some helper - functions for dealing with those signatures. -*) - -open Core_kernel -open Middle - -(** Function arguments are represented by their type an autodiff - type. This is [AutoDiffable] for everything except arguments - marked with the data keyword *) -type fun_arg = UnsizedType.autodifftype * UnsizedType.t - -(** Signatures consist of a return type, a list of arguments, and a flag - for whether or not those arguments can be Struct of Arrays objects *) -type signature = - UnsizedType.returntype * fun_arg list * Common.Helpers.mem_pattern - -val stan_math_signatures : (string, signature list) Hashtbl.t -(** Mapping from names to signature(s) of functions *) - -val is_stan_math_function_name : string -> bool -(** Equivalent to [Hashtbl.mem stan_math_signatures s]*) - -val pretty_print_all_math_sigs : unit Fmt.t -val pretty_print_all_math_distributions : unit Fmt.t - -type dimensionality - -type fkind = Lpmf | Lpdf | Rng | Cdf | Ccdf | UnaryVectorized -[@@deriving show {with_path= false}] - -val distributions : - (fkind list * string * dimensionality list * Common.Helpers.mem_pattern) list -(** The distribution {e families} exposed by the math library *) - -val dist_name_suffix : (string * 'a) list -> string -> string - -(** Helpers for dealing with operators as signatures *) - -val operator_to_stan_math_fns : Operator.t -> string list -val string_operator_to_stan_math_fns : string -> string -val make_assignmentoperator_stan_math_signatures : Operator.t -> signature list - -(** Special functions for the variadic signatures exposed *) - -(* TODO: We should think of a better encapsulization for these, - this doesn't scale well. -*) - -(* reduce_sum helpers *) -val is_reduce_sum_fn : string -> bool -val reduce_sum_slice_types : UnsizedType.t list - -(* variadic ODE helpers *) -val is_variadic_ode_fn : string -> bool -val is_variadic_ode_nonadjoint_tol_fn : string -> bool -val ode_tolerances_suffix : string -val variadic_ode_adjoint_fn : string -val variadic_ode_mandatory_arg_types : fun_arg list -val variadic_ode_mandatory_fun_args : fun_arg list -val variadic_ode_tol_arg_types : fun_arg list -val variadic_ode_adjoint_ctl_tol_arg_types : fun_arg list -val variadic_ode_fun_return_type : UnsizedType.t -val variadic_ode_return_type : UnsizedType.t - -(* variadic DAE helpers *) -val is_variadic_dae_fn : string -> bool -val is_variadic_dae_tol_fn : string -> bool -val dae_tolerances_suffix : string -val variadic_dae_mandatory_arg_types : fun_arg list -val variadic_dae_mandatory_fun_args : fun_arg list -val variadic_dae_tol_arg_types : fun_arg list -val variadic_dae_fun_return_type : UnsizedType.t -val variadic_dae_return_type : UnsizedType.t diff --git a/src/stan_math_backend/dune b/src/stan_math_backend/dune index 07581843c4..1821f261d1 100644 --- a/src/stan_math_backend/dune +++ b/src/stan_math_backend/dune @@ -1,7 +1,7 @@ (library (name stan_math_backend) (public_name stanc.stan_math_backend) - (libraries core_kernel re fmt middle yojson) + (libraries core_kernel re fmt frontend middle yojson) (private_modules mangle cpp_Json diff --git a/src/stanc/stanc.ml b/src/stanc/stanc.ml index a67c90e220..35151e9ed9 100644 --- a/src/stanc/stanc.ml +++ b/src/stanc/stanc.ml @@ -8,7 +8,9 @@ open Analysis_and_optimization open Middle (* Initialize functor modules with the Stan Math Library *) -module CppLibrary = Std_library_utils.NullLibrary +module CppLibrary : Std_library_utils.Library = + Stan_math_backend.Stan_math_library + module Typechecker = Typechecking.Make (CppLibrary) module Deprecations = Deprecation_analysis.Make (CppLibrary) module Canonicalizer = Canonicalize.Make (Deprecations) @@ -352,11 +354,11 @@ let main () = print_deprecated_arg_warning ; (* Deal with multiple modalities *) if !dump_stan_math_sigs then ( - Stan_math_signatures.pretty_print_all_math_sigs Format.std_formatter () ; + Stan_math_library.pretty_print_all_math_sigs Format.std_formatter () ; exit 0 ) ; if !dump_stan_math_distributions then ( - Stan_math_signatures.pretty_print_all_math_distributions - Format.std_formatter () ; + Stan_math_library.pretty_print_all_math_distributions Format.std_formatter + () ; exit 0 ) ; if !model_file = "" then model_file_err () ; (* if we only have functions, always compile as standalone *) From 1a3757f6047cab40c836fdf5ef43867ee9b3a0b7 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Wed, 27 Apr 2022 15:53:43 -0400 Subject: [PATCH 05/14] Parametrize optimizer, unit tests passing --- .../Debug_data_generation.ml | 5 +- .../Monotone_framework.ml | 4 +- src/analysis_and_optimization/Optimize.ml | 2418 ++++++++-------- src/analysis_and_optimization/Optimize.mli | 120 +- .../Partial_evaluation.ml | 1195 ++++++++ .../Partial_evaluation.mli | 8 + .../Partial_evaluator.ml | 1152 -------- .../Pedantic_analysis.ml | 5 +- src/analysis_and_optimization/dune | 2 +- src/frontend/Ast_to_Mir.ml | 8 - src/stan_math_backend/Stan_math_library.ml | 2507 +++++++++++++++++ src/stan_math_backend/Stan_math_library.mli | 41 + src/stanc/stanc.ml | 3 +- test/unit/Debug_data_generation_tests.ml | 21 +- test/unit/Desugar_test.ml | 3 + test/unit/Optimize.ml | 6 +- test/unit/Test_utils.ml | 12 +- 17 files changed, 5093 insertions(+), 2417 deletions(-) create mode 100644 src/analysis_and_optimization/Partial_evaluation.ml create mode 100644 src/analysis_and_optimization/Partial_evaluation.mli delete mode 100644 src/analysis_and_optimization/Partial_evaluator.ml create mode 100644 src/stan_math_backend/Stan_math_library.ml create mode 100644 src/stan_math_backend/Stan_math_library.mli diff --git a/src/analysis_and_optimization/Debug_data_generation.ml b/src/analysis_and_optimization/Debug_data_generation.ml index e3166aa0e4..9e5dbcbb2d 100644 --- a/src/analysis_and_optimization/Debug_data_generation.ml +++ b/src/analysis_and_optimization/Debug_data_generation.ml @@ -1,6 +1,9 @@ open Core_kernel open Middle +module Partial_evaluator = + Partial_evaluation.Make (Frontend.Std_library_utils.NullLibrary) + let rec transpose = function | [] :: _ -> [] | rows -> @@ -30,7 +33,7 @@ let rec vect_to_mat l m = let eval_expr m e = let e = Mir_utils.subst_expr m e in - let e = Partial_evaluator.eval_expr e in + let e = Partial_evaluator.try_eval_expr e in let rec strip_promotions (e : Middle.Expr.Typed.t) = match e.pattern with Promotion (e, _, _) -> strip_promotions e | _ -> e in diff --git a/src/analysis_and_optimization/Monotone_framework.ml b/src/analysis_and_optimization/Monotone_framework.ml index 442781b210..381fa736d5 100644 --- a/src/analysis_and_optimization/Monotone_framework.ml +++ b/src/analysis_and_optimization/Monotone_framework.ml @@ -309,7 +309,9 @@ let minimal_variables_lattice initial_variables = end ) (* The transfer function for a constant propagation analysis *) -let constant_propagation_transfer ?(preserve_stability = false) +let constant_propagation_transfer + (module Partial_evaluator : Partial_evaluation.PartialEvaluator) + ?(preserve_stability = false) (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) = ( module struct type labels = int diff --git a/src/analysis_and_optimization/Optimize.ml b/src/analysis_and_optimization/Optimize.ml index aa05c06f17..a3c8bc5380 100644 --- a/src/analysis_and_optimization/Optimize.ml +++ b/src/analysis_and_optimization/Optimize.ml @@ -6,1059 +6,1180 @@ open Common open Middle open Mir_utils -(** +type optimization_settings = + { function_inlining: bool + ; static_loop_unrolling: bool + ; one_step_loop_unrolling: bool + ; list_collapsing: bool + ; block_fixing: bool + ; allow_uninitialized_decls: bool + ; constant_propagation: bool + ; expression_propagation: bool + ; copy_propagation: bool + ; dead_code_elimination: bool + ; partial_evaluation: bool + ; lazy_code_motion: bool + ; optimize_ad_levels: bool + ; preserve_stability: bool + ; optimize_soa: bool } + +let settings_const b = + { function_inlining= b + ; static_loop_unrolling= b + ; one_step_loop_unrolling= b + ; list_collapsing= b + ; block_fixing= b + ; allow_uninitialized_decls= b + ; constant_propagation= b + ; expression_propagation= b + ; copy_propagation= b + ; dead_code_elimination= b + ; partial_evaluation= b + ; lazy_code_motion= b + ; optimize_ad_levels= b + ; preserve_stability= not b + ; optimize_soa= b } + +let all_optimizations : optimization_settings = settings_const true +let no_optimizations : optimization_settings = settings_const false + +type optimization_level = O0 | O1 | Oexperimental + +let level_optimizations (lvl : optimization_level) : optimization_settings = + match lvl with + | O0 -> no_optimizations + | O1 -> + { function_inlining= true + ; static_loop_unrolling= false + ; one_step_loop_unrolling= false + ; list_collapsing= true + ; block_fixing= true + ; constant_propagation= true + ; expression_propagation= false + ; copy_propagation= true + ; dead_code_elimination= true + ; partial_evaluation= true + ; lazy_code_motion= false + ; allow_uninitialized_decls= true + ; optimize_ad_levels= false + ; preserve_stability= false + ; optimize_soa= true } + | Oexperimental -> all_optimizations + +module type Optimizer = sig + val function_inlining : Program.Typed.t -> Program.Typed.t + val static_loop_unrolling : Program.Typed.t -> Program.Typed.t + val one_step_loop_unrolling : Program.Typed.t -> Program.Typed.t + val list_collapsing : Program.Typed.t -> Program.Typed.t + val block_fixing : Program.Typed.t -> Program.Typed.t + + val constant_propagation : + ?preserve_stability:bool -> Program.Typed.t -> Program.Typed.t + + val expression_propagation : + ?preserve_stability:bool -> Program.Typed.t -> Program.Typed.t + + val copy_propagation : Program.Typed.t -> Program.Typed.t + val dead_code_elimination : Program.Typed.t -> Program.Typed.t + val partial_evaluation : Program.Typed.t -> Program.Typed.t + + val lazy_code_motion : + ?preserve_stability:bool -> Program.Typed.t -> Program.Typed.t + + val optimize_ad_levels : Program.Typed.t -> Program.Typed.t + val allow_uninitialized_decls : Program.Typed.t -> Program.Typed.t + + val optimization_suite : + ?settings:optimization_settings -> Program.Typed.t -> Program.Typed.t +end + +module Make (StdLib : Frontend.Std_library_utils.Library) : Optimizer = struct + module Mem = Mem_pattern.Make (StdLib) + module Partial_evaluator = Partial_evaluation.Make (StdLib) + + (** Apply the transformation to each function body and to the rest of the program as one block. *) -let transform_program (mir : Program.Typed.t) - (transform : Stmt.Located.t -> Stmt.Located.t) : Program.Typed.t = - let packed_prog_body = - transform - { pattern= + let transform_program (mir : Program.Typed.t) + (transform : Stmt.Located.t -> Stmt.Located.t) : Program.Typed.t = + let packed_prog_body = + transform + { pattern= + SList + (List.map + ~f:(fun x -> + Stmt.Fixed.{pattern= SList x; meta= Location_span.empty} ) + [ mir.prepare_data; mir.transform_inits; mir.log_prob + ; mir.generate_quantities ] ) + ; meta= Location_span.empty } in + let transformed_prog_body = transform packed_prog_body in + let transformed_functions = + List.map mir.functions_block ~f:(fun fs -> + {fs with fdbody= Option.map ~f:transform fs.fdbody} ) in + match transformed_prog_body with + | { pattern= SList - (List.map - ~f:(fun x -> - Stmt.Fixed.{pattern= SList x; meta= Location_span.empty} ) - [ mir.prepare_data; mir.transform_inits; mir.log_prob - ; mir.generate_quantities ] ) - ; meta= Location_span.empty } in - let transformed_prog_body = transform packed_prog_body in - let transformed_functions = - List.map mir.functions_block ~f:(fun fs -> - {fs with fdbody= Option.map ~f:transform fs.fdbody} ) in - match transformed_prog_body with - | { pattern= - SList - [ {pattern= SList prepare_data'; _} - ; {pattern= SList transform_inits'; _}; {pattern= SList log_prob'; _} - ; {pattern= SList generate_quantities'; _} ] - ; _ } -> - { mir with - functions_block= transformed_functions - ; prepare_data= prepare_data' - ; transform_inits= transform_inits' - ; log_prob= log_prob' - ; generate_quantities= generate_quantities' } - | _ -> - raise (Failure "Something went wrong with program transformation packing!") - -(** - Apply the transformation to each function body and to each program block separately. -*) -let transform_program_blockwise (mir : Program.Typed.t) - (transform : - Stmt.Located.t Program.fun_def option -> Stmt.Located.t -> Stmt.Located.t - ) : Program.Typed.t = - let transform' fd s = - match transform fd {pattern= SList s; meta= Location_span.empty} with - | {pattern= SList l; _} -> l + [ {pattern= SList prepare_data'; _} + ; {pattern= SList transform_inits'; _}; {pattern= SList log_prob'; _} + ; {pattern= SList generate_quantities'; _} ] + ; _ } -> + { mir with + functions_block= transformed_functions + ; prepare_data= prepare_data' + ; transform_inits= transform_inits' + ; log_prob= log_prob' + ; generate_quantities= generate_quantities' } | _ -> raise (Failure "Something went wrong with program transformation packing!") - in - let transformed_functions = - List.map mir.functions_block ~f:(fun fs -> - {fs with fdbody= Option.map ~f:(transform (Some fs)) fs.fdbody} ) in - { mir with - functions_block= transformed_functions - ; prepare_data= transform' None mir.prepare_data - ; transform_inits= transform' None mir.transform_inits - ; log_prob= transform' None mir.log_prob - ; generate_quantities= transform' None mir.generate_quantities } -let map_no_loc l = - List.map ~f:(fun s -> Stmt.Fixed.{pattern= s; meta= Location_span.empty}) l + (** + Apply the transformation to each function body and to each program block separately. +*) + let transform_program_blockwise (mir : Program.Typed.t) + (transform : + Stmt.Located.t Program.fun_def option + -> Stmt.Located.t + -> Stmt.Located.t ) : Program.Typed.t = + let transform' fd s = + match transform fd {pattern= SList s; meta= Location_span.empty} with + | {pattern= SList l; _} -> l + | _ -> + raise + (Failure "Something went wrong with program transformation packing!") + in + let transformed_functions = + List.map mir.functions_block ~f:(fun fs -> + {fs with fdbody= Option.map ~f:(transform (Some fs)) fs.fdbody} ) + in + { mir with + functions_block= transformed_functions + ; prepare_data= transform' None mir.prepare_data + ; transform_inits= transform' None mir.transform_inits + ; log_prob= transform' None mir.log_prob + ; generate_quantities= transform' None mir.generate_quantities } + + let map_no_loc l = + List.map ~f:(fun s -> Stmt.Fixed.{pattern= s; meta= Location_span.empty}) l -let slist_no_loc l = Stmt.Fixed.Pattern.SList (map_no_loc l) -let block_no_loc l = Stmt.Fixed.Pattern.Block (map_no_loc l) + let slist_no_loc l = Stmt.Fixed.Pattern.SList (map_no_loc l) + let block_no_loc l = Stmt.Fixed.Pattern.Block (map_no_loc l) -let slist_concat_no_loc l stmt = - match l with [] -> stmt | l -> slist_no_loc (l @ [stmt]) + let slist_concat_no_loc l stmt = + match l with [] -> stmt | l -> slist_no_loc (l @ [stmt]) -let gen_inline_var (name : string) (id_var : string) = - Gensym.generate ~prefix:("inline_" ^ name ^ "_" ^ id_var ^ "_") () + let gen_inline_var (name : string) (id_var : string) = + Gensym.generate ~prefix:("inline_" ^ name ^ "_" ^ id_var ^ "_") () -let replace_fresh_local_vars (fname : string) stmt = - let f (m : (string, string) Core_kernel.Map.Poly.t) = function - | Stmt.Fixed.Pattern.Decl {decl_adtype; decl_type; decl_id; initialize} -> - let new_name = - match Map.Poly.find m decl_id with - | Some existing -> existing - | None -> gen_inline_var fname decl_id in - ( Stmt.Fixed.Pattern.Decl - {decl_adtype; decl_id= new_name; decl_type; initialize} - , Map.Poly.set m ~key:decl_id ~data:new_name ) - | Stmt.Fixed.Pattern.For {loopvar; lower; upper; body} -> - let new_name = - match Map.Poly.find m loopvar with - | Some existing -> existing - | None -> gen_inline_var fname loopvar in - ( Stmt.Fixed.Pattern.For {loopvar= new_name; lower; upper; body} - , Map.Poly.set m ~key:loopvar ~data:new_name ) - | Assignment ((var_name, ut, l), e) -> - let var_name = - match Map.Poly.find m var_name with - | None -> var_name - | Some var_name -> var_name in - (Stmt.Fixed.Pattern.Assignment ((var_name, ut, l), e), m) - | x -> (x, m) in - let s, m = map_rec_state_stmt_loc f Map.Poly.empty stmt in - name_subst_stmt m s + let replace_fresh_local_vars (fname : string) stmt = + let f (m : (string, string) Core_kernel.Map.Poly.t) = function + | Stmt.Fixed.Pattern.Decl {decl_adtype; decl_type; decl_id; initialize} -> + let new_name = + match Map.Poly.find m decl_id with + | Some existing -> existing + | None -> gen_inline_var fname decl_id in + ( Stmt.Fixed.Pattern.Decl + {decl_adtype; decl_id= new_name; decl_type; initialize} + , Map.Poly.set m ~key:decl_id ~data:new_name ) + | Stmt.Fixed.Pattern.For {loopvar; lower; upper; body} -> + let new_name = + match Map.Poly.find m loopvar with + | Some existing -> existing + | None -> gen_inline_var fname loopvar in + ( Stmt.Fixed.Pattern.For {loopvar= new_name; lower; upper; body} + , Map.Poly.set m ~key:loopvar ~data:new_name ) + | Assignment ((var_name, ut, l), e) -> + let var_name = + match Map.Poly.find m var_name with + | None -> var_name + | Some var_name -> var_name in + (Stmt.Fixed.Pattern.Assignment ((var_name, ut, l), e), m) + | x -> (x, m) in + let s, m = map_rec_state_stmt_loc f Map.Poly.empty stmt in + name_subst_stmt m s -let subst_args_stmt args es = - let m = Map.Poly.of_alist_exn (List.zip_exn args es) in - subst_stmt m + let subst_args_stmt args es = + let m = Map.Poly.of_alist_exn (List.zip_exn args es) in + subst_stmt m -(** + (** * Count the number of returns that happen in a statement *) -let rec count_returns Stmt.Fixed.{pattern; _} : int = - Stmt.Fixed.Pattern.fold - (fun acc _ -> acc) - (fun acc -> function - | Stmt.Fixed.{pattern= Return _; _} -> acc + 1 - | stmt -> acc + count_returns stmt ) - 0 pattern + let rec count_returns Stmt.Fixed.{pattern; _} : int = + Stmt.Fixed.Pattern.fold + (fun acc _ -> acc) + (fun acc -> function + | Stmt.Fixed.{pattern= Return _; _} -> acc + 1 + | stmt -> acc + count_returns stmt ) + 0 pattern -(* The strategy here is to wrap the function body in a dummy loop, then replace - returns with breaks. One issue is early return from internal loops - in - those cases, a break would only break out of the inner loop. The solution is - a flag variable to indicate whether a 'return' break has been called, and - then to check if that flag is set after each loop. Then, if a 'return' break - is called from an inner loop, there's a cascade of breaks all the way out of - the dummy loop. *) -let handle_early_returns (fname : string) opt_var stmt = - let returned = gen_inline_var fname "early_ret_check" in - let generate_inner_breaks num_returns stmt_pattern = - match stmt_pattern with - | Stmt.Fixed.Pattern.Return opt_ret -> ( - match (opt_var, opt_ret) with - | None, None when num_returns > 1 -> Stmt.Fixed.Pattern.Break - | None, None -> Stmt.Fixed.Pattern.Block [] - | Some name, Some e when num_returns > 1 -> - SList - [ Stmt.Fixed. + (* The strategy here is to wrap the function body in a dummy loop, then replace + returns with breaks. One issue is early return from internal loops - in + those cases, a break would only break out of the inner loop. The solution is + a flag variable to indicate whether a 'return' break has been called, and + then to check if that flag is set after each loop. Then, if a 'return' break + is called from an inner loop, there's a cascade of breaks all the way out of + the dummy loop. *) + let handle_early_returns (fname : string) opt_var stmt = + let returned = gen_inline_var fname "early_ret_check" in + let generate_inner_breaks num_returns stmt_pattern = + match stmt_pattern with + | Stmt.Fixed.Pattern.Return opt_ret -> ( + match (opt_var, opt_ret) with + | None, None when num_returns > 1 -> Stmt.Fixed.Pattern.Break + | None, None -> Stmt.Fixed.Pattern.Block [] + | Some name, Some e when num_returns > 1 -> + SList + [ Stmt.Fixed. + { pattern= + Assignment + ( (returned, UInt, []) + , Expr.Fixed. + { pattern= Lit (Int, "1") + ; meta= + Expr.Typed.Meta. + { type_= UInt + ; adlevel= DataOnly + ; loc= Location_span.empty } } ) + ; meta= Location_span.empty } + ; Stmt.Fixed. + { pattern= Assignment ((name, Expr.Typed.type_of e, []), e) + ; meta= Location_span.empty } + ; {pattern= Break; meta= Location_span.empty} ] + | Some name, Some e -> Assignment ((name, Expr.Typed.type_of e, []), e) + | Some _, None -> + Common.FatalError.fatal_error_msg + [%message + ( "Function should return a value but found an empty return \ + statement." + : string )] + | None, Some _ -> + Common.FatalError.fatal_error_msg + [%message + ( "Expected a void function but found a non-empty return \ + statement." + : string )] ) + | Stmt.Fixed.Pattern.For _ as loop when num_returns > 1 -> + Stmt.Fixed.Pattern.SList + [ Stmt.Fixed.{pattern= loop; meta= Location_span.empty} + ; Stmt.Fixed. { pattern= - Assignment - ( (returned, UInt, []) - , Expr.Fixed. - { pattern= Lit (Int, "1") + IfElse + ( Expr.Fixed. + { pattern= Var returned ; meta= Expr.Typed.Meta. { type_= UInt ; adlevel= DataOnly - ; loc= Location_span.empty } } ) - ; meta= Location_span.empty } - ; Stmt.Fixed. - { pattern= Assignment ((name, Expr.Typed.type_of e, []), e) - ; meta= Location_span.empty } - ; {pattern= Break; meta= Location_span.empty} ] - | Some name, Some e -> Assignment ((name, Expr.Typed.type_of e, []), e) - | Some _, None -> - Common.FatalError.fatal_error_msg - [%message - ( "Function should return a value but found an empty return \ - statement." - : string )] - | None, Some _ -> - Common.FatalError.fatal_error_msg - [%message - ( "Expected a void function but found a non-empty return \ - statement." - : string )] ) - | Stmt.Fixed.Pattern.For _ as loop when num_returns > 1 -> - Stmt.Fixed.Pattern.SList - [ Stmt.Fixed.{pattern= loop; meta= Location_span.empty} - ; Stmt.Fixed. - { pattern= - IfElse - ( Expr.Fixed. - { pattern= Var returned + ; loc= Location_span.empty } } + , {pattern= Break; meta= Location_span.empty} + , None ) + ; meta= Location_span.empty } ] + | x -> x in + let num_returns = count_returns stmt in + if num_returns > 1 then + Stmt.Fixed.Pattern.SList + [ Stmt.Fixed. + { pattern= + Decl + { decl_adtype= DataOnly + ; decl_id= returned + ; decl_type= Sized SInt + ; initialize= true } + ; meta= Location_span.empty } + ; Stmt.Fixed. + { pattern= + Assignment + ( (returned, UInt, []) + , Expr.Fixed. + { pattern= Lit (Int, "0") + ; meta= + Expr.Typed.Meta. + { type_= UInt + ; adlevel= DataOnly + ; loc= Location_span.empty } } ) + ; meta= Location_span.empty } + ; Stmt.Fixed. + { pattern= + Stmt.Fixed.Pattern.For + { loopvar= gen_inline_var fname "iterator" + ; lower= + Expr.Fixed. + { pattern= Lit (Int, "1") ; meta= Expr.Typed.Meta. { type_= UInt ; adlevel= DataOnly ; loc= Location_span.empty } } - , {pattern= Break; meta= Location_span.empty} - , None ) - ; meta= Location_span.empty } ] - | x -> x in - let num_returns = count_returns stmt in - if num_returns > 1 then - Stmt.Fixed.Pattern.SList - [ Stmt.Fixed. - { pattern= - Decl - { decl_adtype= DataOnly - ; decl_id= returned - ; decl_type= Sized SInt - ; initialize= true } - ; meta= Location_span.empty } - ; Stmt.Fixed. - { pattern= - Assignment - ( (returned, UInt, []) - , Expr.Fixed. - { pattern= Lit (Int, "0") - ; meta= - Expr.Typed.Meta. - { type_= UInt - ; adlevel= DataOnly - ; loc= Location_span.empty } } ) - ; meta= Location_span.empty } - ; Stmt.Fixed. - { pattern= - Stmt.Fixed.Pattern.For - { loopvar= gen_inline_var fname "iterator" - ; lower= - Expr.Fixed. + ; upper= { pattern= Lit (Int, "1") ; meta= - Expr.Typed.Meta. - { type_= UInt - ; adlevel= DataOnly - ; loc= Location_span.empty } } - ; upper= - { pattern= Lit (Int, "1") - ; meta= - { type_= UInt - ; adlevel= DataOnly - ; loc= Location_span.empty } } - ; body= - map_rec_stmt_loc (generate_inner_breaks num_returns) stmt } - ; meta= Location_span.empty } ] - else (map_rec_stmt_loc (generate_inner_breaks num_returns) stmt).pattern + { type_= UInt + ; adlevel= DataOnly + ; loc= Location_span.empty } } + ; body= + map_rec_stmt_loc (generate_inner_breaks num_returns) stmt + } + ; meta= Location_span.empty } ] + else (map_rec_stmt_loc (generate_inner_breaks num_returns) stmt).pattern -let inline_list f es = - let dse_list = List.map ~f es in - (* function arguments are evaluated from right to left in C++, so we need to reverse *) - let d_list = - List.concat (List.rev (List.map ~f:(function x, _, _ -> x) dse_list)) in - let s_list = - List.concat (List.rev (List.map ~f:(function _, x, _ -> x) dse_list)) in - let es = List.map ~f:(function _, _, x -> x) dse_list in - (d_list, s_list, es) + let inline_list f es = + let dse_list = List.map ~f es in + (* function arguments are evaluated from right to left in C++, so we need to reverse *) + let d_list = + List.concat (List.rev (List.map ~f:(function x, _, _ -> x) dse_list)) + in + let s_list = + List.concat (List.rev (List.map ~f:(function _, x, _ -> x) dse_list)) + in + let es = List.map ~f:(function _, _, x -> x) dse_list in + (d_list, s_list, es) -(* Triple is (declaration list, statement list, return expression) *) -let rec inline_function_expression propto adt fim (Expr.Fixed.{pattern; _} as e) - = - match pattern with - | Var _ -> ([], [], e) - | Lit (_, _) -> ([], [], e) - | Promotion (expr, ut, ad) -> - let d, sl, expr' = inline_function_expression propto adt fim expr in - (d, sl, {e with pattern= Promotion (expr', ut, ad)}) - | FunApp (kind, es) -> ( - let d_list, s_list, es = - inline_list (inline_function_expression propto adt fim) es in - match kind with - | CompilerInternal _ -> - (d_list, s_list, {e with pattern= FunApp (kind, es)}) - | UserDefined (fname, suffix) | StanLib (fname, suffix, _) -> ( - let suffix, fname' = - match suffix with - | FnLpdf propto' when propto' && propto -> - ( Fun_kind.FnLpdf true - , Utils.with_unnormalized_suffix fname - |> Option.value ~default:fname ) - | FnLpdf _ -> (Fun_kind.FnLpdf false, fname) - | _ -> (suffix, fname) in - match Map.find fim fname' with - | None -> - let fun_kind = - match kind with - | Fun_kind.UserDefined _ -> Fun_kind.UserDefined (fname, suffix) - | _ -> StanLib (fname, suffix, AoS) in - (d_list, s_list, {e with pattern= FunApp (fun_kind, es)}) - | Some (rt, args, body) -> - let inline_return_name = gen_inline_var fname "return" in - let handle = - handle_early_returns fname (Some inline_return_name) in - let d_list2, s_list2, (e : Expr.Typed.t) = - let decl_type = - Option.map ~f:Mir_utils.unsafe_unsized_to_sized_type rt - |> Option.value_exn in - ( [ Stmt.Fixed.Pattern.Decl - { decl_adtype= adt - ; decl_id= inline_return_name - ; decl_type - ; initialize= false } ] - (* We should minimize the code that's having its variables - replaced to avoid conflict with the (two) new dummy - variables introduced by inlining *) - , [ handle - (replace_fresh_local_vars fname - (subst_args_stmt args es body) ) ] - , { pattern= Var inline_return_name - ; meta= - Expr.Typed.Meta. - { type_= Type.to_unsized decl_type - ; adlevel= adt - ; loc= Location_span.empty } } ) in - let d_list = d_list @ d_list2 in - let s_list = s_list @ s_list2 in - (d_list, s_list, e) ) ) - | TernaryIf (e1, e2, e3) -> - let dl1, sl1, e1 = inline_function_expression propto adt fim e1 in - let dl2, sl2, e2 = inline_function_expression propto adt fim e2 in - let dl3, sl3, e3 = inline_function_expression propto adt fim e3 in - ( dl1 @ dl2 @ dl3 - , sl1 - @ [ Stmt.Fixed.( + (* Triple is (declaration list, statement list, return expression) *) + let rec inline_function_expression propto adt fim + (Expr.Fixed.{pattern; _} as e) = + match pattern with + | Var _ -> ([], [], e) + | Lit (_, _) -> ([], [], e) + | Promotion (expr, ut, ad) -> + let d, sl, expr' = inline_function_expression propto adt fim expr in + (d, sl, {e with pattern= Promotion (expr', ut, ad)}) + | FunApp (kind, es) -> ( + let d_list, s_list, es = + inline_list (inline_function_expression propto adt fim) es in + match kind with + | CompilerInternal _ -> + (d_list, s_list, {e with pattern= FunApp (kind, es)}) + | UserDefined (fname, suffix) | StanLib (fname, suffix, _) -> ( + let suffix, fname' = + match suffix with + | FnLpdf propto' when propto' && propto -> + ( Fun_kind.FnLpdf true + , Utils.with_unnormalized_suffix fname + |> Option.value ~default:fname ) + | FnLpdf _ -> (Fun_kind.FnLpdf false, fname) + | _ -> (suffix, fname) in + match Map.find fim fname' with + | None -> + let fun_kind = + match kind with + | Fun_kind.UserDefined _ -> + Fun_kind.UserDefined (fname, suffix) + | _ -> StanLib (fname, suffix, AoS) in + (d_list, s_list, {e with pattern= FunApp (fun_kind, es)}) + | Some (rt, args, body) -> + let inline_return_name = gen_inline_var fname "return" in + let handle = + handle_early_returns fname (Some inline_return_name) in + let d_list2, s_list2, (e : Expr.Typed.t) = + let decl_type = + Option.map ~f:Mir_utils.unsafe_unsized_to_sized_type rt + |> Option.value_exn in + ( [ Stmt.Fixed.Pattern.Decl + { decl_adtype= adt + ; decl_id= inline_return_name + ; decl_type + ; initialize= false } ] + (* We should minimize the code that's having its variables + replaced to avoid conflict with the (two) new dummy + variables introduced by inlining *) + , [ handle + (replace_fresh_local_vars fname + (subst_args_stmt args es body) ) ] + , { pattern= Var inline_return_name + ; meta= + Expr.Typed.Meta. + { type_= Type.to_unsized decl_type + ; adlevel= adt + ; loc= Location_span.empty } } ) in + let d_list = d_list @ d_list2 in + let s_list = s_list @ s_list2 in + (d_list, s_list, e) ) ) + | TernaryIf (e1, e2, e3) -> + let dl1, sl1, e1 = inline_function_expression propto adt fim e1 in + let dl2, sl2, e2 = inline_function_expression propto adt fim e2 in + let dl3, sl3, e3 = inline_function_expression propto adt fim e3 in + ( dl1 @ dl2 @ dl3 + , sl1 + @ [ Stmt.Fixed.( + Pattern.IfElse + ( e1 + , {pattern= block_no_loc sl2; meta= Location_span.empty} + , Some {pattern= block_no_loc sl3; meta= Location_span.empty} + )) ] + , {e with pattern= TernaryIf (e1, e2, e3)} ) + | Indexed (e', i_list) -> + let dl, sl, e' = inline_function_expression propto adt fim e' in + let d_list, s_list, i_list = + inline_list (inline_function_index propto adt fim) i_list in + (d_list @ dl, s_list @ sl, {e with pattern= Indexed (e', i_list)}) + | EAnd (e1, e2) -> + let dl1, sl1, e1 = inline_function_expression propto adt fim e1 in + let dl2, sl2, e2 = inline_function_expression propto adt fim e2 in + let sl2 = + [ Stmt.Fixed.( Pattern.IfElse ( e1 - , {pattern= block_no_loc sl2; meta= Location_span.empty} - , Some {pattern= block_no_loc sl3; meta= Location_span.empty} )) - ] - , {e with pattern= TernaryIf (e1, e2, e3)} ) - | Indexed (e', i_list) -> - let dl, sl, e' = inline_function_expression propto adt fim e' in - let d_list, s_list, i_list = - inline_list (inline_function_index propto adt fim) i_list in - (d_list @ dl, s_list @ sl, {e with pattern= Indexed (e', i_list)}) - | EAnd (e1, e2) -> - let dl1, sl1, e1 = inline_function_expression propto adt fim e1 in - let dl2, sl2, e2 = inline_function_expression propto adt fim e2 in - let sl2 = - [ Stmt.Fixed.( - Pattern.IfElse - ( e1 - , {pattern= Block (map_no_loc sl2); meta= Location_span.empty} - , None )) ] in - (dl1 @ dl2, sl1 @ sl2, {e with pattern= EAnd (e1, e2)}) - | EOr (e1, e2) -> - let dl1, sl1, e1 = inline_function_expression propto adt fim e1 in - let dl2, sl2, e2 = inline_function_expression propto adt fim e2 in - let sl2 = - [ Stmt.Fixed.( - Pattern.IfElse - ( e1 - , {pattern= Skip; meta= Location_span.empty} - , Some {pattern= Block (map_no_loc sl2); meta= Location_span.empty} - )) ] in - (dl1 @ dl2, sl1 @ sl2, {e with pattern= EOr (e1, e2)}) + , {pattern= Block (map_no_loc sl2); meta= Location_span.empty} + , None )) ] in + (dl1 @ dl2, sl1 @ sl2, {e with pattern= EAnd (e1, e2)}) + | EOr (e1, e2) -> + let dl1, sl1, e1 = inline_function_expression propto adt fim e1 in + let dl2, sl2, e2 = inline_function_expression propto adt fim e2 in + let sl2 = + [ Stmt.Fixed.( + Pattern.IfElse + ( e1 + , {pattern= Skip; meta= Location_span.empty} + , Some + {pattern= Block (map_no_loc sl2); meta= Location_span.empty} + )) ] in + (dl1 @ dl2, sl1 @ sl2, {e with pattern= EOr (e1, e2)}) -and inline_function_index propto adt fim i = - match i with - | All -> ([], [], All) - | Single e -> - let dl, sl, e = inline_function_expression propto adt fim e in - (dl, sl, Single e) - | Upfrom e -> - let dl, sl, e = inline_function_expression propto adt fim e in - (dl, sl, Upfrom e) - | Between (e1, e2) -> - let dl1, sl1, e1 = inline_function_expression propto adt fim e1 in - let dl2, sl2, e2 = inline_function_expression propto adt fim e2 in - (dl1 @ dl2, sl1 @ sl2, Between (e1, e2)) - | MultiIndex e -> - let dl, sl, e = inline_function_expression propto adt fim e in - (dl, sl, MultiIndex e) + and inline_function_index propto adt fim i = + match i with + | All -> ([], [], All) + | Single e -> + let dl, sl, e = inline_function_expression propto adt fim e in + (dl, sl, Single e) + | Upfrom e -> + let dl, sl, e = inline_function_expression propto adt fim e in + (dl, sl, Upfrom e) + | Between (e1, e2) -> + let dl1, sl1, e1 = inline_function_expression propto adt fim e1 in + let dl2, sl2, e2 = inline_function_expression propto adt fim e2 in + (dl1 @ dl2, sl1 @ sl2, Between (e1, e2)) + | MultiIndex e -> + let dl, sl, e = inline_function_expression propto adt fim e in + (dl, sl, MultiIndex e) -let rec inline_function_statement propto adt fim Stmt.Fixed.{pattern; meta} = - Stmt.Fixed. - { pattern= - ( match pattern with - | Assignment ((assignee, ut, idx_lst), rhs) -> - let dl1, sl1, new_idx_lst = - inline_list (inline_function_index propto adt fim) idx_lst in - let dl2, sl2, new_rhs = - inline_function_expression propto adt fim rhs in - slist_concat_no_loc - (dl2 @ dl1 @ sl2 @ sl1) - (Assignment ((assignee, ut, new_idx_lst), new_rhs)) - | TargetPE e -> - let d, s, e = inline_function_expression propto adt fim e in - slist_concat_no_loc (d @ s) (TargetPE e) - | NRFunApp (kind, exprs) -> - let d_list, s_list, es = - inline_list (inline_function_expression propto adt fim) exprs - in - slist_concat_no_loc (d_list @ s_list) - ( match kind with - | CompilerInternal _ -> NRFunApp (kind, es) - | UserDefined (s, _) | StanLib (s, _, _) -> ( - match Map.find fim s with - | None -> NRFunApp (kind, es) - | Some (_, args, b) -> - let b = replace_fresh_local_vars s b in - let b = handle_early_returns s None b in - (subst_args_stmt args es - {pattern= b; meta= Location_span.empty} ) - .pattern ) ) - | Return e -> ( - match e with - | None -> Return None - | Some expr -> + let rec inline_function_statement propto adt fim Stmt.Fixed.{pattern; meta} = + Stmt.Fixed. + { pattern= + ( match pattern with + | Assignment ((assignee, ut, idx_lst), rhs) -> + let dl1, sl1, new_idx_lst = + inline_list (inline_function_index propto adt fim) idx_lst in + let dl2, sl2, new_rhs = + inline_function_expression propto adt fim rhs in + slist_concat_no_loc + (dl2 @ dl1 @ sl2 @ sl1) + (Assignment ((assignee, ut, new_idx_lst), new_rhs)) + | TargetPE e -> + let d, s, e = inline_function_expression propto adt fim e in + slist_concat_no_loc (d @ s) (TargetPE e) + | NRFunApp (kind, exprs) -> + let d_list, s_list, es = + inline_list (inline_function_expression propto adt fim) exprs + in + slist_concat_no_loc (d_list @ s_list) + ( match kind with + | CompilerInternal _ -> NRFunApp (kind, es) + | UserDefined (s, _) | StanLib (s, _, _) -> ( + match Map.find fim s with + | None -> NRFunApp (kind, es) + | Some (_, args, b) -> + let b = replace_fresh_local_vars s b in + let b = handle_early_returns s None b in + (subst_args_stmt args es + {pattern= b; meta= Location_span.empty} ) + .pattern ) ) + | Return e -> ( + match e with + | None -> Return None + | Some expr -> + let d, s, e = inline_function_expression propto adt fim expr in + slist_concat_no_loc (d @ s) (Return (Some e)) ) + | IfElse (expr, s1, s2) -> let d, s, e = inline_function_expression propto adt fim expr in - slist_concat_no_loc (d @ s) (Return (Some e)) ) - | IfElse (expr, s1, s2) -> - let d, s, e = inline_function_expression propto adt fim expr in - slist_concat_no_loc (d @ s) - (IfElse - ( e - , inline_function_statement propto adt fim s1 - , Option.map ~f:(inline_function_statement propto adt fim) s2 - ) ) - | While (expr, stmt) -> - let d', s', e = inline_function_expression propto adt fim expr in - slist_concat_no_loc (d' @ s') - (While - ( e - , match s' with - | [] -> inline_function_statement propto adt fim stmt - | _ -> - { pattern= - Block - ( [inline_function_statement propto adt fim stmt] - @ map_no_loc s' ) - ; meta= Location_span.empty } ) ) - | For {loopvar; lower; upper; body} -> - let d_lower, s_lower, lower = - inline_function_expression propto adt fim lower in - let d_upper, s_upper, upper = - inline_function_expression propto adt fim upper in - slist_concat_no_loc - (d_lower @ d_upper @ s_lower @ s_upper) - (For - { loopvar - ; lower - ; upper - ; body= - ( match s_upper with - | [] -> inline_function_statement propto adt fim body + slist_concat_no_loc (d @ s) + (IfElse + ( e + , inline_function_statement propto adt fim s1 + , Option.map ~f:(inline_function_statement propto adt fim) s2 + ) ) + | While (expr, stmt) -> + let d', s', e = inline_function_expression propto adt fim expr in + slist_concat_no_loc (d' @ s') + (While + ( e + , match s' with + | [] -> inline_function_statement propto adt fim stmt | _ -> { pattern= Block - ( [inline_function_statement propto adt fim body] - @ map_no_loc s_upper ) - ; meta= Location_span.empty } ) } ) - | Profile (name, l) -> - Profile - (name, List.map l ~f:(inline_function_statement propto adt fim)) - | Block l -> - Block (List.map l ~f:(inline_function_statement propto adt fim)) - | SList l -> - SList (List.map l ~f:(inline_function_statement propto adt fim)) - | Decl r -> Decl r - | Skip -> Skip - | Break -> Break - | Continue -> Continue ) - ; meta } + ( [inline_function_statement propto adt fim stmt] + @ map_no_loc s' ) + ; meta= Location_span.empty } ) ) + | For {loopvar; lower; upper; body} -> + let d_lower, s_lower, lower = + inline_function_expression propto adt fim lower in + let d_upper, s_upper, upper = + inline_function_expression propto adt fim upper in + slist_concat_no_loc + (d_lower @ d_upper @ s_lower @ s_upper) + (For + { loopvar + ; lower + ; upper + ; body= + ( match s_upper with + | [] -> inline_function_statement propto adt fim body + | _ -> + { pattern= + Block + ( [ inline_function_statement propto adt fim + body ] + @ map_no_loc s_upper ) + ; meta= Location_span.empty } ) } ) + | Profile (name, l) -> + Profile + (name, List.map l ~f:(inline_function_statement propto adt fim)) + | Block l -> + Block (List.map l ~f:(inline_function_statement propto adt fim)) + | SList l -> + SList (List.map l ~f:(inline_function_statement propto adt fim)) + | Decl r -> Decl r + | Skip -> Skip + | Break -> Break + | Continue -> Continue ) + ; meta } -let create_function_inline_map adt l = - (* We only add the first definition for each function to the inline map. - This will make sure we do not inline recursive functions. - We also don't want to add any function declaration (as opposed to - definitions), because that would replace the function call with a Skip. - *) - let f (accum, visited) Program.{fdname; fdargs; fdbody; fdrt; _} = - (* If we see a function more than once, - remove it to prevent inlining of overloaded functions + let create_function_inline_map adt l = + (* We only add the first definition for each function to the inline map. + This will make sure we do not inline recursive functions. + We also don't want to add any function declaration (as opposed to + definitions), because that would replace the function call with a Skip. *) - if Set.mem visited fdname then (Map.remove accum fdname, visited) - else - let accum' = - match fdbody with - | None -> accum - | Some fdbody -> ( - let create_data propto = - ( Option.map ~f:(fun x -> Type.Unsized x) fdrt - , List.map ~f:(fun (_, name, _) -> name) fdargs - , inline_function_statement propto adt accum fdbody ) in - match Middle.Utils.with_unnormalized_suffix fdname with - | None -> ( - let data = create_data true in - match Map.add accum ~key:fdname ~data with - | `Ok m -> m - | `Duplicate -> accum ) - | Some fdname' -> - let data = create_data false in - let data' = create_data true in - let m = - Map.Poly.of_alist_exn [(fdname, data); (fdname', data')] in - Map.merge_skewed accum m ~combine:(fun ~key:_ f _ -> f) ) in - let visited' = Set.add visited fdname in - (accum', visited') in - let accum, _ = List.fold l ~init:(Map.Poly.empty, Set.Poly.empty) ~f in - accum + let f (accum, visited) Program.{fdname; fdargs; fdbody; fdrt; _} = + (* If we see a function more than once, + remove it to prevent inlining of overloaded functions + *) + if Set.mem visited fdname then (Map.remove accum fdname, visited) + else + let accum' = + match fdbody with + | None -> accum + | Some fdbody -> ( + let create_data propto = + ( Option.map ~f:(fun x -> Type.Unsized x) fdrt + , List.map ~f:(fun (_, name, _) -> name) fdargs + , inline_function_statement propto adt accum fdbody ) in + match Middle.Utils.with_unnormalized_suffix fdname with + | None -> ( + let data = create_data true in + match Map.add accum ~key:fdname ~data with + | `Ok m -> m + | `Duplicate -> accum ) + | Some fdname' -> + let data = create_data false in + let data' = create_data true in + let m = + Map.Poly.of_alist_exn [(fdname, data); (fdname', data')] + in + Map.merge_skewed accum m ~combine:(fun ~key:_ f _ -> f) ) + in + let visited' = Set.add visited fdname in + (accum', visited') in + let accum, _ = List.fold l ~init:(Map.Poly.empty, Set.Poly.empty) ~f in + accum -let function_inlining (mir : Program.Typed.t) = - let dataonly_inline_map = - create_function_inline_map UnsizedType.DataOnly mir.functions_block in - let autodiff_inline_map = - create_function_inline_map UnsizedType.AutoDiffable mir.functions_block - in - let dataonly_inline_function_statements = - List.map - ~f: - (inline_function_statement true UnsizedType.DataOnly dataonly_inline_map) - in - let autodiffable_inline_function_statements = - List.map - ~f: - (inline_function_statement true UnsizedType.AutoDiffable - autodiff_inline_map ) in - { mir with - prepare_data= dataonly_inline_function_statements mir.prepare_data - ; transform_inits= autodiffable_inline_function_statements mir.transform_inits - ; log_prob= autodiffable_inline_function_statements mir.log_prob - ; generate_quantities= - dataonly_inline_function_statements mir.generate_quantities } + let function_inlining (mir : Program.Typed.t) = + let dataonly_inline_map = + create_function_inline_map UnsizedType.DataOnly mir.functions_block in + let autodiff_inline_map = + create_function_inline_map UnsizedType.AutoDiffable mir.functions_block + in + let dataonly_inline_function_statements = + List.map + ~f: + (inline_function_statement true UnsizedType.DataOnly + dataonly_inline_map ) in + let autodiffable_inline_function_statements = + List.map + ~f: + (inline_function_statement true UnsizedType.AutoDiffable + autodiff_inline_map ) in + { mir with + prepare_data= dataonly_inline_function_statements mir.prepare_data + ; transform_inits= + autodiffable_inline_function_statements mir.transform_inits + ; log_prob= autodiffable_inline_function_statements mir.log_prob + ; generate_quantities= + dataonly_inline_function_statements mir.generate_quantities } -let rec contains_top_break_or_continue Stmt.Fixed.{pattern; _} = - match pattern with - | Break | Continue -> true - | Assignment (_, _) - |TargetPE _ - |NRFunApp (_, _) - |Return _ | Decl _ - |While (_, _) - |For _ | Skip -> - false - | Profile (_, l) | Block l | SList l -> - List.exists l ~f:contains_top_break_or_continue - | IfElse (_, b1, b2) -> ( - contains_top_break_or_continue b1 - || - match b2 with None -> false | Some b -> contains_top_break_or_continue b ) + let rec contains_top_break_or_continue Stmt.Fixed.{pattern; _} = + match pattern with + | Break | Continue -> true + | Assignment (_, _) + |TargetPE _ + |NRFunApp (_, _) + |Return _ | Decl _ + |While (_, _) + |For _ | Skip -> + false + | Profile (_, l) | Block l | SList l -> + List.exists l ~f:contains_top_break_or_continue + | IfElse (_, b1, b2) -> ( + contains_top_break_or_continue b1 + || + match b2 with + | None -> false + | Some b -> contains_top_break_or_continue b ) -let unroll_static_limit = 32 + let unroll_static_limit = 32 -let unroll_static_loops_statement _ = - let f stmt = - match stmt with - | Stmt.Fixed.Pattern.For {loopvar; lower; upper; body} -> ( - let lower = Partial_evaluator.try_eval_expr lower in - let upper = Partial_evaluator.try_eval_expr upper in - match - (contains_top_break_or_continue body, lower.pattern, upper.pattern) - with - | false, Lit (Int, low_str), Lit (Int, up_str) -> - let low = Int.of_string low_str in - let up = Int.of_string up_str in - if up - low > unroll_static_limit then stmt - else - let range = - List.map - ~f:(fun i -> - Expr.Fixed. - { pattern= Lit (Int, Int.to_string i) - ; meta= - Expr.Typed.Meta. - { type_= UInt - ; loc= Location_span.empty - ; adlevel= DataOnly } } ) - (List.range ~start:`inclusive ~stop:`inclusive low up) in - let stmts = - List.map - ~f:(fun i -> - subst_args_stmt [loopvar] [i] - {pattern= body.pattern; meta= Location_span.empty} ) - range in - Stmt.Fixed.Pattern.SList stmts - | _ -> stmt ) - | _ -> stmt in - top_down_map_rec_stmt_loc f + let unroll_static_loops_statement _ = + let f stmt = + match stmt with + | Stmt.Fixed.Pattern.For {loopvar; lower; upper; body} -> ( + let lower = Partial_evaluator.try_eval_expr lower in + let upper = Partial_evaluator.try_eval_expr upper in + match + (contains_top_break_or_continue body, lower.pattern, upper.pattern) + with + | false, Lit (Int, low_str), Lit (Int, up_str) -> + let low = Int.of_string low_str in + let up = Int.of_string up_str in + if up - low > unroll_static_limit then stmt + else + let range = + List.map + ~f:(fun i -> + Expr.Fixed. + { pattern= Lit (Int, Int.to_string i) + ; meta= + Expr.Typed.Meta. + { type_= UInt + ; loc= Location_span.empty + ; adlevel= DataOnly } } ) + (List.range ~start:`inclusive ~stop:`inclusive low up) in + let stmts = + List.map + ~f:(fun i -> + subst_args_stmt [loopvar] [i] + {pattern= body.pattern; meta= Location_span.empty} ) + range in + Stmt.Fixed.Pattern.SList stmts + | _ -> stmt ) + | _ -> stmt in + top_down_map_rec_stmt_loc f -let static_loop_unrolling mir = - transform_program_blockwise mir unroll_static_loops_statement + let static_loop_unrolling mir = + transform_program_blockwise mir unroll_static_loops_statement -let unroll_loop_one_step_statement _ = - let f stmt = - match stmt with - | Stmt.Fixed.Pattern.For {loopvar; lower; upper; body} -> - if contains_top_break_or_continue body then stmt - else - IfElse - ( Expr.Fixed. - { lower with - pattern= - FunApp (StanLib ("Geq__", FnPlain, AoS), [upper; lower]) } - , { pattern= - (let body_unrolled = - subst_args_stmt [loopvar] [lower] - {pattern= body.pattern; meta= Location_span.empty} in - let (body' : Stmt.Located.t) = - { pattern= - Stmt.Fixed.Pattern.For - { loopvar - ; upper - ; body - ; lower= - { lower with - pattern= - FunApp - ( StanLib ("Plus__", FnPlain, AoS) - , [lower; Expr.Helpers.loop_bottom] ) } } - ; meta= Location_span.empty } in - match body_unrolled.pattern with - | Block stmts -> Block (stmts @ [body']) - | _ -> Stmt.Fixed.Pattern.Block [body_unrolled; body'] ) - ; meta= Location_span.empty } - , None ) - | While (e, body) -> - if contains_top_break_or_continue body then stmt - else - IfElse - ( e - , { pattern= Block [body; {body with pattern= While (e, body)}] - ; meta= Location_span.empty } - , None ) - | _ -> stmt in - map_rec_stmt_loc f + let unroll_loop_one_step_statement _ = + let f stmt = + match stmt with + | Stmt.Fixed.Pattern.For {loopvar; lower; upper; body} -> + if contains_top_break_or_continue body then stmt + else + IfElse + ( Expr.Fixed. + { lower with + pattern= + FunApp (StanLib ("Geq__", FnPlain, AoS), [upper; lower]) + } + , { pattern= + (let body_unrolled = + subst_args_stmt [loopvar] [lower] + {pattern= body.pattern; meta= Location_span.empty} + in + let (body' : Stmt.Located.t) = + { pattern= + Stmt.Fixed.Pattern.For + { loopvar + ; upper + ; body + ; lower= + { lower with + pattern= + FunApp + ( StanLib ("Plus__", FnPlain, AoS) + , [lower; Expr.Helpers.loop_bottom] ) } + } + ; meta= Location_span.empty } in + match body_unrolled.pattern with + | Block stmts -> Block (stmts @ [body']) + | _ -> Stmt.Fixed.Pattern.Block [body_unrolled; body'] ) + ; meta= Location_span.empty } + , None ) + | While (e, body) -> + if contains_top_break_or_continue body then stmt + else + IfElse + ( e + , { pattern= Block [body; {body with pattern= While (e, body)}] + ; meta= Location_span.empty } + , None ) + | _ -> stmt in + map_rec_stmt_loc f -let one_step_loop_unrolling mir = - transform_program_blockwise mir unroll_loop_one_step_statement + let one_step_loop_unrolling mir = + transform_program_blockwise mir unroll_loop_one_step_statement -let collapse_lists_statement _ = - let rec collapse_lists l = - match l with - | [] -> [] - | Stmt.Fixed.{pattern= SList l'; _} :: rest -> l' @ collapse_lists rest - | x :: rest -> x :: collapse_lists rest in - let f = function - | Stmt.Fixed.Pattern.Block l -> Stmt.Fixed.Pattern.Block (collapse_lists l) - | SList l -> SList (collapse_lists l) - | x -> x in - map_rec_stmt_loc f + let collapse_lists_statement _ = + let rec collapse_lists l = + match l with + | [] -> [] + | Stmt.Fixed.{pattern= SList l'; _} :: rest -> l' @ collapse_lists rest + | x :: rest -> x :: collapse_lists rest in + let f = function + | Stmt.Fixed.Pattern.Block l -> + Stmt.Fixed.Pattern.Block (collapse_lists l) + | SList l -> SList (collapse_lists l) + | x -> x in + map_rec_stmt_loc f -let list_collapsing (mir : Program.Typed.t) = - transform_program_blockwise mir collapse_lists_statement + let list_collapsing (mir : Program.Typed.t) = + transform_program_blockwise mir collapse_lists_statement -let propagation - (propagation_transfer : - (int, Stmt.Located.Non_recursive.t) Map.Poly.t - -> (module Monotone_framework_sigs.TRANSFER_FUNCTION - with type labels = int - and type properties = (string, Middle.Expr.Typed.t) Map.Poly.t - option ) ) (mir : Program.Typed.t) = - let transform stmt = - let flowgraph, flowgraph_to_mir = - Monotone_framework.forward_flowgraph_of_stmt stmt in - let (module Flowgraph) = flowgraph in - let values = - Monotone_framework.propagation_mfp mir - (module Flowgraph) - flowgraph_to_mir propagation_transfer in - let propagate_stmt = - map_rec_stmt_loc_num flowgraph_to_mir (fun i -> - subst_stmt_base - (Option.value ~default:Map.Poly.empty (Map.find_exn values i).entry) ) - in - propagate_stmt (Map.find_exn flowgraph_to_mir 1) in - transform_program mir transform + let propagation + (propagation_transfer : + (int, Stmt.Located.Non_recursive.t) Map.Poly.t + -> (module Monotone_framework_sigs.TRANSFER_FUNCTION + with type labels = int + and type properties = (string, Middle.Expr.Typed.t) Map.Poly.t + option ) ) (mir : Program.Typed.t) = + let transform stmt = + let flowgraph, flowgraph_to_mir = + Monotone_framework.forward_flowgraph_of_stmt stmt in + let (module Flowgraph) = flowgraph in + let values = + Monotone_framework.propagation_mfp mir + (module Flowgraph) + flowgraph_to_mir propagation_transfer in + let propagate_stmt = + map_rec_stmt_loc_num flowgraph_to_mir (fun i -> + subst_stmt_base + (Option.value ~default:Map.Poly.empty + (Map.find_exn values i).entry ) ) in + propagate_stmt (Map.find_exn flowgraph_to_mir 1) in + transform_program mir transform -let constant_propagation ?(preserve_stability = false) = - propagation - (Monotone_framework.constant_propagation_transfer ~preserve_stability) + let constant_propagation ?(preserve_stability = false) = + propagation + (Monotone_framework.constant_propagation_transfer + (module Partial_evaluator) + ~preserve_stability ) -let rec expr_any pred (e : Expr.Typed.t) = - match e.pattern with - | Indexed (e, is) -> expr_any pred e || List.exists ~f:(idx_any pred) is - | _ -> pred e || Expr.Fixed.Pattern.fold (accum_any pred) false e.pattern + let rec expr_any pred (e : Expr.Typed.t) = + match e.pattern with + | Indexed (e, is) -> expr_any pred e || List.exists ~f:(idx_any pred) is + | _ -> pred e || Expr.Fixed.Pattern.fold (accum_any pred) false e.pattern -and idx_any pred (i : Expr.Typed.t Index.t) = - Index.fold (accum_any pred) false i + and idx_any pred (i : Expr.Typed.t Index.t) = + Index.fold (accum_any pred) false i -and accum_any pred b e = b || expr_any pred e + and accum_any pred b e = b || expr_any pred e -let can_side_effect_top_expr (e : Expr.Typed.t) = - match e.pattern with - | FunApp ((UserDefined (_, FnTarget) | StanLib (_, FnTarget, _)), _) -> true - | FunApp (CompilerInternal internal_fn, _) -> - Internal_fun.can_side_effect internal_fn - | _ -> false + let can_side_effect_top_expr (e : Expr.Typed.t) = + match e.pattern with + | FunApp ((UserDefined (_, FnTarget) | StanLib (_, FnTarget, _)), _) -> true + | FunApp (CompilerInternal internal_fn, _) -> + Internal_fun.can_side_effect internal_fn + | _ -> false -let cannot_duplicate_expr ?(preserve_stability = false) (e : Expr.Typed.t) = - let pred e = - can_side_effect_top_expr e - || ( match e.pattern with - | FunApp ((UserDefined (_, FnRng) | StanLib (_, FnRng, _)), _) -> true - | _ -> false ) - || (preserve_stability && UnsizedType.is_autodiffable e.meta.type_) in - expr_any pred e + let cannot_duplicate_expr ?(preserve_stability = false) (e : Expr.Typed.t) = + let pred e = + can_side_effect_top_expr e + || ( match e.pattern with + | FunApp ((UserDefined (_, FnRng) | StanLib (_, FnRng, _)), _) -> true + | _ -> false ) + || (preserve_stability && UnsizedType.is_autodiffable e.meta.type_) in + expr_any pred e -let cannot_remove_expr (e : Expr.Typed.t) = expr_any can_side_effect_top_expr e + let cannot_remove_expr (e : Expr.Typed.t) = + expr_any can_side_effect_top_expr e -let expression_propagation ?(preserve_stability = false) mir = - propagation - (Monotone_framework.expression_propagation_transfer ~preserve_stability - (cannot_duplicate_expr ~preserve_stability) ) - mir + let expression_propagation ?(preserve_stability = false) mir = + propagation + (Monotone_framework.expression_propagation_transfer ~preserve_stability + (cannot_duplicate_expr ~preserve_stability) ) + mir -let copy_propagation mir = - let globals = Monotone_framework.globals mir in - propagation (Monotone_framework.copy_propagation_transfer globals) mir + let copy_propagation mir = + let globals = Monotone_framework.globals mir in + propagation (Monotone_framework.copy_propagation_transfer globals) mir -let is_skip_break_continue s = - match s with Stmt.Fixed.Pattern.Skip | Break | Continue -> true | _ -> false + let is_skip_break_continue s = + match s with + | Stmt.Fixed.Pattern.Skip | Break | Continue -> true + | _ -> false -(* TODO: could also implement partial dead code elimination *) -let dead_code_elimination (mir : Program.Typed.t) = - (* TODO: think about whether we should treat function bodies as local scopes in the statement - from the POV of a live variables analysis. - (Obviously, this shouldn't be the case for the purposes of reaching definitions, - constant propagation, expressions analyses. But I do think that's the right way to - go about live variables. *) - let transform s = - let rev_flowgraph, flowgraph_to_mir = - Monotone_framework.inverse_flowgraph_of_stmt s in - let (module Rev_Flowgraph) = rev_flowgraph in - let live_variables = - Monotone_framework.live_variables_mfp mir - (module Rev_Flowgraph) - flowgraph_to_mir in - let dead_code_elim_stmt_base i stmt = - (* NOTE: entry in the reverse flowgraph, so exit in the forward flowgraph *) - let live_variables_s = - (Map.find_exn live_variables i).Monotone_framework_sigs.entry in - match stmt with - | Stmt.Fixed.Pattern.Assignment ((x, _, []), rhs) -> - if Set.Poly.mem live_variables_s x || cannot_remove_expr rhs then stmt - else Skip - | Assignment ((x, _, is), rhs) -> - if - Set.Poly.mem live_variables_s x - || cannot_remove_expr rhs - || List.exists ~f:(idx_any cannot_remove_expr) is - then stmt - else Skip - (* NOTE: we never get rid of declarations as we might not be able to - remove an assignment to a variable - due to side effects. *) - (* TODO: maybe we should revisit that. *) - | Decl _ | TargetPE _ - |NRFunApp (_, _) - |Break | Continue | Return _ | Skip -> - stmt - | IfElse (e, b1, b2) -> ( - if - (* TODO: check if e has side effects, like print, reject, then don't optimize? *) - (not (cannot_remove_expr e)) - && b1.Stmt.Fixed.pattern = Skip - && ( Option.map ~f:(fun Stmt.Fixed.{pattern; _} -> pattern) b2 - = Some Skip - || Option.map ~f:(fun Stmt.Fixed.{pattern; _} -> pattern) b2 - = None ) - then Skip - else - match e.pattern with - | Lit (Int, "0") | Lit (Real, "0.0") -> ( - match b2 with Some x -> x.pattern | None -> Skip ) - | Lit (_, _) -> b1.pattern - | _ -> IfElse (e, b1, b2) ) - | While (e, b) -> ( - if (not (cannot_remove_expr e)) && b.pattern = Break then Skip - else - match e.pattern with - | Lit (Int, "0") | Lit (Real, "0.0") -> Skip - | _ -> While (e, b) ) - | For {loopvar; lower; upper; body} -> - if - (not (cannot_remove_expr lower)) - && (not (cannot_remove_expr upper)) - && is_skip_break_continue body.pattern - then Skip - else For {loopvar; lower; upper; body} - | Profile (name, l) -> - let l' = List.filter ~f:(fun x -> x.pattern <> Skip) l in - if List.length l' = 0 then Skip else Profile (name, l') - | Block l -> - let l' = List.filter ~f:(fun x -> x.pattern <> Skip) l in - if List.length l' = 0 then Skip else Block l' - | SList l -> - let l' = List.filter ~f:(fun x -> x.pattern <> Skip) l in - SList l' in - let dead_code_elim_stmt = - map_rec_stmt_loc_num flowgraph_to_mir dead_code_elim_stmt_base in - dead_code_elim_stmt (Map.find_exn flowgraph_to_mir 1) in - transform_program mir transform + (* TODO: could also implement partial dead code elimination *) + let dead_code_elimination (mir : Program.Typed.t) = + (* TODO: think about whether we should treat function bodies as local scopes in the statement + from the POV of a live variables analysis. + (Obviously, this shouldn't be the case for the purposes of reaching definitions, + constant propagation, expressions analyses. But I do think that's the right way to + go about live variables. *) + let transform s = + let rev_flowgraph, flowgraph_to_mir = + Monotone_framework.inverse_flowgraph_of_stmt s in + let (module Rev_Flowgraph) = rev_flowgraph in + let live_variables = + Monotone_framework.live_variables_mfp mir + (module Rev_Flowgraph) + flowgraph_to_mir in + let dead_code_elim_stmt_base i stmt = + (* NOTE: entry in the reverse flowgraph, so exit in the forward flowgraph *) + let live_variables_s = + (Map.find_exn live_variables i).Monotone_framework_sigs.entry in + match stmt with + | Stmt.Fixed.Pattern.Assignment ((x, _, []), rhs) -> + if Set.Poly.mem live_variables_s x || cannot_remove_expr rhs then + stmt + else Skip + | Assignment ((x, _, is), rhs) -> + if + Set.Poly.mem live_variables_s x + || cannot_remove_expr rhs + || List.exists ~f:(idx_any cannot_remove_expr) is + then stmt + else Skip + (* NOTE: we never get rid of declarations as we might not be able to + remove an assignment to a variable + due to side effects. *) + (* TODO: maybe we should revisit that. *) + | Decl _ | TargetPE _ + |NRFunApp (_, _) + |Break | Continue | Return _ | Skip -> + stmt + | IfElse (e, b1, b2) -> ( + if + (* TODO: check if e has side effects, like print, reject, then don't optimize? *) + (not (cannot_remove_expr e)) + && b1.Stmt.Fixed.pattern = Skip + && ( Option.map ~f:(fun Stmt.Fixed.{pattern; _} -> pattern) b2 + = Some Skip + || Option.map ~f:(fun Stmt.Fixed.{pattern; _} -> pattern) b2 + = None ) + then Skip + else + match e.pattern with + | Lit (Int, "0") | Lit (Real, "0.0") -> ( + match b2 with Some x -> x.pattern | None -> Skip ) + | Lit (_, _) -> b1.pattern + | _ -> IfElse (e, b1, b2) ) + | While (e, b) -> ( + if (not (cannot_remove_expr e)) && b.pattern = Break then Skip + else + match e.pattern with + | Lit (Int, "0") | Lit (Real, "0.0") -> Skip + | _ -> While (e, b) ) + | For {loopvar; lower; upper; body} -> + if + (not (cannot_remove_expr lower)) + && (not (cannot_remove_expr upper)) + && is_skip_break_continue body.pattern + then Skip + else For {loopvar; lower; upper; body} + | Profile (name, l) -> + let l' = List.filter ~f:(fun x -> x.pattern <> Skip) l in + if List.length l' = 0 then Skip else Profile (name, l') + | Block l -> + let l' = List.filter ~f:(fun x -> x.pattern <> Skip) l in + if List.length l' = 0 then Skip else Block l' + | SList l -> + let l' = List.filter ~f:(fun x -> x.pattern <> Skip) l in + SList l' in + let dead_code_elim_stmt = + map_rec_stmt_loc_num flowgraph_to_mir dead_code_elim_stmt_base in + dead_code_elim_stmt (Map.find_exn flowgraph_to_mir 1) in + transform_program mir transform -let partial_evaluation = Partial_evaluator.eval_prog + let partial_evaluation = Partial_evaluator.eval_prog -(** + (** * Given a name and Stmt, search the statement for the first assignment * where that name is the assignee. *) -let rec find_assignment_idx (name : string) Stmt.Fixed.{pattern; _} = - match pattern with - | Stmt.Fixed.Pattern.Assignment - ((assign_name, lhs_ut, idx_lst), (rhs : 'a Expr.Fixed.t)) - when name = assign_name - && (not (Set.Poly.mem (expr_var_names_set rhs) assign_name)) - && not - ( rhs.meta.adlevel = UnsizedType.DataOnly - && UnsizedType.is_array lhs_ut ) -> - Some idx_lst - | _ -> None + let rec find_assignment_idx (name : string) Stmt.Fixed.{pattern; _} = + match pattern with + | Stmt.Fixed.Pattern.Assignment + ((assign_name, lhs_ut, idx_lst), (rhs : 'a Expr.Fixed.t)) + when name = assign_name + && (not (Set.Poly.mem (expr_var_names_set rhs) assign_name)) + && not + ( rhs.meta.adlevel = UnsizedType.DataOnly + && UnsizedType.is_array lhs_ut ) -> + Some idx_lst + | _ -> None -(** + (** * Given a list of Stmts, find Decls whose objects are fully assigned to * in their first assignment and mark them as not needing to be * initialized. *) -and unenforce_initialize (lst : Stmt.Located.t list) = - let rec unenforce_initialize_patt (Stmt.Fixed.{pattern; _} as stmt) sub_lst = - match pattern with - | Stmt.Fixed.Pattern.Decl ({decl_id; _} as decl_pat) -> ( - match List.hd sub_lst with - | Some next_stmt -> ( - match find_assignment_idx decl_id next_stmt with - | Some [] | Some [Index.All] | Some [Index.All; Index.All] -> - { stmt with - pattern= Stmt.Fixed.Pattern.Decl {decl_pat with initialize= false} - } - | None | Some _ -> stmt ) - | None -> stmt ) - | Block block_lst -> - {stmt with pattern= Block (unenforce_initialize block_lst)} - | SList s_lst -> {stmt with pattern= SList (unenforce_initialize s_lst)} - (*[] here because we do not want to check out of scope*) - | While (expr, stmt) -> - {stmt with pattern= While (expr, unenforce_initialize_patt stmt [])} - | For ({body; _} as pat) -> - { stmt with - pattern= For {pat with body= unenforce_initialize_patt body []} } - | Profile ((pname : string), stmts) -> - {stmt with pattern= Profile (pname, unenforce_initialize stmts)} - | IfElse ((expr : 'a Expr.Fixed.t), true_stmt, op_false_stmt) -> - let mod_false_stmt = - Option.map ~f:(fun x -> unenforce_initialize_patt x []) op_false_stmt - in - { stmt with - pattern= - IfElse (expr, unenforce_initialize_patt true_stmt [], mod_false_stmt) - } - | _ -> stmt in - match List.hd lst with - | Some stmt -> ( - match List.tl lst with - | Some sub_lst -> - List.cons - (unenforce_initialize_patt stmt sub_lst) - (unenforce_initialize sub_lst) - | None -> lst ) - | None -> lst + and unenforce_initialize (lst : Stmt.Located.t list) = + let rec unenforce_initialize_patt (Stmt.Fixed.{pattern; _} as stmt) sub_lst + = + match pattern with + | Stmt.Fixed.Pattern.Decl ({decl_id; _} as decl_pat) -> ( + match List.hd sub_lst with + | Some next_stmt -> ( + match find_assignment_idx decl_id next_stmt with + | Some [] | Some [Index.All] | Some [Index.All; Index.All] -> + { stmt with + pattern= + Stmt.Fixed.Pattern.Decl {decl_pat with initialize= false} } + | None | Some _ -> stmt ) + | None -> stmt ) + | Block block_lst -> + {stmt with pattern= Block (unenforce_initialize block_lst)} + | SList s_lst -> {stmt with pattern= SList (unenforce_initialize s_lst)} + (*[] here because we do not want to check out of scope*) + | While (expr, stmt) -> + {stmt with pattern= While (expr, unenforce_initialize_patt stmt [])} + | For ({body; _} as pat) -> + { stmt with + pattern= For {pat with body= unenforce_initialize_patt body []} } + | Profile ((pname : string), stmts) -> + {stmt with pattern= Profile (pname, unenforce_initialize stmts)} + | IfElse ((expr : 'a Expr.Fixed.t), true_stmt, op_false_stmt) -> + let mod_false_stmt = + Option.map + ~f:(fun x -> unenforce_initialize_patt x []) + op_false_stmt in + { stmt with + pattern= + IfElse + (expr, unenforce_initialize_patt true_stmt [], mod_false_stmt) + } + | _ -> stmt in + match List.hd lst with + | Some stmt -> ( + match List.tl lst with + | Some sub_lst -> + List.cons + (unenforce_initialize_patt stmt sub_lst) + (unenforce_initialize sub_lst) + | None -> lst ) + | None -> lst -(** + (** * Take the Mir and perform a transform that requires searching * across the list inside of each piece of the Mir. * @param mir The mir * @param transformer a function that takes in and returns a list of * Stmts. *) -let transform_mir_blocks (mir : (Expr.Typed.t, Stmt.Located.t) Program.t) - (transformer : Stmt.Located.t list -> Stmt.Located.t list) : - (Expr.Typed.t, Stmt.Located.t) Program.t = - let transformed_functions = - List.map mir.functions_block ~f:(fun fs -> - let new_body = - match fs.fdbody with - | Some (Stmt.Fixed.{pattern= SList lst; _} as stmt) -> - Some {stmt with pattern= SList (transformer lst)} - | Some (Stmt.Fixed.{pattern= Block lst; _} as stmt) -> - Some {stmt with pattern= Block (transformer lst)} - | alt -> alt in - {fs with fdbody= new_body} ) in - { Program.functions_block= transformed_functions - ; input_vars= mir.input_vars - ; prepare_data= transformer mir.prepare_data - ; log_prob= transformer mir.log_prob - ; generate_quantities= transformer mir.generate_quantities - ; transform_inits= transformer mir.transform_inits - ; output_vars= mir.output_vars - ; prog_name= mir.prog_name - ; prog_path= mir.prog_path } + let transform_mir_blocks (mir : (Expr.Typed.t, Stmt.Located.t) Program.t) + (transformer : Stmt.Located.t list -> Stmt.Located.t list) : + (Expr.Typed.t, Stmt.Located.t) Program.t = + let transformed_functions = + List.map mir.functions_block ~f:(fun fs -> + let new_body = + match fs.fdbody with + | Some (Stmt.Fixed.{pattern= SList lst; _} as stmt) -> + Some {stmt with pattern= SList (transformer lst)} + | Some (Stmt.Fixed.{pattern= Block lst; _} as stmt) -> + Some {stmt with pattern= Block (transformer lst)} + | alt -> alt in + {fs with fdbody= new_body} ) in + { Program.functions_block= transformed_functions + ; input_vars= mir.input_vars + ; prepare_data= transformer mir.prepare_data + ; log_prob= transformer mir.log_prob + ; generate_quantities= transformer mir.generate_quantities + ; transform_inits= transformer mir.transform_inits + ; output_vars= mir.output_vars + ; prog_name= mir.prog_name + ; prog_path= mir.prog_path } -let allow_uninitialized_decls mir = - transform_mir_blocks mir unenforce_initialize + let allow_uninitialized_decls mir = + transform_mir_blocks mir unenforce_initialize -let lazy_code_motion ?(preserve_stability = false) (mir : Program.Typed.t) = - (* TODO: clean up this code. It is not very pretty. *) - (* TODO: make lazy code motion operate on transformed parameters and models blocks - simultaneously *) - let preprocess_flowgraph = - let preprocess_flowgraph_base - (stmt : (Expr.Typed.t, Stmt.Located.t) Stmt.Fixed.Pattern.t) = - match stmt with - | IfElse (e, b1, Some b2) -> - Stmt.Fixed.( - Pattern.IfElse + let lazy_code_motion ?(preserve_stability = false) (mir : Program.Typed.t) = + (* TODO: clean up this code. It is not very pretty. *) + (* TODO: make lazy code motion operate on transformed parameters and models blocks + simultaneously *) + let preprocess_flowgraph = + let preprocess_flowgraph_base + (stmt : (Expr.Typed.t, Stmt.Located.t) Stmt.Fixed.Pattern.t) = + match stmt with + | IfElse (e, b1, Some b2) -> + Stmt.Fixed.( + Pattern.IfElse + ( e + , { pattern= + Block [b1; {pattern= Skip; meta= Location_span.empty}] + ; meta= Location_span.empty } + , Some + { pattern= + Block [b2; {pattern= Skip; meta= Location_span.empty}] + ; meta= Location_span.empty } )) + | IfElse (e, b, None) -> + IfElse ( e - , { pattern= Block [b1; {pattern= Skip; meta= Location_span.empty}] + , { pattern= Block [b; {pattern= Skip; meta= Location_span.empty}] ; meta= Location_span.empty } - , Some + , Some {pattern= Skip; meta= Location_span.empty} ) + | While (e, b) -> + While + ( e + , { pattern= Block [b; {pattern= Skip; meta= Location_span.empty}] + ; meta= Location_span.empty } ) + | For {loopvar; lower; upper; body= b} -> + For + { loopvar + ; lower + ; upper + ; body= { pattern= - Block [b2; {pattern= Skip; meta= Location_span.empty}] - ; meta= Location_span.empty } )) - | IfElse (e, b, None) -> - IfElse - ( e - , { pattern= Block [b; {pattern= Skip; meta= Location_span.empty}] - ; meta= Location_span.empty } - , Some {pattern= Skip; meta= Location_span.empty} ) - | While (e, b) -> - While - ( e - , { pattern= Block [b; {pattern= Skip; meta= Location_span.empty}] - ; meta= Location_span.empty } ) - | For {loopvar; lower; upper; body= b} -> - For - { loopvar - ; lower - ; upper - ; body= - { pattern= Block [b; {pattern= Skip; meta= Location_span.empty}] - ; meta= Location_span.empty } } - | _ -> stmt in - map_rec_stmt_loc preprocess_flowgraph_base in - let transform s = - let rev_flowgraph, flowgraph_to_mir = - Monotone_framework.inverse_flowgraph_of_stmt ~blocks_after_body:false s - in - let fwd_flowgraph = Monotone_framework.reverse rev_flowgraph in - let latest_expr, used_not_latest_expressions_mfp = - Monotone_framework.lazy_expressions_mfp fwd_flowgraph rev_flowgraph - flowgraph_to_mir in - let expression_map = - let rec collect_expressions accum (e : Expr.Typed.t) = - match e.pattern with - | Lit (_, _) -> accum - | Var _ -> accum - | _ when cannot_duplicate_expr ~preserve_stability e -> - (* Immovable expressions might have movable subexpressions *) - Expr.Fixed.Pattern.fold collect_expressions accum e.pattern - | _ -> Map.set accum ~key:e ~data:(Gensym.generate ~prefix:"lcm_" ()) + Block [b; {pattern= Skip; meta= Location_span.empty}] + ; meta= Location_span.empty } } + | _ -> stmt in + map_rec_stmt_loc preprocess_flowgraph_base in + let transform s = + let rev_flowgraph, flowgraph_to_mir = + Monotone_framework.inverse_flowgraph_of_stmt ~blocks_after_body:false s in - Set.fold - (Monotone_framework.used_expressions_stmt s.pattern) - ~init:Expr.Typed.Map.empty ~f:collect_expressions in - (* TODO: it'd be more efficient to just not accumulate constants in the static analysis *) - let declarations_list = - Map.fold expression_map ~init:[] ~f:(fun ~key ~data accum -> - Stmt.Fixed. - { pattern= - Pattern.Decl - { decl_adtype= Expr.Typed.adlevel_of key - ; decl_id= data - ; decl_type= Type.Unsized (Expr.Typed.type_of key) - ; initialize= true } - ; meta= Location_span.empty } - :: accum ) in - let lazy_code_motion_base i stmt = - let latest_and_used_after_i = - Set.inter - (Map.find_exn latest_expr i) - (Map.find_exn used_not_latest_expressions_mfp i).entry in - let to_assign_in_s = - latest_and_used_after_i - |> Set.filter ~f:(fun x -> Map.mem expression_map x) - |> Set.to_list - |> List.sort ~compare:(fun e e' -> - compare_int (expr_depth e) (expr_depth e') ) in - (* TODO: is this sort doing anything or are they already stored in the right order by - chance? It appears to not do anything. *) - let assignments_to_add_to_s = - List.map - ~f:(fun e -> + let fwd_flowgraph = Monotone_framework.reverse rev_flowgraph in + let latest_expr, used_not_latest_expressions_mfp = + Monotone_framework.lazy_expressions_mfp fwd_flowgraph rev_flowgraph + flowgraph_to_mir in + let expression_map = + let rec collect_expressions accum (e : Expr.Typed.t) = + match e.pattern with + | Lit (_, _) -> accum + | Var _ -> accum + | _ when cannot_duplicate_expr ~preserve_stability e -> + (* Immovable expressions might have movable subexpressions *) + Expr.Fixed.Pattern.fold collect_expressions accum e.pattern + | _ -> Map.set accum ~key:e ~data:(Gensym.generate ~prefix:"lcm_" ()) + in + Set.fold + (Monotone_framework.used_expressions_stmt s.pattern) + ~init:Expr.Typed.Map.empty ~f:collect_expressions in + (* TODO: it'd be more efficient to just not accumulate constants in the static analysis *) + let declarations_list = + Map.fold expression_map ~init:[] ~f:(fun ~key ~data accum -> Stmt.Fixed. { pattern= - Assignment - ((Map.find_exn expression_map e, e.meta.type_, []), e) - ; meta= Location_span.empty } ) - to_assign_in_s in - let expr_subst_stmt_except_initial_assign m = - let f stmt = - match stmt with - | Stmt.Fixed.Pattern.Assignment ((x, _, []), e') - when Map.mem m e' - && Expr.Typed.equal {e' with pattern= Var x} - (Map.find_exn m e') -> - expr_subst_stmt_base (Map.remove m e') stmt - | _ -> expr_subst_stmt_base m stmt in - map_rec_stmt_loc f in - let expr_map = - Map.filter_keys - ~f:(fun key -> - Set.mem latest_and_used_after_i key - || Set.mem (Map.find_exn used_not_latest_expressions_mfp i).exit key - ) - (Map.mapi expression_map ~f:(fun ~key ~data -> - {key with pattern= Var data} ) ) in - let f = expr_subst_stmt_except_initial_assign expr_map in - if List.length assignments_to_add_to_s = 0 then - (f Stmt.Fixed.{pattern= stmt; meta= Location_span.empty}).pattern - else - SList - (List.map ~f - ( assignments_to_add_to_s - @ [{pattern= stmt; meta= Location_span.empty}] ) ) in - let lazy_code_motion_stmt = - map_rec_stmt_loc_num flowgraph_to_mir lazy_code_motion_base in - Stmt.Fixed. - { pattern= + Pattern.Decl + { decl_adtype= Expr.Typed.adlevel_of key + ; decl_id= data + ; decl_type= Type.Unsized (Expr.Typed.type_of key) + ; initialize= true } + ; meta= Location_span.empty } + :: accum ) in + let lazy_code_motion_base i stmt = + let latest_and_used_after_i = + Set.inter + (Map.find_exn latest_expr i) + (Map.find_exn used_not_latest_expressions_mfp i).entry in + let to_assign_in_s = + latest_and_used_after_i + |> Set.filter ~f:(fun x -> Map.mem expression_map x) + |> Set.to_list + |> List.sort ~compare:(fun e e' -> + compare_int (expr_depth e) (expr_depth e') ) in + (* TODO: is this sort doing anything or are they already stored in the right order by + chance? It appears to not do anything. *) + let assignments_to_add_to_s = + List.map + ~f:(fun e -> + Stmt.Fixed. + { pattern= + Assignment + ((Map.find_exn expression_map e, e.meta.type_, []), e) + ; meta= Location_span.empty } ) + to_assign_in_s in + let expr_subst_stmt_except_initial_assign m = + let f stmt = + match stmt with + | Stmt.Fixed.Pattern.Assignment ((x, _, []), e') + when Map.mem m e' + && Expr.Typed.equal {e' with pattern= Var x} + (Map.find_exn m e') -> + expr_subst_stmt_base (Map.remove m e') stmt + | _ -> expr_subst_stmt_base m stmt in + map_rec_stmt_loc f in + let expr_map = + Map.filter_keys + ~f:(fun key -> + Set.mem latest_and_used_after_i key + || Set.mem (Map.find_exn used_not_latest_expressions_mfp i).exit + key ) + (Map.mapi expression_map ~f:(fun ~key ~data -> + {key with pattern= Var data} ) ) in + let f = expr_subst_stmt_except_initial_assign expr_map in + if List.length assignments_to_add_to_s = 0 then + (f Stmt.Fixed.{pattern= stmt; meta= Location_span.empty}).pattern + else SList - ( declarations_list - @ [lazy_code_motion_stmt (Map.find_exn flowgraph_to_mir 1)] ) - ; meta= Location_span.empty } in - let cleanup = - let cleanup_base (stmt : (Expr.Typed.t, Stmt.Located.t) Stmt.Fixed.Pattern.t) - : (Expr.Typed.t, Stmt.Located.t) Stmt.Fixed.Pattern.t = - match stmt with - | Stmt.Fixed.( - Pattern.IfElse + (List.map ~f + ( assignments_to_add_to_s + @ [{pattern= stmt; meta= Location_span.empty}] ) ) in + let lazy_code_motion_stmt = + map_rec_stmt_loc_num flowgraph_to_mir lazy_code_motion_base in + Stmt.Fixed. + { pattern= + SList + ( declarations_list + @ [lazy_code_motion_stmt (Map.find_exn flowgraph_to_mir 1)] ) + ; meta= Location_span.empty } in + let cleanup = + let cleanup_base + (stmt : (Expr.Typed.t, Stmt.Located.t) Stmt.Fixed.Pattern.t) : + (Expr.Typed.t, Stmt.Located.t) Stmt.Fixed.Pattern.t = + match stmt with + | Stmt.Fixed.( + Pattern.IfElse + ( e + , {pattern= Block [b1; {pattern= Skip; _}]; _} + , Some {pattern= Block [b2; {pattern= Skip; _}]; _} )) -> + IfElse (e, b1, Some b2) + | IfElse ( e - , {pattern= Block [b1; {pattern= Skip; _}]; _} - , Some {pattern= Block [b2; {pattern= Skip; _}]; _} )) -> - IfElse (e, b1, Some b2) - | IfElse - ( e - , {pattern= Block [b; {pattern= Skip; _}]; _} - , Some {pattern= Skip; _} ) -> - IfElse (e, b, None) - | While (e, {pattern= Block [b; {pattern= Skip; _}]; _}) -> While (e, b) - | For - { loopvar - ; lower - ; upper - ; body= {pattern= Block [b; {pattern= Skip; _}]; _} } -> - For {loopvar; lower; upper; body= b} - | _ -> stmt in - map_rec_stmt_loc cleanup_base in - transform_program_blockwise mir (fun _ x -> - cleanup (transform (preprocess_flowgraph x)) ) + , {pattern= Block [b; {pattern= Skip; _}]; _} + , Some {pattern= Skip; _} ) -> + IfElse (e, b, None) + | While (e, {pattern= Block [b; {pattern= Skip; _}]; _}) -> While (e, b) + | For + { loopvar + ; lower + ; upper + ; body= {pattern= Block [b; {pattern= Skip; _}]; _} } -> + For {loopvar; lower; upper; body= b} + | _ -> stmt in + map_rec_stmt_loc cleanup_base in + transform_program_blockwise mir (fun _ x -> + cleanup (transform (preprocess_flowgraph x)) ) -let block_fixing mir = - transform_program_blockwise mir (fun _ x -> - (map_rec_stmt_loc (fun stmt -> - match stmt with - | IfElse - ( e - , {pattern= SList l; meta} - , Some {pattern= SList l'; meta= smeta'} ) -> - IfElse + let block_fixing mir = + transform_program_blockwise mir (fun _ x -> + (map_rec_stmt_loc (fun stmt -> + match stmt with + | IfElse ( e - , {pattern= Block l; meta} - , Some {pattern= Block l'; meta= smeta'} ) - | IfElse (e, {pattern= SList l; meta}, b) -> - IfElse (e, {pattern= Block l; meta}, b) - | IfElse (e, b, Some {pattern= SList l'; meta= smeta'}) -> - IfElse (e, b, Some {pattern= Block l'; meta= smeta'}) - | While (e, {pattern= SList l; meta}) -> - While (e, {pattern= Block l; meta}) - | For {loopvar; lower; upper; body= {pattern= SList l; meta}} -> - For {loopvar; lower; upper; body= {pattern= Block l; meta}} - | _ -> stmt ) ) - x ) + , {pattern= SList l; meta} + , Some {pattern= SList l'; meta= smeta'} ) -> + IfElse + ( e + , {pattern= Block l; meta} + , Some {pattern= Block l'; meta= smeta'} ) + | IfElse (e, {pattern= SList l; meta}, b) -> + IfElse (e, {pattern= Block l; meta}, b) + | IfElse (e, b, Some {pattern= SList l'; meta= smeta'}) -> + IfElse (e, b, Some {pattern= Block l'; meta= smeta'}) + | While (e, {pattern= SList l; meta}) -> + While (e, {pattern= Block l; meta}) + | For {loopvar; lower; upper; body= {pattern= SList l; meta}} -> + For {loopvar; lower; upper; body= {pattern= Block l; meta}} + | _ -> stmt ) ) + x ) -(* TODO: implement SlicStan style optimizer for choosing best program block for each statement. *) -(* TODO: add optimization pass to move declarations down as much as possible and introduce as - tight as possible local scopes *) -(* TODO: add tests *) -(* TODO: add pass to get rid of redundant declarations? *) + (* TODO: implement SlicStan style optimizer for choosing best program block for each statement. *) + (* TODO: add optimization pass to move declarations down as much as possible and introduce as + tight as possible local scopes *) + (* TODO: add tests *) + (* TODO: add pass to get rid of redundant declarations? *) -(** + (** * A generic optimization pass for finding a minimal set of variables that * are generated by some circumstance, and then updating the MIR with that set. * @param gen_variables: the variables that must be added to the set at @@ -1072,87 +1193,89 @@ let block_fixing mir = * @param initial_variables: the initial known members of the set of variables * @param stmt the MIR statement to optimize. *) -let optimize_minimal_variables - ~(gen_variables : - (int, Stmt.Located.Non_recursive.t) Map.Poly.t - -> int - -> string Set.Poly.t - -> string Set.Poly.t ) - ~(update_expr : string Set.Poly.t -> Expr.Typed.t -> Expr.Typed.t) - ~(update_stmt : - ( Expr.Typed.t - , (Expr.Typed.Meta.t, 'a) Stmt.Fixed.t ) - Stmt.Fixed.Pattern.t - -> string Core_kernel.Set.Poly.t - -> ( Expr.Typed.t - , (Expr.Typed.Meta.t, 'a) Stmt.Fixed.t ) - Stmt.Fixed.Pattern.t ) - ~(extra_variables : string -> string Set.Poly.t) - ~(initial_variables : string Set.Poly.t) (stmt : Stmt.Located.t) = - let rev_flowgraph, flowgraph_to_mir = - Monotone_framework.inverse_flowgraph_of_stmt stmt in - let fwd_flowgraph = Monotone_framework.reverse rev_flowgraph in - let (module Circular_Fwd_Flowgraph) = - Monotone_framework.make_circular_flowgraph fwd_flowgraph rev_flowgraph in - let mfp_variables = - Monotone_framework.minimal_variables_mfp - (module Circular_Fwd_Flowgraph) - flowgraph_to_mir initial_variables gen_variables in - let optimize_min_vars_stmt_base i stmt_pattern = - let variable_set = - let exits = (Map.find_exn mfp_variables i).exit in - Set.Poly.union exits (union_map exits ~f:extra_variables) in - let stmt_val = - Stmt.Fixed.Pattern.map (update_expr variable_set) - (fun x -> x) - stmt_pattern in - update_stmt stmt_val variable_set in - map_rec_stmt_loc_num flowgraph_to_mir optimize_min_vars_stmt_base - (Map.find_exn flowgraph_to_mir 1) + let optimize_minimal_variables + ~(gen_variables : + (int, Stmt.Located.Non_recursive.t) Map.Poly.t + -> int + -> string Set.Poly.t + -> string Set.Poly.t ) + ~(update_expr : string Set.Poly.t -> Expr.Typed.t -> Expr.Typed.t) + ~(update_stmt : + ( Expr.Typed.t + , (Expr.Typed.Meta.t, 'a) Stmt.Fixed.t ) + Stmt.Fixed.Pattern.t + -> string Core_kernel.Set.Poly.t + -> ( Expr.Typed.t + , (Expr.Typed.Meta.t, 'a) Stmt.Fixed.t ) + Stmt.Fixed.Pattern.t ) + ~(extra_variables : string -> string Set.Poly.t) + ~(initial_variables : string Set.Poly.t) (stmt : Stmt.Located.t) = + let rev_flowgraph, flowgraph_to_mir = + Monotone_framework.inverse_flowgraph_of_stmt stmt in + let fwd_flowgraph = Monotone_framework.reverse rev_flowgraph in + let (module Circular_Fwd_Flowgraph) = + Monotone_framework.make_circular_flowgraph fwd_flowgraph rev_flowgraph + in + let mfp_variables = + Monotone_framework.minimal_variables_mfp + (module Circular_Fwd_Flowgraph) + flowgraph_to_mir initial_variables gen_variables in + let optimize_min_vars_stmt_base i stmt_pattern = + let variable_set = + let exits = (Map.find_exn mfp_variables i).exit in + Set.Poly.union exits (union_map exits ~f:extra_variables) in + let stmt_val = + Stmt.Fixed.Pattern.map (update_expr variable_set) + (fun x -> x) + stmt_pattern in + update_stmt stmt_val variable_set in + map_rec_stmt_loc_num flowgraph_to_mir optimize_min_vars_stmt_base + (Map.find_exn flowgraph_to_mir 1) -let optimize_ad_levels (mir : Program.Typed.t) = - let gen_ad_variables - (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) - (l : int) (ad_variables : string Set.Poly.t) = - let mir_node = (Map.find_exn flowgraph_to_mir l).pattern in - match mir_node with - | Assignment ((x, _, _), e) - when Expr.Typed.adlevel_of (update_expr_ad_levels ad_variables e) - = AutoDiffable -> - Set.Poly.singleton x - | _ -> Set.Poly.empty in - let global_initial_ad_variables = - Set.Poly.of_list - (List.filter_map - ~f:(fun (v, Program.{out_block; _}) -> - match out_block with Parameters -> Some v | _ -> None ) - mir.output_vars ) in - let initial_ad_variables fundef_opt _ = - match (fundef_opt : Stmt.Located.t Program.fun_def option) with - | None -> global_initial_ad_variables - | Some {fdargs; _} -> - Set.Poly.union global_initial_ad_variables - (Set.Poly.of_list - (List.filter_map fdargs ~f:(fun (_, name, ut) -> - if UnsizedType.is_autodiffable ut then Some name else None ) - ) ) in - let extra_variables v = Set.Poly.singleton (v ^ "_in__") in - let update_stmt stmt_pattern variable_set = - match stmt_pattern with - | Stmt.Fixed.Pattern.Decl ({decl_id; _} as decl) - when Set.mem variable_set decl_id -> - Stmt.Fixed.Pattern.Decl {decl with decl_adtype= UnsizedType.AutoDiffable} - | Decl ({decl_id; _} as decl) when not (Set.mem variable_set decl_id) -> - Decl {decl with decl_adtype= DataOnly} - | s -> s in - let transform fundef_opt stmt = - optimize_minimal_variables ~gen_variables:gen_ad_variables - ~update_expr:update_expr_ad_levels ~update_stmt ~extra_variables - ~initial_variables:(initial_ad_variables fundef_opt stmt) - stmt in - transform_program_blockwise mir transform + let optimize_ad_levels (mir : Program.Typed.t) = + let gen_ad_variables + (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) + (l : int) (ad_variables : string Set.Poly.t) = + let mir_node = (Map.find_exn flowgraph_to_mir l).pattern in + match mir_node with + | Assignment ((x, _, _), e) + when Expr.Typed.adlevel_of (update_expr_ad_levels ad_variables e) + = AutoDiffable -> + Set.Poly.singleton x + | _ -> Set.Poly.empty in + let global_initial_ad_variables = + Set.Poly.of_list + (List.filter_map + ~f:(fun (v, Program.{out_block; _}) -> + match out_block with Parameters -> Some v | _ -> None ) + mir.output_vars ) in + let initial_ad_variables fundef_opt _ = + match (fundef_opt : Stmt.Located.t Program.fun_def option) with + | None -> global_initial_ad_variables + | Some {fdargs; _} -> + Set.Poly.union global_initial_ad_variables + (Set.Poly.of_list + (List.filter_map fdargs ~f:(fun (_, name, ut) -> + if UnsizedType.is_autodiffable ut then Some name else None ) + ) ) in + let extra_variables v = Set.Poly.singleton (v ^ "_in__") in + let update_stmt stmt_pattern variable_set = + match stmt_pattern with + | Stmt.Fixed.Pattern.Decl ({decl_id; _} as decl) + when Set.mem variable_set decl_id -> + Stmt.Fixed.Pattern.Decl + {decl with decl_adtype= UnsizedType.AutoDiffable} + | Decl ({decl_id; _} as decl) when not (Set.mem variable_set decl_id) -> + Decl {decl with decl_adtype= DataOnly} + | s -> s in + let transform fundef_opt stmt = + optimize_minimal_variables ~gen_variables:gen_ad_variables + ~update_expr:update_expr_ad_levels ~update_stmt ~extra_variables + ~initial_variables:(initial_ad_variables fundef_opt stmt) + stmt in + transform_program_blockwise mir transform -(** + (** * Deduces whether types can be Structures of Arrays (SoA/fast) or * Arrays of Structs (AoS/slow). See the docs in * Mem_pattern.query_demote_stmt/exprs* functions for @@ -1173,144 +1296,81 @@ let optimize_ad_levels (mir : Program.Typed.t) = * * @param mir: The program's whole MIR. *) -let optimize_soa (mir : Program.Typed.t) = - let module Mem = Mem_pattern.Make (Frontend.Std_library_utils.NullLibrary) - (*TODO*) in - let gen_aos_variables - (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) - (l : int) (aos_variables : string Set.Poly.t) = - let mir_node mir_idx = Map.find_exn flowgraph_to_mir mir_idx in - match (mir_node l).pattern with - | stmt -> Mem.query_demotable_stmt aos_variables stmt in - let initial_variables = - List.fold ~init:Set.Poly.empty - ~f:(Mem.query_initial_demotable_stmt false) - mir.log_prob in - let mod_exprs aos_exits mod_expr = - Mir_utils.map_rec_expr (Mem.modify_expr_pattern aos_exits) mod_expr in - let modify_stmt_patt stmt_pattern variable_set = - Mem.modify_stmt_pattern stmt_pattern variable_set in - let transform stmt = - optimize_minimal_variables ~gen_variables:gen_aos_variables - ~update_expr:mod_exprs ~update_stmt:modify_stmt_patt ~initial_variables - stmt ~extra_variables:(fun _ -> initial_variables) in - let transform' s = - match transform {pattern= SList s; meta= Location_span.empty} with - | {pattern= SList (l : Stmt.Located.t list); _} -> l - | _ -> - raise - (Failure "Something went wrong with program transformation packing!") - in - {mir with log_prob= transform' mir.log_prob} - -(* Apparently you need to completely copy/paste type definitions between - ml and mli files?*) -type optimization_settings = - { function_inlining: bool - ; static_loop_unrolling: bool - ; one_step_loop_unrolling: bool - ; list_collapsing: bool - ; block_fixing: bool - ; allow_uninitialized_decls: bool - ; constant_propagation: bool - ; expression_propagation: bool - ; copy_propagation: bool - ; dead_code_elimination: bool - ; partial_evaluation: bool - ; lazy_code_motion: bool - ; optimize_ad_levels: bool - ; preserve_stability: bool - ; optimize_soa: bool } - -let settings_const b = - { function_inlining= b - ; static_loop_unrolling= b - ; one_step_loop_unrolling= b - ; list_collapsing= b - ; block_fixing= b - ; allow_uninitialized_decls= b - ; constant_propagation= b - ; expression_propagation= b - ; copy_propagation= b - ; dead_code_elimination= b - ; partial_evaluation= b - ; lazy_code_motion= b - ; optimize_ad_levels= b - ; preserve_stability= not b - ; optimize_soa= b } - -let all_optimizations : optimization_settings = settings_const true -let no_optimizations : optimization_settings = settings_const false - -type optimization_level = O0 | O1 | Oexperimental - -let level_optimizations (lvl : optimization_level) : optimization_settings = - match lvl with - | O0 -> no_optimizations - | O1 -> - { function_inlining= true - ; static_loop_unrolling= false - ; one_step_loop_unrolling= false - ; list_collapsing= true - ; block_fixing= true - ; constant_propagation= true - ; expression_propagation= false - ; copy_propagation= true - ; dead_code_elimination= true - ; partial_evaluation= true - ; lazy_code_motion= false - ; allow_uninitialized_decls= true - ; optimize_ad_levels= false - ; preserve_stability= false - ; optimize_soa= true } - | Oexperimental -> all_optimizations + let optimize_soa (mir : Program.Typed.t) = + let gen_aos_variables + (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) + (l : int) (aos_variables : string Set.Poly.t) = + let mir_node mir_idx = Map.find_exn flowgraph_to_mir mir_idx in + match (mir_node l).pattern with + | stmt -> Mem.query_demotable_stmt aos_variables stmt in + let initial_variables = + List.fold ~init:Set.Poly.empty + ~f:(Mem.query_initial_demotable_stmt false) + mir.log_prob in + let mod_exprs aos_exits mod_expr = + Mir_utils.map_rec_expr (Mem.modify_expr_pattern aos_exits) mod_expr in + let modify_stmt_patt stmt_pattern variable_set = + Mem.modify_stmt_pattern stmt_pattern variable_set in + let transform stmt = + optimize_minimal_variables ~gen_variables:gen_aos_variables + ~update_expr:mod_exprs ~update_stmt:modify_stmt_patt ~initial_variables + stmt ~extra_variables:(fun _ -> initial_variables) in + let transform' s = + match transform {pattern= SList s; meta= Location_span.empty} with + | {pattern= SList (l : Stmt.Located.t list); _} -> l + | _ -> + raise + (Failure "Something went wrong with program transformation packing!") + in + {mir with log_prob= transform' mir.log_prob} -let optimization_suite ?(settings = all_optimizations) mir = - let preserve_stability = settings.preserve_stability in - let maybe_optimizations = - [ (* Phase order. See phase-ordering-nodes.org for details *) - (* Book section A *) - (* Book section B *) - (* Book: Procedure integration *) - (function_inlining, settings.function_inlining) - (* Book: Sparse conditional constant propagation *) - ; (constant_propagation ~preserve_stability, settings.constant_propagation) - (* Book section C *) - (* Book: Local and global copy propagation *) - ; (copy_propagation, settings.copy_propagation) - (* Book: Sparse conditional constant propagation *) - ; (constant_propagation ~preserve_stability, settings.constant_propagation) - (* Book: Dead-code elimination *) - ; (dead_code_elimination, settings.dead_code_elimination) - (* Matthijs: Before lazy code motion to get loop-invariant code motion *) - ; (one_step_loop_unrolling, settings.one_step_loop_unrolling) - (* Matthjis: expression_propagation < partial_evaluation *) - ; ( expression_propagation ~preserve_stability - , settings.expression_propagation ) - (* Matthjis: partial_evaluation < lazy_code_motion *) - ; (partial_evaluation, settings.partial_evaluation) - (* Book: Loop-invariant code motion *) - ; (lazy_code_motion ~preserve_stability, settings.lazy_code_motion) - (* Matthijs: lazy_code_motion < copy_propagation TODO: Check if this is necessary *) - ; (copy_propagation, settings.copy_propagation) - (* Matthijs: Constant propagation before static loop unrolling *) - ; (constant_propagation ~preserve_stability, settings.constant_propagation) - (* Book: Loop simplification *) - ; (static_loop_unrolling, settings.static_loop_unrolling) - (* Book: Dead-code elimination *) - (* Matthijs: Everything < Dead-code elimination *) - ; (dead_code_elimination, settings.dead_code_elimination) - (* Book: Machine idioms and instruction combining *) - ; (list_collapsing, settings.list_collapsing) - (* Book: Machine idioms and instruction combining *) - ; (optimize_ad_levels, settings.optimize_ad_levels) - (*Remove decls immediately assigned to*) - ; (allow_uninitialized_decls, settings.allow_uninitialized_decls) - (* Book: Machine idioms and instruction combining *) - (* Matthijs: Everything < block_fixing *) - ; (block_fixing, settings.block_fixing) - ; (optimize_soa, settings.optimize_soa) ] in - let optimizations = - List.filter_map maybe_optimizations ~f:(fun (fn, flag) -> - if flag then Some fn else None ) in - List.fold optimizations ~init:mir ~f:(fun mir opt -> opt mir) + let optimization_suite ?(settings = all_optimizations) mir = + let preserve_stability = settings.preserve_stability in + let maybe_optimizations = + [ (* Phase order. See phase-ordering-nodes.org for details *) + (* Book section A *) + (* Book section B *) + (* Book: Procedure integration *) + (function_inlining, settings.function_inlining) + (* Book: Sparse conditional constant propagation *) + ; (constant_propagation ~preserve_stability, settings.constant_propagation) + (* Book section C *) + (* Book: Local and global copy propagation *) + ; (copy_propagation, settings.copy_propagation) + (* Book: Sparse conditional constant propagation *) + ; (constant_propagation ~preserve_stability, settings.constant_propagation) + (* Book: Dead-code elimination *) + ; (dead_code_elimination, settings.dead_code_elimination) + (* Matthijs: Before lazy code motion to get loop-invariant code motion *) + ; (one_step_loop_unrolling, settings.one_step_loop_unrolling) + (* Matthjis: expression_propagation < partial_evaluation *) + ; ( expression_propagation ~preserve_stability + , settings.expression_propagation ) + (* Matthjis: partial_evaluation < lazy_code_motion *) + ; (partial_evaluation, settings.partial_evaluation) + (* Book: Loop-invariant code motion *) + ; (lazy_code_motion ~preserve_stability, settings.lazy_code_motion) + (* Matthijs: lazy_code_motion < copy_propagation TODO: Check if this is necessary *) + ; (copy_propagation, settings.copy_propagation) + (* Matthijs: Constant propagation before static loop unrolling *) + ; (constant_propagation ~preserve_stability, settings.constant_propagation) + (* Book: Loop simplification *) + ; (static_loop_unrolling, settings.static_loop_unrolling) + (* Book: Dead-code elimination *) + (* Matthijs: Everything < Dead-code elimination *) + ; (dead_code_elimination, settings.dead_code_elimination) + (* Book: Machine idioms and instruction combining *) + ; (list_collapsing, settings.list_collapsing) + (* Book: Machine idioms and instruction combining *) + ; (optimize_ad_levels, settings.optimize_ad_levels) + (*Remove decls immediately assigned to*) + ; (allow_uninitialized_decls, settings.allow_uninitialized_decls) + (* Book: Machine idioms and instruction combining *) + (* Matthijs: Everything < block_fixing *) + ; (block_fixing, settings.block_fixing) + ; (optimize_soa, settings.optimize_soa) ] in + let optimizations = + List.filter_map maybe_optimizations ~f:(fun (fn, flag) -> + if flag then Some fn else None ) in + List.fold optimizations ~init:mir ~f:(fun mir opt -> opt mir) +end diff --git a/src/analysis_and_optimization/Optimize.mli b/src/analysis_and_optimization/Optimize.mli index 2885ce9b44..8d1bc28331 100644 --- a/src/analysis_and_optimization/Optimize.mli +++ b/src/analysis_and_optimization/Optimize.mli @@ -1,93 +1,97 @@ (* Code for optimization passes on the MIR *) open Middle -val function_inlining : Program.Typed.t -> Program.Typed.t -(** Inline all functions except for ones with forward declarations +(** Interface for turning individual optimizations on/off. Useful for testing + and for top-level interface flags. *) +type optimization_settings = + { function_inlining: bool + ; static_loop_unrolling: bool + ; one_step_loop_unrolling: bool + ; list_collapsing: bool + ; block_fixing: bool + ; allow_uninitialized_decls: bool + ; constant_propagation: bool + ; expression_propagation: bool + ; copy_propagation: bool + ; dead_code_elimination: bool + ; partial_evaluation: bool + ; lazy_code_motion: bool + ; optimize_ad_levels: bool + ; preserve_stability: bool + ; optimize_soa: bool } + +val all_optimizations : optimization_settings +val no_optimizations : optimization_settings + +type optimization_level = O0 | O1 | Oexperimental + +val level_optimizations : optimization_level -> optimization_settings + +module type Optimizer = sig + val function_inlining : Program.Typed.t -> Program.Typed.t + (** Inline all functions except for ones with forward declarations (e.g. recursive functions, mutually recursive functions, and functions without a definition *) -val static_loop_unrolling : Program.Typed.t -> Program.Typed.t -(** Unroll all for-loops with constant bounds, as long as they do + val static_loop_unrolling : Program.Typed.t -> Program.Typed.t + (** Unroll all for-loops with constant bounds, as long as they do not contain break or continue statements in their body at the top level *) -val one_step_loop_unrolling : Program.Typed.t -> Program.Typed.t -(** Unroll all loops for one iteration, as long as they do + val one_step_loop_unrolling : Program.Typed.t -> Program.Typed.t + (** Unroll all loops for one iteration, as long as they do not contain break or continue statements in their body at the top level *) -val list_collapsing : Program.Typed.t -> Program.Typed.t -(** Remove redundant SList constructors from the Mir that might have + val list_collapsing : Program.Typed.t -> Program.Typed.t + (** Remove redundant SList constructors from the Mir that might have been introduced by other optimizations *) -val block_fixing : Program.Typed.t -> Program.Typed.t -(** Make sure that SList constructors directly under if, for, while or fundef + val block_fixing : Program.Typed.t -> Program.Typed.t + (** Make sure that SList constructors directly under if, for, while or fundef constructors are replaced with Block constructors. This should probably be run before we generate code. *) -val constant_propagation : - ?preserve_stability:bool -> Program.Typed.t -> Program.Typed.t -(** Propagate constant values through variable assignments *) + val constant_propagation : + ?preserve_stability:bool -> Program.Typed.t -> Program.Typed.t + (** Propagate constant values through variable assignments *) -val expression_propagation : - ?preserve_stability:bool -> Program.Typed.t -> Program.Typed.t -(** Propagate arbitrary expressions through variable assignments. + val expression_propagation : + ?preserve_stability:bool -> Program.Typed.t -> Program.Typed.t + (** Propagate arbitrary expressions through variable assignments. This can be useful for opening up new possibilities for partial evaluation. It should be followed by some CSE or lazy code motion pass, however. *) -val copy_propagation : Program.Typed.t -> Program.Typed.t -(** Propagate copies of variables through assignments. *) + val copy_propagation : Program.Typed.t -> Program.Typed.t + (** Propagate copies of variables through assignments. *) -val dead_code_elimination : Program.Typed.t -> Program.Typed.t -(** Eliminate semantically redundant code branches. + val dead_code_elimination : Program.Typed.t -> Program.Typed.t + (** Eliminate semantically redundant code branches. This includes removing redundant assignments (because they will be overwritten) and removing redundant code in program branches that will never be reached. *) -val partial_evaluation : Program.Typed.t -> Program.Typed.t -(** Partially evaluate expressions in the program. This includes simplification using + val partial_evaluation : Program.Typed.t -> Program.Typed.t + (** Partially evaluate expressions in the program. This includes simplification using algebraic identities of logical and arithmetic operators as well as Stan math functions. *) -val lazy_code_motion : - ?preserve_stability:bool -> Program.Typed.t -> Program.Typed.t -(** Perform partial redundancy elmination using the lazy code motion algorithm. This + val lazy_code_motion : + ?preserve_stability:bool -> Program.Typed.t -> Program.Typed.t + (** Perform partial redundancy elmination using the lazy code motion algorithm. This subsumes common subexpression elimination and loop-invariant code motion. *) -val optimize_ad_levels : Program.Typed.t -> Program.Typed.t -(** Assign the optimal ad-levels to local variables. That means, make sure that + val optimize_ad_levels : Program.Typed.t -> Program.Typed.t + (** Assign the optimal ad-levels to local variables. That means, make sure that variables only ever get treated as autodiff variables if they have some dependency on a parameter *) -val allow_uninitialized_decls : Program.Typed.t -> Program.Typed.t -(** Marks Decl types such that, if the first assignment after the decl - assigns to the full object, allow the object to be constructed but + val allow_uninitialized_decls : Program.Typed.t -> Program.Typed.t + (** Marks Decl types such that, if the first assignment after the decl + assigns to the full object, allow the object to be constructed but not uninitialized. *) -(** Interface for turning individual optimizations on/off. Useful for testing - and for top-level interface flags. *) -type optimization_settings = - { function_inlining: bool - ; static_loop_unrolling: bool - ; one_step_loop_unrolling: bool - ; list_collapsing: bool - ; block_fixing: bool - ; allow_uninitialized_decls: bool - ; constant_propagation: bool - ; expression_propagation: bool - ; copy_propagation: bool - ; dead_code_elimination: bool - ; partial_evaluation: bool - ; lazy_code_motion: bool - ; optimize_ad_levels: bool - ; preserve_stability: bool - ; optimize_soa: bool } - -val all_optimizations : optimization_settings -val no_optimizations : optimization_settings - -type optimization_level = O0 | O1 | Oexperimental - -val level_optimizations : optimization_level -> optimization_settings + val optimization_suite : + ?settings:optimization_settings -> Program.Typed.t -> Program.Typed.t + (** Perform all optimizations in this module on the MIR in an appropriate order. *) +end -val optimization_suite : - ?settings:optimization_settings -> Program.Typed.t -> Program.Typed.t -(** Perform all optimizations in this module on the MIR in an appropriate order. *) +module Make (StdLib : Frontend.Std_library_utils.Library) : Optimizer diff --git a/src/analysis_and_optimization/Partial_evaluation.ml b/src/analysis_and_optimization/Partial_evaluation.ml new file mode 100644 index 0000000000..fc3864fb94 --- /dev/null +++ b/src/analysis_and_optimization/Partial_evaluation.ml @@ -0,0 +1,1195 @@ +(* A partial evaluator for use in static analysis and optimization *) + +open Core_kernel +open Core_kernel.Poly +open Middle + +exception Rejected of Location_span.t * string + +let rec is_int query Expr.Fixed.{pattern; _} = + match pattern with + | Lit (Int, i) | Lit (Real, i) -> float_of_string i = float_of_int query + | Promotion (e, _, _) -> is_int query e + | _ -> false + +let apply_prefix_operator_int (op : string) i = + Expr.Fixed.Pattern.Lit + ( Int + , Int.to_string + ( match op with + | "PPlus__" -> i + | "PMinus__" -> -i + | "PNot__" -> if i = 0 then 1 else 0 + | s -> + Common.FatalError.fatal_error_msg + [%message "Not an int prefix operator: " s] ) ) + +let apply_prefix_operator_real (op : string) i = + Expr.Fixed.Pattern.Lit + ( Real + , Float.to_string + ( match op with + | "PPlus__" -> i + | "PMinus__" -> -.i + | s -> + Common.FatalError.fatal_error_msg + [%message "Not a real prefix operator: " s] ) ) + +let apply_operator_int (op : string) i1 i2 = + Expr.Fixed.Pattern.Lit + ( Int + , Int.to_string + ( match op with + | "Plus__" -> i1 + i2 + | "Minus__" -> i1 - i2 + | "Times__" -> i1 * i2 + | "Divide__" | "IntDivide__" -> i1 / i2 + | "Modulo__" -> i1 % i2 + | "Equals__" -> Bool.to_int (i1 = i2) + | "NEquals__" -> Bool.to_int (i1 <> i2) + | "Less__" -> Bool.to_int (i1 < i2) + | "Leq__" -> Bool.to_int (i1 <= i2) + | "Greater__" -> Bool.to_int (i1 > i2) + | "Geq__" -> Bool.to_int (i1 >= i2) + | s -> + Common.FatalError.fatal_error_msg + [%message "Not an int operator: " s] ) ) + +let apply_arithmetic_operator_real (op : string) r1 r2 = + Expr.Fixed.Pattern.Lit + ( Real + , Float.to_string + ( match op with + | "Plus__" -> r1 +. r2 + | "Minus__" -> r1 -. r2 + | "Times__" -> r1 *. r2 + | "Divide__" -> r1 /. r2 + | s -> + Common.FatalError.fatal_error_msg + [%message "Not a real operator: " s] ) ) + +let apply_logical_operator_real (op : string) r1 r2 = + Expr.Fixed.Pattern.Lit + ( Int + , Int.to_string + ( match op with + | "Equals__" -> Bool.to_int (r1 = r2) + | "NEquals__" -> Bool.to_int (r1 <> r2) + | "Less__" -> Bool.to_int (r1 < r2) + | "Leq__" -> Bool.to_int (r1 <= r2) + | "Greater__" -> Bool.to_int (r1 > r2) + | "Geq__" -> Bool.to_int (r1 >= r2) + | s -> + Common.FatalError.fatal_error_msg + [%message "Not a logical operator: " s] ) ) + +let is_multi_index = function + | Index.MultiIndex _ | Upfrom _ | Between _ | All -> true + | Single _ -> false + +module type PartialEvaluator = sig + val try_eval_expr : Expr.Typed.t -> Expr.Typed.t + val eval_prog : Program.Typed.t -> Program.Typed.t +end + +module Make (StdLib : Frontend.Std_library_utils.Library) = struct + module TC = Frontend.Typechecking.Make (StdLib) + + let rec eval_expr ?(preserve_stability = false) (e : Expr.Typed.t) = + { e with + pattern= + ( match e.pattern with + | Var _ | Lit (_, _) -> e.pattern + | Promotion (expr, ut, ad) -> + Promotion (eval_expr ~preserve_stability expr, ut, ad) + | FunApp (kind, l) -> ( + let l = List.map ~f:(eval_expr ~preserve_stability) l in + match kind with + | UserDefined _ | CompilerInternal _ -> FunApp (kind, l) + | StanLib (f, suffix, mem_type) -> + let get_fun_or_op_rt_opt name l' = + let argument_types = + List.map + ~f:(fun x -> Expr.Typed.(adlevel_of x, type_of x)) + l' in + Operator.of_string_opt name + |> Option.value_map + ~f:(fun op -> + TC.operator_return_type op argument_types + |> Option.map ~f:fst ) + ~default: + (TC.library_function_return_type name argument_types) + in + let try_partially_evaluate_stanlib e = + Expr.Fixed.Pattern.( + match e with + | FunApp (StanLib (f', suffix', mem_type), l') -> ( + match get_fun_or_op_rt_opt f' l' with + | Some _ -> FunApp (StanLib (f', suffix', mem_type), l') + | None -> FunApp (StanLib (f, suffix, mem_type), l) ) + | e -> e) in + let lub_mem_pat lst = + Common.Helpers.lub_mem_pat (List.cons mem_type lst) in + try_partially_evaluate_stanlib + ( match (f, l) with + (* TODO: deal with tilde statements and unnormalized distributions properly here *) + | ( "bernoulli_lpmf" + , [ y + ; { pattern= + FunApp + ( StanLib ("inv_logit", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem2) + , [ alpha + ; { pattern= + FunApp + ( StanLib + ("Times__", FnPlain, mem3) + , [x; beta] ) + ; _ } ] ) + ; _ } ] ) + ; _ } ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2; mem3] in + FunApp + ( StanLib ("bernoulli_logit_glm_lpmf", suffix, lub_mem) + , [y; x; alpha; beta] ) + | ( "bernoulli_lpmf" + , [ y + ; { pattern= + FunApp + ( StanLib ("inv_logit", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem2) + , [ { pattern= + FunApp + ( StanLib + ("Times__", FnPlain, mem3) + , [x; beta] ) + ; _ }; alpha ] ) + ; _ } ] ) + ; _ } ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2; mem3] in + FunApp + ( StanLib ("bernoulli_logit_glm_lpmf", suffix, lub_mem) + , [y; x; alpha; beta] ) + | ( "bernoulli_lpmf" + , [ y + ; { pattern= + FunApp + ( StanLib ("inv_logit", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem2) + , [x; beta] ) + ; _ } ] ) + ; _ } ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + ( StanLib ("bernoulli_logit_glm_lpmf", suffix, lub_mem) + , [y; x; Expr.Helpers.zero; beta] ) + | ( "bernoulli_logit_lpmf" + , [ y + ; { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem1) + , [ alpha + ; { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem2) + , [x; beta] ) + ; _ } ] ) + ; _ } ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + ( StanLib ("bernoulli_logit_glm_lpmf", suffix, lub_mem) + , [y; x; alpha; beta] ) + | ( "bernoulli_logit_lpmf" + , [ y + ; { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem2) + , [x; beta] ) + ; _ }; alpha ] ) + ; _ } ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + ( StanLib ("bernoulli_logit_glm_lpmf", suffix, lub_mem) + , [y; x; alpha; beta] ) + | ( "bernoulli_logit_lpmf" + , [ y + ; { pattern= + FunApp (StanLib ("Times__", FnPlain, mem), [x; beta]) + ; _ } ] ) + when Expr.Typed.type_of x = UMatrix -> + FunApp + ( StanLib + ( "bernoulli_logit_glm_lpmf" + , suffix + , lub_mem_pat [mem] ) + , [y; x; Expr.Helpers.zero; beta] ) + | ( "bernoulli_lpmf" + , [ y + ; { pattern= + FunApp (StanLib ("inv_logit", FnPlain, mem), [alpha]) + ; _ } ] ) -> + FunApp + ( StanLib + ("bernoulli_logit_lpmf", suffix, lub_mem_pat [mem]) + , [y; alpha] ) + | ( "bernoulli_rng" + , [ { pattern= + FunApp (StanLib ("inv_logit", FnPlain, mem), [alpha]) + ; _ } ] ) -> + FunApp + ( StanLib + ("bernoulli_logit_rng", suffix, lub_mem_pat [mem]) + , [alpha] ) + | ( "binomial_lpmf" + , [ y; n + ; { pattern= + FunApp (StanLib ("inv_logit", FnPlain, mem), [alpha]) + ; _ } ] ) -> + FunApp + ( StanLib + ("binomial_logit_lpmf", suffix, lub_mem_pat [mem]) + , [y; n; alpha] ) + | ( "categorical_lpmf" + , [ y + ; { pattern= + FunApp (StanLib ("inv_logit", FnPlain, mem), [alpha]) + ; _ } ] ) -> + FunApp + ( StanLib + ("categorical_logit_lpmf", suffix, lub_mem_pat [mem]) + , [y; alpha] ) + | ( "categorical_rng" + , [ { pattern= + FunApp (StanLib ("inv_logit", FnPlain, mem), [alpha]) + ; _ } ] ) -> + FunApp + ( StanLib + ("categorical_logit_rng", suffix, lub_mem_pat [mem]) + , [alpha] ) + | "columns_dot_product", [x; y] when Expr.Typed.equal x y -> + FunApp + (StanLib ("columns_dot_self", suffix, mem_type), [x]) + | "dot_product", [x; y] when Expr.Typed.equal x y -> + FunApp (StanLib ("dot_self", suffix, mem_type), [x]) + | ( "inv" + , [{pattern= FunApp (StanLib ("sqrt", FnPlain, mem), l); _}] + ) -> + FunApp (StanLib ("inv_sqrt", suffix, mem), l) + | ( "inv" + , [ { pattern= FunApp (StanLib ("square", FnPlain, mem), [x]) + ; _ } ] ) -> + FunApp + (StanLib ("inv_square", suffix, lub_mem_pat [mem]), [x]) + | ( "log" + , [ { pattern= + FunApp + ( StanLib ("Minus__", FnPlain, mem1) + , [ y + ; { pattern= + FunApp + (StanLib ("exp", FnPlain, mem2), [x]) + ; _ } ] ) + ; _ } ] ) + when is_int 1 y && not preserve_stability -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp (StanLib ("log1m_exp", suffix, lub_mem), [x]) + | ( "log" + , [ { pattern= + FunApp + ( StanLib ("Minus__", FnPlain, mem1) + , [ y + ; { pattern= + FunApp + ( StanLib ("inv_logit", FnPlain, mem2) + , [x] ) + ; _ } ] ) + ; _ } ] ) + when is_int 1 y && not preserve_stability -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp (StanLib ("log1m_inv_logit", suffix, lub_mem), [x]) + | ( "log" + , [ { pattern= + FunApp (StanLib ("Minus__", FnPlain, mem), [y; x]) + ; _ } ] ) + when is_int 1 y && not preserve_stability -> + FunApp (StanLib ("log1m", suffix, lub_mem_pat [mem]), [x]) + | ( "log" + , [ { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem1) + , [ y + ; { pattern= + FunApp + (StanLib ("exp", FnPlain, mem2), [x]) + ; _ } ] ) + ; _ } ] ) + when is_int 1 y && not preserve_stability -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp (StanLib ("log1p_exp", suffix, lub_mem), [x]) + | ( "log" + , [ { pattern= + FunApp (StanLib ("Plus__", FnPlain, mem), [y; x]) + ; _ } ] ) + when is_int 1 y && not preserve_stability -> + FunApp (StanLib ("log1p", suffix, lub_mem_pat [mem]), [x]) + | ( "log" + , [ { pattern= + FunApp + ( StanLib ("fabs", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("determinant", FnPlain, mem2) + , [x] ) + ; _ } ] ) + ; _ } ] ) -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp (StanLib ("log_determinant", suffix, lub_mem), [x]) + | ( "log" + , [ { pattern= + FunApp + ( StanLib ("Minus__", FnPlain, mem1) + , [ { pattern= + FunApp + (StanLib ("exp", FnPlain, mem2), [x]) + ; _ } + ; { pattern= + FunApp + (StanLib ("exp", FnPlain, mem3), [y]) + ; _ } ] ) + ; _ } ] ) -> + let lub_mem = lub_mem_pat [mem1; mem2; mem3] in + FunApp (StanLib ("log_diff_exp", suffix, lub_mem), [x; y]) + (* TODO: log_mix?*) + | ( "log" + , [ { pattern= + FunApp + (StanLib ("falling_factorial", FnPlain, mem), l) + ; _ } ] ) -> + FunApp + ( StanLib + ("log_falling_factorial", suffix, lub_mem_pat [mem]) + , l ) + | ( "log" + , [ { pattern= + FunApp + (StanLib ("rising_factorial", FnPlain, mem), l) + ; _ } ] ) -> + FunApp + ( StanLib + ("log_rising_factorial", suffix, lub_mem_pat [mem]) + , l ) + | ( "log" + , [ { pattern= + FunApp (StanLib ("inv_logit", FnPlain, mem), l) + ; _ } ] ) -> + FunApp + (StanLib ("log_inv_logit", suffix, lub_mem_pat [mem]), l) + | ( "log" + , [ { pattern= FunApp (StanLib ("softmax", FnPlain, mem), l) + ; _ } ] ) -> + FunApp + (StanLib ("log_softmax", suffix, lub_mem_pat [mem]), l) + | ( "log" + , [ { pattern= + FunApp + ( StanLib ("sum", FnPlain, mem1) + , [ { pattern= + FunApp (StanLib ("exp", FnPlain, mem2), l) + ; _ } ] ) + ; _ } ] ) -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp (StanLib ("log_sum_exp", suffix, lub_mem), l) + | ( "log" + , [ { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem1) + , [ { pattern= + FunApp + (StanLib ("exp", FnPlain, mem2), [x]) + ; _ } + ; { pattern= + FunApp + (StanLib ("exp", FnPlain, mem3), [y]) + ; _ } ] ) + ; _ } ] ) -> + let lub_mem = lub_mem_pat [mem1; mem2; mem3] in + FunApp (StanLib ("log_sum_exp", suffix, lub_mem), [x; y]) + | ( "multi_normal_lpdf" + , [ y; mu + ; { pattern= + FunApp (StanLib ("inverse", FnPlain, mem), [tau]) + ; _ } ] ) -> + let lub_mem = lub_mem_pat [mem] in + FunApp + ( StanLib ("multi_normal_prec_lpdf", suffix, lub_mem) + , [y; mu; tau] ) + | ( "neg_binomial_2_lpmf" + , [ y + ; { pattern= + FunApp + ( StanLib ("exp", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem2) + , [ alpha + ; { pattern= + FunApp + ( StanLib + ("Times__", FnPlain, mem3) + , [x; beta] ) + ; _ } ] ) + ; _ } ] ) + ; _ }; sigma ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2; mem3] in + FunApp + ( StanLib + ("neg_binomial_2_log_glm_lpmf", suffix, lub_mem) + , [y; x; alpha; beta; sigma] ) + | ( "neg_binomial_2_lpmf" + , [ y + ; { pattern= + FunApp + ( StanLib ("exp", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem2) + , [ { pattern= + FunApp + ( StanLib + ("Times__", FnPlain, mem3) + , [x; beta] ) + ; _ }; alpha ] ) + ; _ } ] ) + ; _ }; sigma ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2; mem3] in + FunApp + ( StanLib + ("neg_binomial_2_log_glm_lpmf", suffix, lub_mem) + , [y; x; alpha; beta; sigma] ) + | ( "neg_binomial_2_lpmf" + , [ y + ; { pattern= + FunApp + ( StanLib ("exp", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem2) + , [x; beta] ) + ; _ } ] ) + ; _ }; sigma ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + ( StanLib + ("neg_binomial_2_log_glm_lpmf", suffix, lub_mem) + , [y; x; Expr.Helpers.zero; beta; sigma] ) + | ( "neg_binomial_2_log_lpmf" + , [ y + ; { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem1) + , [ alpha + ; { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem2) + , [x; beta] ) + ; _ } ] ) + ; _ }; sigma ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + ( StanLib + ("neg_binomial_2_log_glm_lpmf", suffix, lub_mem) + , [y; x; alpha; beta; sigma] ) + | ( "neg_binomial_2_log_lpmf" + , [ y + ; { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem2) + , [x; beta] ) + ; _ }; alpha ] ) + ; _ }; sigma ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + ( StanLib + ("neg_binomial_2_log_glm_lpmf", suffix, lub_mem) + , [y; x; alpha; beta; sigma] ) + | ( "neg_binomial_2_log_lpmf" + , [ y + ; { pattern= + FunApp (StanLib ("Times__", FnPlain, mem), [x; beta]) + ; _ }; sigma ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem] in + FunApp + ( StanLib + ("neg_binomial_2_log_glm_lpmf", suffix, lub_mem) + , [y; x; Expr.Helpers.zero; beta; sigma] ) + | ( "neg_binomial_2_lpmf" + , [ y + ; { pattern= FunApp (StanLib ("exp", FnPlain, mem), [eta]) + ; _ }; phi ] ) -> + let lub_mem = lub_mem_pat [mem] in + FunApp + ( StanLib ("neg_binomial_2_log_lpmf", suffix, lub_mem) + , [y; eta; phi] ) + | ( "neg_binomial_2_rng" + , [ { pattern= FunApp (StanLib ("exp", FnPlain, mem), [eta]) + ; _ }; phi ] ) -> + let lub_mem = lub_mem_pat [mem] in + FunApp + ( StanLib ("neg_binomial_2_log_rng", suffix, lub_mem) + , [eta; phi] ) + | ( "normal_lpdf" + , [ y + ; { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem1) + , [ alpha + ; { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem2) + , [x; beta] ) + ; _ } ] ) + ; _ }; sigma ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + ( StanLib ("normal_id_glm_lpdf", suffix, lub_mem) + , [y; x; alpha; beta; sigma] ) + | ( "normal_lpdf" + , [ y + ; { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem2) + , [x; beta] ) + ; _ }; alpha ] ) + ; _ }; sigma ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + ( StanLib ("normal_id_glm_lpdf", suffix, lub_mem) + , [y; x; alpha; beta; sigma] ) + | ( "normal_lpdf" + , [ y + ; { pattern= + FunApp (StanLib ("Times__", FnPlain, mem), [x; beta]) + ; _ }; sigma ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem] in + FunApp + ( StanLib ("normal_id_glm_lpdf", suffix, lub_mem) + , [y; x; Expr.Helpers.zero; beta; sigma] ) + | ( "poisson_lpmf" + , [ y + ; { pattern= + FunApp + ( StanLib ("exp", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem2) + , [ alpha + ; { pattern= + FunApp + ( StanLib + ("Times__", FnPlain, mem3) + , [x; beta] ) + ; _ } ] ) + ; _ } ] ) + ; _ } ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2; mem3] in + FunApp + ( StanLib ("poisson_log_glm_lpmf", suffix, lub_mem) + , [y; x; alpha; beta] ) + | ( "poisson_lpmf" + , [ y + ; { pattern= + FunApp + ( StanLib ("exp", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem2) + , [ { pattern= + FunApp + ( StanLib + ("Times__", FnPlain, mem3) + , [x; beta] ) + ; _ }; alpha ] ) + ; _ } ] ) + ; _ } ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2; mem3] in + FunApp + ( StanLib ("poisson_log_glm_lpmf", suffix, lub_mem) + , [y; x; alpha; beta] ) + | ( "poisson_lpmf" + , [ y + ; { pattern= + FunApp + ( StanLib ("exp", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem2) + , [x; beta] ) + ; _ } ] ) + ; _ } ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + ( StanLib ("poisson_log_glm_lpmf", suffix, lub_mem) + , [y; x; Expr.Helpers.zero; beta] ) + | ( "poisson_log_lpmf" + , [ y + ; { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem1) + , [ alpha + ; { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem2) + , [x; beta] ) + ; _ } ] ) + ; _ } ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + ( StanLib ("poisson_log_glm_lpmf", suffix, lub_mem) + , [y; x; alpha; beta] ) + | ( "poisson_log_lpmf" + , [ y + ; { pattern= + FunApp + ( StanLib ("Plus__", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem2) + , [x; beta] ) + ; _ }; alpha ] ) + ; _ } ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + ( StanLib ("poisson_log_glm_lpmf", suffix, lub_mem) + , [y; x; alpha; beta] ) + | ( "poisson_log_lpmf" + , [ y + ; { pattern= + FunApp (StanLib ("Times__", FnPlain, mem), [x; beta]) + ; _ } ] ) + when Expr.Typed.type_of x = UMatrix -> + let lub_mem = lub_mem_pat [mem] in + FunApp + ( StanLib ("poisson_log_glm_lpmf", suffix, lub_mem) + , [y; x; Expr.Helpers.zero; beta] ) + | ( "poisson_lpmf" + , [ y + ; { pattern= FunApp (StanLib ("exp", FnPlain, mem), [eta]) + ; _ } ] ) -> + let lub_mem = lub_mem_pat [mem] in + FunApp + (StanLib ("poisson_log_lpmf", suffix, lub_mem), [y; eta]) + | ( "poisson_rng" + , [ { pattern= FunApp (StanLib ("exp", FnPlain, mem), [eta]) + ; _ } ] ) -> + let lub_mem = lub_mem_pat [mem] in + FunApp + (StanLib ("poisson_log_rng", suffix, lub_mem), [eta]) + | "pow", [y; x] when is_int 2 y -> + FunApp (StanLib ("exp2", suffix, mem_type), [x]) + | "rows_dot_product", [x; y] when Expr.Typed.equal x y -> + FunApp (StanLib ("rows_dot_self", suffix, mem_type), [x]) + | "pow", [x; {pattern= Lit (Int, "2"); _}] -> + FunApp (StanLib ("square", suffix, mem_type), [x]) + | "pow", [x; {pattern= Lit (Real, "0.5"); _}] -> + FunApp (StanLib ("sqrt", suffix, mem_type), [x]) + | ( "pow" + , [ x + ; { pattern= + FunApp (StanLib ("Divide__", FnPlain, mem), [y; z]) + ; _ } ] ) + when is_int 1 y && is_int 2 z -> + let lub_mem = lub_mem_pat [mem] in + FunApp (StanLib ("sqrt", suffix, lub_mem), [x]) + (* This is wrong; if both are type UInt the exponent is rounds down to zero. *) + | ( "square" + , [{pattern= FunApp (StanLib ("sd", FnPlain, mem), [x]); _}] + ) -> + let lub_mem = lub_mem_pat [mem] in + FunApp (StanLib ("variance", suffix, lub_mem), [x]) + | "sqrt", [x] when is_int 2 x -> + FunApp (StanLib ("sqrt2", suffix, mem_type), []) + | ( "sum" + , [ { pattern= + FunApp + ( StanLib ("square", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Minus__", FnPlain, mem2) + , [x; y] ) + ; _ } ] ) + ; _ } ] ) -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + (StanLib ("squared_distance", suffix, lub_mem), [x; y]) + | ( "sum" + , [ { pattern= FunApp (StanLib ("diagonal", FnPlain, mem), l) + ; _ } ] ) -> + let lub_mem = lub_mem_pat [mem] in + FunApp (StanLib ("trace", suffix, lub_mem), l) + | ( "trace" + , [ { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem2) + , [ { pattern= + FunApp + ( StanLib + ("Times__", FnPlain, mem3) + , [ d + ; { pattern= + FunApp + ( StanLib + ( "transpose" + , FnPlain + , mem4 ) + , [b] ) + ; _ } ] ) + ; _ }; a ] ) + ; _ }; c ] ) + ; _ } ] ) + when Expr.Typed.equal b c -> + let lub_mem = lub_mem_pat [mem1; mem2; mem3; mem4] in + FunApp + ( StanLib ("trace_gen_quad_form", suffix, lub_mem) + , [d; a; b] ) + | ( "trace" + , [ { pattern= + FunApp (StanLib ("quad_form", FnPlain, mem), [a; b]) + ; _ } ] ) -> + let lub_mem = lub_mem_pat [mem] in + FunApp + (StanLib ("trace_quad_form", suffix, lub_mem), [a; b]) + | ( ("Plus__" | "add") + , [ ({pattern= Lit (Imaginary, i); _} as im) + ; ({pattern= Lit ((Real | Int), _); _} as r) ] ) + |( ("Plus__" | "add") + , [ ({pattern= Lit ((Real | Int), _); _} as r) + ; ({pattern= Lit (Imaginary, i); _} as im) ] ) + |( ("Plus__" | "add") + , [ ({pattern= Lit (Imaginary, i); _} as im) + ; { pattern= + Promotion + ( ({pattern= Lit ((Real | Int), _); _} as r) + , UComplex + , _ ) + ; _ } ] ) + |( ("Plus__" | "add") + , [ { pattern= + Promotion + ( ({pattern= Lit ((Real | Int), _); _} as r) + , UComplex + , _ ) + ; _ }; ({pattern= Lit (Imaginary, i); _} as im) ] ) -> + let im_part = + Expr.Fixed. + { pattern= Lit (Real, i) + ; meta= {im.meta with type_= UReal} } in + FunApp + (StanLib ("to_complex", suffix, mem_type), [r; im_part]) + | ( "Minus__" + , [ x + ; {pattern= FunApp (StanLib ("erf", FnPlain, mem), l); _} + ] ) + when is_int 1 x -> + let lub_mem = lub_mem_pat [mem] in + FunApp (StanLib ("erfc", suffix, lub_mem), l) + | ( "Minus__" + , [ x + ; {pattern= FunApp (StanLib ("erfc", FnPlain, mem), l); _} + ] ) + when is_int 1 x -> + let lub_mem = lub_mem_pat [mem] in + FunApp (StanLib ("erf", suffix, lub_mem), l) + | ( "Minus__" + , [ {pattern= FunApp (StanLib ("exp", FnPlain, mem), l'); _} + ; x ] ) + when is_int 1 x && not preserve_stability -> + let lub_mem = lub_mem_pat [mem] in + FunApp (StanLib ("expm1", suffix, lub_mem), l') + | ( "Plus__" + , [ { pattern= + FunApp (StanLib ("Times__", FnPlain, mem), [x; y]) + ; _ }; z ] ) + when (not preserve_stability) + && not + ( UnsizedType.is_eigen_type x.meta.type_ + && UnsizedType.is_eigen_type y.meta.type_ ) -> + let lub_mem = lub_mem_pat [mem] in + FunApp (StanLib ("fma", suffix, lub_mem), [x; y; z]) + | ( "Plus__" + , [ z + ; { pattern= + FunApp (StanLib ("Times__", FnPlain, mem), [x; y]) + ; _ } ] ) + when (not preserve_stability) + && not + ( UnsizedType.is_eigen_type x.meta.type_ + && UnsizedType.is_eigen_type y.meta.type_ ) -> + let lub_mem = lub_mem_pat [mem] in + FunApp (StanLib ("fma", suffix, lub_mem), [x; y; z]) + | ( "Plus__" + , [ { pattern= + FunApp + ( StanLib + (("elt_multiply" | "EltTimes__"), FnPlain, mem) + , [x; y] ) + ; _ }; z ] ) + when not preserve_stability -> + let lub_mem = lub_mem_pat [mem] in + FunApp (StanLib ("fma", suffix, lub_mem), [x; y; z]) + | ( "Plus__" + , [ z + ; { pattern= + FunApp + ( StanLib + (("elt_multiply" | "EltTimes__"), FnPlain, mem) + , [x; y] ) + ; _ } ] ) + when not preserve_stability -> + let lub_mem = lub_mem_pat [mem] in + FunApp (StanLib ("fma", suffix, lub_mem), [x; y; z]) + | ( "Minus__" + , [ x + ; { pattern= FunApp (StanLib ("gamma_p", FnPlain, mem), l) + ; _ } ] ) + when is_int 1 x -> + let lub_mem = lub_mem_pat [mem] in + FunApp (StanLib ("gamma_q", suffix, lub_mem), l) + | ( "Minus__" + , [ x + ; { pattern= FunApp (StanLib ("gamma_q", FnPlain, mem), l) + ; _ } ] ) + when is_int 1 x -> + let lub_mem = lub_mem_pat [mem] in + FunApp (StanLib ("gamma_p", suffix, lub_mem), l) + | ( "Times__" + , [ { pattern= + FunApp + ( StanLib ("matrix_exp", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem2) + , [t; a] ) + ; _ } ] ) + ; _ }; b ] ) + when Expr.Typed.type_of t = UInt + || Expr.Typed.type_of t = UReal -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + ( StanLib ("scale_matrix_exp_multiply", suffix, lub_mem) + , [t; a; b] ) + | ( "Times__" + , [ { pattern= + FunApp + ( StanLib ("matrix_exp", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem2) + , [a; t] ) + ; _ } ] ) + ; _ }; b ] ) + when Expr.Typed.type_of t = UInt + || Expr.Typed.type_of t = UReal -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + ( StanLib ("scale_matrix_exp_multiply", suffix, lub_mem) + , [t; a; b] ) + | ( "Times__" + , [ { pattern= + FunApp (StanLib ("matrix_exp", FnPlain, mem), [a]) + ; _ }; b ] ) -> + let lub_mem = lub_mem_pat [mem] in + FunApp + ( StanLib ("matrix_exp_multiply", suffix, lub_mem) + , [a; b] ) + | ( "Times__" + , [ x + ; {pattern= FunApp (StanLib ("log", FnPlain, mem), [y]); _} + ] ) + |( "Times__" + , [ {pattern= FunApp (StanLib ("log", FnPlain, mem), [y]); _} + ; x ] ) + when not preserve_stability -> + let lub_mem = lub_mem_pat [mem] in + FunApp (StanLib ("lmultiply", suffix, lub_mem), [x; y]) + | ( "Times__" + , [ { pattern= + FunApp (StanLib ("diag_matrix", FnPlain, mem1), [v]) + ; _ } + ; { pattern= + FunApp + ( StanLib ("diag_post_multiply", FnPlain, mem2) + , [a; w] ) + ; _ } ] ) + when Expr.Typed.equal v w -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + (StanLib ("quad_form_diag", suffix, lub_mem), [a; v]) + | ( "Times__" + , [ { pattern= + FunApp + ( StanLib ("diag_pre_multiply", FnPlain, mem1) + , [v; a] ) + ; _ } + ; { pattern= + FunApp (StanLib ("diag_matrix", FnPlain, mem2), [w]) + ; _ } ] ) + when Expr.Typed.equal v w -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp + (StanLib ("quad_form_diag", suffix, lub_mem), [a; v]) + | ( "Times__" + , [ { pattern= + FunApp (StanLib ("transpose", FnPlain, mem1), [b]) + ; _ } + ; { pattern= + FunApp (StanLib ("Times__", FnPlain, mem2), [a; c]) + ; _ } ] ) + when Expr.Typed.equal b c -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp (StanLib ("quad_form", suffix, lub_mem), [a; b]) + | ( "Times__" + , [ { pattern= + FunApp + ( StanLib ("Times__", FnPlain, mem1) + , [ { pattern= + FunApp + ( StanLib ("transpose", FnPlain, mem2) + , [b] ) + ; _ }; a ] ) + ; _ }; c ] ) + when Expr.Typed.equal b c -> + let lub_mem = lub_mem_pat [mem1; mem2] in + FunApp (StanLib ("quad_form", suffix, lub_mem), [a; b]) + | ( "Times__" + , [ e1' + ; { pattern= + FunApp (StanLib ("diag_matrix", FnPlain, mem), [v]) + ; _ } ] ) -> + let lub_mem = lub_mem_pat [mem] in + FunApp + ( StanLib ("diag_post_multiply", suffix, lub_mem) + , [e1'; v] ) + | ( "Times__" + , [ { pattern= + FunApp (StanLib ("diag_matrix", FnPlain, mem), [v]) + ; _ }; e2' ] ) -> + let lub_mem = lub_mem_pat [mem] in + FunApp + ( StanLib ("diag_pre_multiply", suffix, lub_mem) + , [v; e2'] ) + (* Constant folding for operators *) + | op, [{pattern= Lit (Int, i); _}] -> ( + match op with + | "PPlus__" | "PMinus__" | "PNot__" -> + apply_prefix_operator_int op (Int.of_string i) + | _ -> FunApp (kind, l) ) + | op, [{pattern= Lit (Real, r); _}] -> ( + match op with + | "PPlus__" | "PMinus__" -> + apply_prefix_operator_real op (Float.of_string r) + | _ -> FunApp (kind, l) ) + | ( ("Divide__" | "IntDivide__") + , [{meta= {type_= UInt; _}; _}; {pattern= Lit (Int, i2); _}] + ) + when Int.of_string i2 = 0 -> + raise (Rejected (e.meta.loc, "Integer division by zero")) + | ( op + , [{pattern= Lit (Int, i1); _}; {pattern= Lit (Int, i2); _}] + ) -> ( + match op with + | "Plus__" | "Minus__" | "Times__" | "Divide__" + |"IntDivide__" | "Modulo__" | "Or__" | "And__" + |"Equals__" | "NEquals__" | "Less__" | "Leq__" + |"Greater__" | "Geq__" -> + apply_operator_int op (Int.of_string i1) + (Int.of_string i2) + | _ -> FunApp (kind, l) ) + | ( op + , [ {pattern= Lit (Real, i1); _} + ; {pattern= Lit (Real, i2); _} ] ) + |( op + , [{pattern= Lit (Int, i1); _}; {pattern= Lit (Real, i2); _}] + ) + |( op + , [{pattern= Lit (Real, i1); _}; {pattern= Lit (Int, i2); _}] + ) -> ( + match op with + | "Plus__" | "Minus__" | "Times__" | "Divide__" -> + apply_arithmetic_operator_real op (Float.of_string i1) + (Float.of_string i2) + | "Or__" | "And__" | "Equals__" | "NEquals__" | "Less__" + |"Leq__" | "Greater__" | "Geq__" -> + apply_logical_operator_real op (Float.of_string i1) + (Float.of_string i2) + | _ -> FunApp (kind, l) ) + | _ -> FunApp (kind, l) ) ) + | TernaryIf (e1, e2, e3) -> ( + match + ( eval_expr ~preserve_stability e1 + , eval_expr ~preserve_stability e2 + , eval_expr ~preserve_stability e3 ) + with + | x, _, e3' when is_int 0 x -> e3'.pattern + | {pattern= Lit (Int, _); _}, e2', _ -> e2'.pattern + | e1', e2', e3' -> TernaryIf (e1', e2', e3') ) + | EAnd (e1, e2) -> ( + match + (eval_expr ~preserve_stability e1, eval_expr ~preserve_stability e2) + with + | {pattern= Lit (Int, s1); _}, {pattern= Lit (Int, s2); _} -> + let i1, i2 = (Int.of_string s1, Int.of_string s2) in + Lit (Int, Int.to_string (Bool.to_int (i1 <> 0 && i2 <> 0))) + | {pattern= Lit (_, s1); _}, {pattern= Lit (_, s2); _} -> + let r1, r2 = (Float.of_string s1, Float.of_string s2) in + Lit (Int, Int.to_string (Bool.to_int (r1 <> 0. && r2 <> 0.))) + | e1', e2' -> EAnd (e1', e2') ) + | EOr (e1, e2) -> ( + match + (eval_expr ~preserve_stability e1, eval_expr ~preserve_stability e2) + with + | {pattern= Lit (Int, s1); _}, {pattern= Lit (Int, s2); _} -> + let i1, i2 = (Int.of_string s1, Int.of_string s2) in + Lit (Int, Int.to_string (Bool.to_int (i1 <> 0 || i2 <> 0))) + | {pattern= Lit (_, s1); _}, {pattern= Lit (_, s2); _} -> + let r1, r2 = (Float.of_string s1, Float.of_string s2) in + Lit (Int, Int.to_string (Bool.to_int (r1 <> 0. || r2 <> 0.))) + | e1', e2' -> EOr (e1', e2') ) + | Indexed (e, l) -> + (* TODO: do something clever with array and matrix expressions here? + Note that we could also constant fold array sizes if we keep those around on declarations. *) + Indexed (eval_expr e, List.map ~f:(Index.map eval_expr) l) ) } + + let rec simplify_index_expr pattern = + Expr.Fixed.( + match pattern with + | Pattern.Indexed + ( { pattern= + Indexed (obj, inner_indices) + (* , Single ({emeta= {type_= UArray UInt; _} as emeta; _} as multi) + * :: inner_tl ) *) + ; meta } + , ( Single ({meta= Expr.Typed.Meta.{type_= UInt; _}; _} as single_e) + as single ) + :: outer_tl ) + when List.exists ~f:is_multi_index inner_indices -> ( + match List.split_while ~f:(Fn.non is_multi_index) inner_indices with + | inner_singles, MultiIndex first_multi :: inner_tl -> + (* foo [arr1, ..., arrN] [i1, ..., iN] -> + foo [arr1[i1]] [arr[i2]] ... [arrN[iN]] *) + simplify_index_expr + (Indexed + ( { pattern= + Indexed + ( obj + , inner_singles + @ [ Index.Single + { pattern= Indexed (first_multi, [single]) + ; meta= {meta with type_= UInt} } ] + @ inner_tl ) + ; meta } + , outer_tl ) ) + | inner_singles, All :: inner_tl -> + (* v[:x][i] -> v[i] *) + (* v[:][i] -> v[i] *) + (* XXX generate check *) + simplify_index_expr + (Indexed + ( { pattern= Indexed (obj, inner_singles @ [single] @ inner_tl) + ; meta } + , outer_tl ) ) + | inner_singles, Between (bot, _) :: inner_tl + |inner_singles, Upfrom bot :: inner_tl -> + (* v[x:y][z] -> v[x+z-1] *) + (* XXX generate check *) + simplify_index_expr + (Indexed + ( { pattern= + Indexed + ( obj + , inner_singles + @ [ Index.Single + Expr.Helpers.( + binop (binop bot Plus single_e) Minus + loop_bottom) ] + @ inner_tl ) + ; meta } + , outer_tl ) ) + | inner_singles, (([] | Single _ :: _) as multis) -> + Common.FatalError.fatal_error_msg + [%message + " There must be a multi-index." + (inner_singles : Expr.Typed.t Index.t list) + (multis : Expr.Typed.t Index.t list)] ) + | e -> e) + + let remove_trailing_alls_expr = function + | Expr.Fixed.Pattern.Indexed (obj, indices) -> + (* a[2][:] -> a[2] *) + let rec remove_trailing_alls indices = + match List.rev indices with + | Index.All :: tl -> remove_trailing_alls (List.rev tl) + | _ -> indices in + Expr.Fixed.Pattern.Indexed (obj, remove_trailing_alls indices) + | e -> e + + let rec simplify_indices_expr expr = + Expr.Fixed.( + let pattern = + expr.pattern |> remove_trailing_alls_expr |> simplify_index_expr + |> Expr.Fixed.Pattern.map simplify_indices_expr in + {expr with pattern}) + + let try_eval_expr expr = try eval_expr expr with Rejected _ -> expr + + let rec eval_stmt s = + try + Stmt.Fixed. + { s with + pattern= + Pattern.map + (Fn.compose eval_expr simplify_indices_expr) + eval_stmt s.pattern } + with Rejected (loc, m) -> + { Stmt.Fixed.pattern= + NRFunApp (CompilerInternal FnReject, [Expr.Helpers.str m]) + ; meta= loc } + + let eval_prog = Program.map try_eval_expr eval_stmt +end diff --git a/src/analysis_and_optimization/Partial_evaluation.mli b/src/analysis_and_optimization/Partial_evaluation.mli new file mode 100644 index 0000000000..a12d5db74a --- /dev/null +++ b/src/analysis_and_optimization/Partial_evaluation.mli @@ -0,0 +1,8 @@ +open Middle + +module type PartialEvaluator = sig + val try_eval_expr : Expr.Typed.t -> Expr.Typed.t + val eval_prog : Program.Typed.t -> Program.Typed.t +end + +module Make (StdLib : Frontend.Std_library_utils.Library) : PartialEvaluator diff --git a/src/analysis_and_optimization/Partial_evaluator.ml b/src/analysis_and_optimization/Partial_evaluator.ml deleted file mode 100644 index e63e77ee98..0000000000 --- a/src/analysis_and_optimization/Partial_evaluator.ml +++ /dev/null @@ -1,1152 +0,0 @@ -(* A partial evaluator for use in static analysis and optimization *) - -open Core_kernel -open Core_kernel.Poly -open Middle - -exception Rejected of Location_span.t * string - -let rec is_int query Expr.Fixed.{pattern; _} = - match pattern with - | Lit (Int, i) | Lit (Real, i) -> float_of_string i = float_of_int query - | Promotion (e, _, _) -> is_int query e - | _ -> false - -let apply_prefix_operator_int (op : string) i = - Expr.Fixed.Pattern.Lit - ( Int - , Int.to_string - ( match op with - | "PPlus__" -> i - | "PMinus__" -> -i - | "PNot__" -> if i = 0 then 1 else 0 - | s -> - Common.FatalError.fatal_error_msg - [%message "Not an int prefix operator: " s] ) ) - -let apply_prefix_operator_real (op : string) i = - Expr.Fixed.Pattern.Lit - ( Real - , Float.to_string - ( match op with - | "PPlus__" -> i - | "PMinus__" -> -.i - | s -> - Common.FatalError.fatal_error_msg - [%message "Not a real prefix operator: " s] ) ) - -let apply_operator_int (op : string) i1 i2 = - Expr.Fixed.Pattern.Lit - ( Int - , Int.to_string - ( match op with - | "Plus__" -> i1 + i2 - | "Minus__" -> i1 - i2 - | "Times__" -> i1 * i2 - | "Divide__" | "IntDivide__" -> i1 / i2 - | "Modulo__" -> i1 % i2 - | "Equals__" -> Bool.to_int (i1 = i2) - | "NEquals__" -> Bool.to_int (i1 <> i2) - | "Less__" -> Bool.to_int (i1 < i2) - | "Leq__" -> Bool.to_int (i1 <= i2) - | "Greater__" -> Bool.to_int (i1 > i2) - | "Geq__" -> Bool.to_int (i1 >= i2) - | s -> - Common.FatalError.fatal_error_msg - [%message "Not an int operator: " s] ) ) - -let apply_arithmetic_operator_real (op : string) r1 r2 = - Expr.Fixed.Pattern.Lit - ( Real - , Float.to_string - ( match op with - | "Plus__" -> r1 +. r2 - | "Minus__" -> r1 -. r2 - | "Times__" -> r1 *. r2 - | "Divide__" -> r1 /. r2 - | s -> - Common.FatalError.fatal_error_msg - [%message "Not a real operator: " s] ) ) - -let apply_logical_operator_real (op : string) r1 r2 = - Expr.Fixed.Pattern.Lit - ( Int - , Int.to_string - ( match op with - | "Equals__" -> Bool.to_int (r1 = r2) - | "NEquals__" -> Bool.to_int (r1 <> r2) - | "Less__" -> Bool.to_int (r1 < r2) - | "Leq__" -> Bool.to_int (r1 <= r2) - | "Greater__" -> Bool.to_int (r1 > r2) - | "Geq__" -> Bool.to_int (r1 >= r2) - | s -> - Common.FatalError.fatal_error_msg - [%message "Not a logical operator: " s] ) ) - -let is_multi_index = function - | Index.MultiIndex _ | Upfrom _ | Between _ | All -> true - | Single _ -> false - -let rec eval_expr ?(preserve_stability = false) (e : Expr.Typed.t) = - { e with - pattern= - ( match e.pattern with - | Var _ | Lit (_, _) -> e.pattern - | Promotion (expr, ut, ad) -> - Promotion (eval_expr ~preserve_stability expr, ut, ad) - | FunApp (kind, l) -> ( - let l = List.map ~f:(eval_expr ~preserve_stability) l in - match kind with - | UserDefined _ | CompilerInternal _ -> FunApp (kind, l) - | StanLib (f, suffix, mem_type) -> - let get_fun_or_op_rt_opt name l' = - let module TC = Frontend.Typechecking.Make ((* TODO *) - Frontend.Std_library_utils.NullLibrary) in - let argument_types = - List.map ~f:(fun x -> Expr.Typed.(adlevel_of x, type_of x)) l' - in - Operator.of_string_opt name - |> Option.value_map - ~f:(fun op -> - TC.operator_return_type op argument_types - |> Option.map ~f:fst ) - ~default: - (TC.library_function_return_type name argument_types) - in - let try_partially_evaluate_stanlib e = - Expr.Fixed.Pattern.( - match e with - | FunApp (StanLib (f', suffix', mem_type), l') -> ( - match get_fun_or_op_rt_opt f' l' with - | Some _ -> FunApp (StanLib (f', suffix', mem_type), l') - | None -> FunApp (StanLib (f, suffix, mem_type), l) ) - | e -> e) in - let lub_mem_pat lst = - Common.Helpers.lub_mem_pat (List.cons mem_type lst) in - try_partially_evaluate_stanlib - ( match (f, l) with - (* TODO: deal with tilde statements and unnormalized distributions properly here *) - | ( "bernoulli_lpmf" - , [ y - ; { pattern= - FunApp - ( StanLib ("inv_logit", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem2) - , [ alpha - ; { pattern= - FunApp - ( StanLib - ("Times__", FnPlain, mem3) - , [x; beta] ) - ; _ } ] ) - ; _ } ] ) - ; _ } ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2; mem3] in - FunApp - ( StanLib ("bernoulli_logit_glm_lpmf", suffix, lub_mem) - , [y; x; alpha; beta] ) - | ( "bernoulli_lpmf" - , [ y - ; { pattern= - FunApp - ( StanLib ("inv_logit", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem2) - , [ { pattern= - FunApp - ( StanLib - ("Times__", FnPlain, mem3) - , [x; beta] ) - ; _ }; alpha ] ) - ; _ } ] ) - ; _ } ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2; mem3] in - FunApp - ( StanLib ("bernoulli_logit_glm_lpmf", suffix, lub_mem) - , [y; x; alpha; beta] ) - | ( "bernoulli_lpmf" - , [ y - ; { pattern= - FunApp - ( StanLib ("inv_logit", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem2) - , [x; beta] ) - ; _ } ] ) - ; _ } ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp - ( StanLib ("bernoulli_logit_glm_lpmf", suffix, lub_mem) - , [y; x; Expr.Helpers.zero; beta] ) - | ( "bernoulli_logit_lpmf" - , [ y - ; { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem1) - , [ alpha - ; { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem2) - , [x; beta] ) - ; _ } ] ) - ; _ } ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp - ( StanLib ("bernoulli_logit_glm_lpmf", suffix, lub_mem) - , [y; x; alpha; beta] ) - | ( "bernoulli_logit_lpmf" - , [ y - ; { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem2) - , [x; beta] ) - ; _ }; alpha ] ) - ; _ } ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp - ( StanLib ("bernoulli_logit_glm_lpmf", suffix, lub_mem) - , [y; x; alpha; beta] ) - | ( "bernoulli_logit_lpmf" - , [ y - ; { pattern= - FunApp (StanLib ("Times__", FnPlain, mem), [x; beta]) - ; _ } ] ) - when Expr.Typed.type_of x = UMatrix -> - FunApp - ( StanLib - ("bernoulli_logit_glm_lpmf", suffix, lub_mem_pat [mem]) - , [y; x; Expr.Helpers.zero; beta] ) - | ( "bernoulli_lpmf" - , [ y - ; { pattern= - FunApp (StanLib ("inv_logit", FnPlain, mem), [alpha]) - ; _ } ] ) -> - FunApp - ( StanLib - ("bernoulli_logit_lpmf", suffix, lub_mem_pat [mem]) - , [y; alpha] ) - | ( "bernoulli_rng" - , [ { pattern= - FunApp (StanLib ("inv_logit", FnPlain, mem), [alpha]) - ; _ } ] ) -> - FunApp - ( StanLib - ("bernoulli_logit_rng", suffix, lub_mem_pat [mem]) - , [alpha] ) - | ( "binomial_lpmf" - , [ y; n - ; { pattern= - FunApp (StanLib ("inv_logit", FnPlain, mem), [alpha]) - ; _ } ] ) -> - FunApp - ( StanLib - ("binomial_logit_lpmf", suffix, lub_mem_pat [mem]) - , [y; n; alpha] ) - | ( "categorical_lpmf" - , [ y - ; { pattern= - FunApp (StanLib ("inv_logit", FnPlain, mem), [alpha]) - ; _ } ] ) -> - FunApp - ( StanLib - ("categorical_logit_lpmf", suffix, lub_mem_pat [mem]) - , [y; alpha] ) - | ( "categorical_rng" - , [ { pattern= - FunApp (StanLib ("inv_logit", FnPlain, mem), [alpha]) - ; _ } ] ) -> - FunApp - ( StanLib - ("categorical_logit_rng", suffix, lub_mem_pat [mem]) - , [alpha] ) - | "columns_dot_product", [x; y] when Expr.Typed.equal x y -> - FunApp (StanLib ("columns_dot_self", suffix, mem_type), [x]) - | "dot_product", [x; y] when Expr.Typed.equal x y -> - FunApp (StanLib ("dot_self", suffix, mem_type), [x]) - | ( "inv" - , [{pattern= FunApp (StanLib ("sqrt", FnPlain, mem), l); _}] ) - -> - FunApp (StanLib ("inv_sqrt", suffix, mem), l) - | ( "inv" - , [ { pattern= FunApp (StanLib ("square", FnPlain, mem), [x]) - ; _ } ] ) -> - FunApp - (StanLib ("inv_square", suffix, lub_mem_pat [mem]), [x]) - | ( "log" - , [ { pattern= - FunApp - ( StanLib ("Minus__", FnPlain, mem1) - , [ y - ; { pattern= - FunApp (StanLib ("exp", FnPlain, mem2), [x]) - ; _ } ] ) - ; _ } ] ) - when is_int 1 y && not preserve_stability -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp (StanLib ("log1m_exp", suffix, lub_mem), [x]) - | ( "log" - , [ { pattern= - FunApp - ( StanLib ("Minus__", FnPlain, mem1) - , [ y - ; { pattern= - FunApp - (StanLib ("inv_logit", FnPlain, mem2), [x]) - ; _ } ] ) - ; _ } ] ) - when is_int 1 y && not preserve_stability -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp (StanLib ("log1m_inv_logit", suffix, lub_mem), [x]) - | ( "log" - , [ { pattern= - FunApp (StanLib ("Minus__", FnPlain, mem), [y; x]) - ; _ } ] ) - when is_int 1 y && not preserve_stability -> - FunApp (StanLib ("log1m", suffix, lub_mem_pat [mem]), [x]) - | ( "log" - , [ { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem1) - , [ y - ; { pattern= - FunApp (StanLib ("exp", FnPlain, mem2), [x]) - ; _ } ] ) - ; _ } ] ) - when is_int 1 y && not preserve_stability -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp (StanLib ("log1p_exp", suffix, lub_mem), [x]) - | ( "log" - , [ { pattern= - FunApp (StanLib ("Plus__", FnPlain, mem), [y; x]) - ; _ } ] ) - when is_int 1 y && not preserve_stability -> - FunApp (StanLib ("log1p", suffix, lub_mem_pat [mem]), [x]) - | ( "log" - , [ { pattern= - FunApp - ( StanLib ("fabs", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("determinant", FnPlain, mem2) - , [x] ) - ; _ } ] ) - ; _ } ] ) -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp (StanLib ("log_determinant", suffix, lub_mem), [x]) - | ( "log" - , [ { pattern= - FunApp - ( StanLib ("Minus__", FnPlain, mem1) - , [ { pattern= - FunApp (StanLib ("exp", FnPlain, mem2), [x]) - ; _ } - ; { pattern= - FunApp (StanLib ("exp", FnPlain, mem3), [y]) - ; _ } ] ) - ; _ } ] ) -> - let lub_mem = lub_mem_pat [mem1; mem2; mem3] in - FunApp (StanLib ("log_diff_exp", suffix, lub_mem), [x; y]) - (* TODO: log_mix?*) - | ( "log" - , [ { pattern= - FunApp (StanLib ("falling_factorial", FnPlain, mem), l) - ; _ } ] ) -> - FunApp - ( StanLib - ("log_falling_factorial", suffix, lub_mem_pat [mem]) - , l ) - | ( "log" - , [ { pattern= - FunApp (StanLib ("rising_factorial", FnPlain, mem), l) - ; _ } ] ) -> - FunApp - ( StanLib - ("log_rising_factorial", suffix, lub_mem_pat [mem]) - , l ) - | ( "log" - , [ { pattern= FunApp (StanLib ("inv_logit", FnPlain, mem), l) - ; _ } ] ) -> - FunApp - (StanLib ("log_inv_logit", suffix, lub_mem_pat [mem]), l) - | ( "log" - , [{pattern= FunApp (StanLib ("softmax", FnPlain, mem), l); _}] - ) -> - FunApp - (StanLib ("log_softmax", suffix, lub_mem_pat [mem]), l) - | ( "log" - , [ { pattern= - FunApp - ( StanLib ("sum", FnPlain, mem1) - , [ { pattern= - FunApp (StanLib ("exp", FnPlain, mem2), l) - ; _ } ] ) - ; _ } ] ) -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp (StanLib ("log_sum_exp", suffix, lub_mem), l) - | ( "log" - , [ { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem1) - , [ { pattern= - FunApp (StanLib ("exp", FnPlain, mem2), [x]) - ; _ } - ; { pattern= - FunApp (StanLib ("exp", FnPlain, mem3), [y]) - ; _ } ] ) - ; _ } ] ) -> - let lub_mem = lub_mem_pat [mem1; mem2; mem3] in - FunApp (StanLib ("log_sum_exp", suffix, lub_mem), [x; y]) - | ( "multi_normal_lpdf" - , [ y; mu - ; { pattern= - FunApp (StanLib ("inverse", FnPlain, mem), [tau]) - ; _ } ] ) -> - let lub_mem = lub_mem_pat [mem] in - FunApp - ( StanLib ("multi_normal_prec_lpdf", suffix, lub_mem) - , [y; mu; tau] ) - | ( "neg_binomial_2_lpmf" - , [ y - ; { pattern= - FunApp - ( StanLib ("exp", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem2) - , [ alpha - ; { pattern= - FunApp - ( StanLib - ("Times__", FnPlain, mem3) - , [x; beta] ) - ; _ } ] ) - ; _ } ] ) - ; _ }; sigma ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2; mem3] in - FunApp - ( StanLib ("neg_binomial_2_log_glm_lpmf", suffix, lub_mem) - , [y; x; alpha; beta; sigma] ) - | ( "neg_binomial_2_lpmf" - , [ y - ; { pattern= - FunApp - ( StanLib ("exp", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem2) - , [ { pattern= - FunApp - ( StanLib - ("Times__", FnPlain, mem3) - , [x; beta] ) - ; _ }; alpha ] ) - ; _ } ] ) - ; _ }; sigma ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2; mem3] in - FunApp - ( StanLib ("neg_binomial_2_log_glm_lpmf", suffix, lub_mem) - , [y; x; alpha; beta; sigma] ) - | ( "neg_binomial_2_lpmf" - , [ y - ; { pattern= - FunApp - ( StanLib ("exp", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem2) - , [x; beta] ) - ; _ } ] ) - ; _ }; sigma ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp - ( StanLib ("neg_binomial_2_log_glm_lpmf", suffix, lub_mem) - , [y; x; Expr.Helpers.zero; beta; sigma] ) - | ( "neg_binomial_2_log_lpmf" - , [ y - ; { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem1) - , [ alpha - ; { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem2) - , [x; beta] ) - ; _ } ] ) - ; _ }; sigma ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp - ( StanLib ("neg_binomial_2_log_glm_lpmf", suffix, lub_mem) - , [y; x; alpha; beta; sigma] ) - | ( "neg_binomial_2_log_lpmf" - , [ y - ; { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem2) - , [x; beta] ) - ; _ }; alpha ] ) - ; _ }; sigma ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp - ( StanLib ("neg_binomial_2_log_glm_lpmf", suffix, lub_mem) - , [y; x; alpha; beta; sigma] ) - | ( "neg_binomial_2_log_lpmf" - , [ y - ; { pattern= - FunApp (StanLib ("Times__", FnPlain, mem), [x; beta]) - ; _ }; sigma ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem] in - FunApp - ( StanLib ("neg_binomial_2_log_glm_lpmf", suffix, lub_mem) - , [y; x; Expr.Helpers.zero; beta; sigma] ) - | ( "neg_binomial_2_lpmf" - , [ y - ; {pattern= FunApp (StanLib ("exp", FnPlain, mem), [eta]); _} - ; phi ] ) -> - let lub_mem = lub_mem_pat [mem] in - FunApp - ( StanLib ("neg_binomial_2_log_lpmf", suffix, lub_mem) - , [y; eta; phi] ) - | ( "neg_binomial_2_rng" - , [ {pattern= FunApp (StanLib ("exp", FnPlain, mem), [eta]); _} - ; phi ] ) -> - let lub_mem = lub_mem_pat [mem] in - FunApp - ( StanLib ("neg_binomial_2_log_rng", suffix, lub_mem) - , [eta; phi] ) - | ( "normal_lpdf" - , [ y - ; { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem1) - , [ alpha - ; { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem2) - , [x; beta] ) - ; _ } ] ) - ; _ }; sigma ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp - ( StanLib ("normal_id_glm_lpdf", suffix, lub_mem) - , [y; x; alpha; beta; sigma] ) - | ( "normal_lpdf" - , [ y - ; { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem2) - , [x; beta] ) - ; _ }; alpha ] ) - ; _ }; sigma ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp - ( StanLib ("normal_id_glm_lpdf", suffix, lub_mem) - , [y; x; alpha; beta; sigma] ) - | ( "normal_lpdf" - , [ y - ; { pattern= - FunApp (StanLib ("Times__", FnPlain, mem), [x; beta]) - ; _ }; sigma ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem] in - FunApp - ( StanLib ("normal_id_glm_lpdf", suffix, lub_mem) - , [y; x; Expr.Helpers.zero; beta; sigma] ) - | ( "poisson_lpmf" - , [ y - ; { pattern= - FunApp - ( StanLib ("exp", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem2) - , [ alpha - ; { pattern= - FunApp - ( StanLib - ("Times__", FnPlain, mem3) - , [x; beta] ) - ; _ } ] ) - ; _ } ] ) - ; _ } ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2; mem3] in - FunApp - ( StanLib ("poisson_log_glm_lpmf", suffix, lub_mem) - , [y; x; alpha; beta] ) - | ( "poisson_lpmf" - , [ y - ; { pattern= - FunApp - ( StanLib ("exp", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem2) - , [ { pattern= - FunApp - ( StanLib - ("Times__", FnPlain, mem3) - , [x; beta] ) - ; _ }; alpha ] ) - ; _ } ] ) - ; _ } ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2; mem3] in - FunApp - ( StanLib ("poisson_log_glm_lpmf", suffix, lub_mem) - , [y; x; alpha; beta] ) - | ( "poisson_lpmf" - , [ y - ; { pattern= - FunApp - ( StanLib ("exp", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem2) - , [x; beta] ) - ; _ } ] ) - ; _ } ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp - ( StanLib ("poisson_log_glm_lpmf", suffix, lub_mem) - , [y; x; Expr.Helpers.zero; beta] ) - | ( "poisson_log_lpmf" - , [ y - ; { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem1) - , [ alpha - ; { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem2) - , [x; beta] ) - ; _ } ] ) - ; _ } ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp - ( StanLib ("poisson_log_glm_lpmf", suffix, lub_mem) - , [y; x; alpha; beta] ) - | ( "poisson_log_lpmf" - , [ y - ; { pattern= - FunApp - ( StanLib ("Plus__", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem2) - , [x; beta] ) - ; _ }; alpha ] ) - ; _ } ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp - ( StanLib ("poisson_log_glm_lpmf", suffix, lub_mem) - , [y; x; alpha; beta] ) - | ( "poisson_log_lpmf" - , [ y - ; { pattern= - FunApp (StanLib ("Times__", FnPlain, mem), [x; beta]) - ; _ } ] ) - when Expr.Typed.type_of x = UMatrix -> - let lub_mem = lub_mem_pat [mem] in - FunApp - ( StanLib ("poisson_log_glm_lpmf", suffix, lub_mem) - , [y; x; Expr.Helpers.zero; beta] ) - | ( "poisson_lpmf" - , [ y - ; {pattern= FunApp (StanLib ("exp", FnPlain, mem), [eta]); _} - ] ) -> - let lub_mem = lub_mem_pat [mem] in - FunApp - (StanLib ("poisson_log_lpmf", suffix, lub_mem), [y; eta]) - | ( "poisson_rng" - , [{pattern= FunApp (StanLib ("exp", FnPlain, mem), [eta]); _}] - ) -> - let lub_mem = lub_mem_pat [mem] in - FunApp (StanLib ("poisson_log_rng", suffix, lub_mem), [eta]) - | "pow", [y; x] when is_int 2 y -> - FunApp (StanLib ("exp2", suffix, mem_type), [x]) - | "rows_dot_product", [x; y] when Expr.Typed.equal x y -> - FunApp (StanLib ("rows_dot_self", suffix, mem_type), [x]) - | "pow", [x; {pattern= Lit (Int, "2"); _}] -> - FunApp (StanLib ("square", suffix, mem_type), [x]) - | "pow", [x; {pattern= Lit (Real, "0.5"); _}] -> - FunApp (StanLib ("sqrt", suffix, mem_type), [x]) - | ( "pow" - , [ x - ; { pattern= - FunApp (StanLib ("Divide__", FnPlain, mem), [y; z]) - ; _ } ] ) - when is_int 1 y && is_int 2 z -> - let lub_mem = lub_mem_pat [mem] in - FunApp (StanLib ("sqrt", suffix, lub_mem), [x]) - (* This is wrong; if both are type UInt the exponent is rounds down to zero. *) - | ( "square" - , [{pattern= FunApp (StanLib ("sd", FnPlain, mem), [x]); _}] ) - -> - let lub_mem = lub_mem_pat [mem] in - FunApp (StanLib ("variance", suffix, lub_mem), [x]) - | "sqrt", [x] when is_int 2 x -> - FunApp (StanLib ("sqrt2", suffix, mem_type), []) - | ( "sum" - , [ { pattern= - FunApp - ( StanLib ("square", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Minus__", FnPlain, mem2) - , [x; y] ) - ; _ } ] ) - ; _ } ] ) -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp - (StanLib ("squared_distance", suffix, lub_mem), [x; y]) - | ( "sum" - , [ { pattern= FunApp (StanLib ("diagonal", FnPlain, mem), l) - ; _ } ] ) -> - let lub_mem = lub_mem_pat [mem] in - FunApp (StanLib ("trace", suffix, lub_mem), l) - | ( "trace" - , [ { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem2) - , [ { pattern= - FunApp - ( StanLib - ("Times__", FnPlain, mem3) - , [ d - ; { pattern= - FunApp - ( StanLib - ( "transpose" - , FnPlain - , mem4 ) - , [b] ) - ; _ } ] ) - ; _ }; a ] ) - ; _ }; c ] ) - ; _ } ] ) - when Expr.Typed.equal b c -> - let lub_mem = lub_mem_pat [mem1; mem2; mem3; mem4] in - FunApp - ( StanLib ("trace_gen_quad_form", suffix, lub_mem) - , [d; a; b] ) - | ( "trace" - , [ { pattern= - FunApp (StanLib ("quad_form", FnPlain, mem), [a; b]) - ; _ } ] ) -> - let lub_mem = lub_mem_pat [mem] in - FunApp (StanLib ("trace_quad_form", suffix, lub_mem), [a; b]) - | ( ("Plus__" | "add") - , [ ({pattern= Lit (Imaginary, i); _} as im) - ; ({pattern= Lit ((Real | Int), _); _} as r) ] ) - |( ("Plus__" | "add") - , [ ({pattern= Lit ((Real | Int), _); _} as r) - ; ({pattern= Lit (Imaginary, i); _} as im) ] ) - |( ("Plus__" | "add") - , [ ({pattern= Lit (Imaginary, i); _} as im) - ; { pattern= - Promotion - ( ({pattern= Lit ((Real | Int), _); _} as r) - , UComplex - , _ ) - ; _ } ] ) - |( ("Plus__" | "add") - , [ { pattern= - Promotion - ( ({pattern= Lit ((Real | Int), _); _} as r) - , UComplex - , _ ) - ; _ }; ({pattern= Lit (Imaginary, i); _} as im) ] ) -> - let im_part = - Expr.Fixed. - { pattern= Lit (Real, i) - ; meta= {im.meta with type_= UReal} } in - FunApp - (StanLib ("to_complex", suffix, mem_type), [r; im_part]) - | ( "Minus__" - , [x; {pattern= FunApp (StanLib ("erf", FnPlain, mem), l); _}] - ) - when is_int 1 x -> - let lub_mem = lub_mem_pat [mem] in - FunApp (StanLib ("erfc", suffix, lub_mem), l) - | ( "Minus__" - , [x; {pattern= FunApp (StanLib ("erfc", FnPlain, mem), l); _}] - ) - when is_int 1 x -> - let lub_mem = lub_mem_pat [mem] in - FunApp (StanLib ("erf", suffix, lub_mem), l) - | ( "Minus__" - , [{pattern= FunApp (StanLib ("exp", FnPlain, mem), l'); _}; x] - ) - when is_int 1 x && not preserve_stability -> - let lub_mem = lub_mem_pat [mem] in - FunApp (StanLib ("expm1", suffix, lub_mem), l') - | ( "Plus__" - , [ { pattern= - FunApp (StanLib ("Times__", FnPlain, mem), [x; y]) - ; _ }; z ] ) - when (not preserve_stability) - && not - ( UnsizedType.is_eigen_type x.meta.type_ - && UnsizedType.is_eigen_type y.meta.type_ ) -> - let lub_mem = lub_mem_pat [mem] in - FunApp (StanLib ("fma", suffix, lub_mem), [x; y; z]) - | ( "Plus__" - , [ z - ; { pattern= - FunApp (StanLib ("Times__", FnPlain, mem), [x; y]) - ; _ } ] ) - when (not preserve_stability) - && not - ( UnsizedType.is_eigen_type x.meta.type_ - && UnsizedType.is_eigen_type y.meta.type_ ) -> - let lub_mem = lub_mem_pat [mem] in - FunApp (StanLib ("fma", suffix, lub_mem), [x; y; z]) - | ( "Plus__" - , [ { pattern= - FunApp - ( StanLib - (("elt_multiply" | "EltTimes__"), FnPlain, mem) - , [x; y] ) - ; _ }; z ] ) - when not preserve_stability -> - let lub_mem = lub_mem_pat [mem] in - FunApp (StanLib ("fma", suffix, lub_mem), [x; y; z]) - | ( "Plus__" - , [ z - ; { pattern= - FunApp - ( StanLib - (("elt_multiply" | "EltTimes__"), FnPlain, mem) - , [x; y] ) - ; _ } ] ) - when not preserve_stability -> - let lub_mem = lub_mem_pat [mem] in - FunApp (StanLib ("fma", suffix, lub_mem), [x; y; z]) - | ( "Minus__" - , [ x - ; {pattern= FunApp (StanLib ("gamma_p", FnPlain, mem), l); _} - ] ) - when is_int 1 x -> - let lub_mem = lub_mem_pat [mem] in - FunApp (StanLib ("gamma_q", suffix, lub_mem), l) - | ( "Minus__" - , [ x - ; {pattern= FunApp (StanLib ("gamma_q", FnPlain, mem), l); _} - ] ) - when is_int 1 x -> - let lub_mem = lub_mem_pat [mem] in - FunApp (StanLib ("gamma_p", suffix, lub_mem), l) - | ( "Times__" - , [ { pattern= - FunApp - ( StanLib ("matrix_exp", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem2) - , [t; a] ) - ; _ } ] ) - ; _ }; b ] ) - when Expr.Typed.type_of t = UInt - || Expr.Typed.type_of t = UReal -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp - ( StanLib ("scale_matrix_exp_multiply", suffix, lub_mem) - , [t; a; b] ) - | ( "Times__" - , [ { pattern= - FunApp - ( StanLib ("matrix_exp", FnPlain, mem1) - , [ { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem2) - , [a; t] ) - ; _ } ] ) - ; _ }; b ] ) - when Expr.Typed.type_of t = UInt - || Expr.Typed.type_of t = UReal -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp - ( StanLib ("scale_matrix_exp_multiply", suffix, lub_mem) - , [t; a; b] ) - | ( "Times__" - , [ { pattern= - FunApp (StanLib ("matrix_exp", FnPlain, mem), [a]) - ; _ }; b ] ) -> - let lub_mem = lub_mem_pat [mem] in - FunApp - (StanLib ("matrix_exp_multiply", suffix, lub_mem), [a; b]) - | ( "Times__" - , [ x - ; {pattern= FunApp (StanLib ("log", FnPlain, mem), [y]); _} - ] ) - |( "Times__" - , [ {pattern= FunApp (StanLib ("log", FnPlain, mem), [y]); _} - ; x ] ) - when not preserve_stability -> - let lub_mem = lub_mem_pat [mem] in - FunApp (StanLib ("lmultiply", suffix, lub_mem), [x; y]) - | ( "Times__" - , [ { pattern= - FunApp (StanLib ("diag_matrix", FnPlain, mem1), [v]) - ; _ } - ; { pattern= - FunApp - ( StanLib ("diag_post_multiply", FnPlain, mem2) - , [a; w] ) - ; _ } ] ) - when Expr.Typed.equal v w -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp (StanLib ("quad_form_diag", suffix, lub_mem), [a; v]) - | ( "Times__" - , [ { pattern= - FunApp - ( StanLib ("diag_pre_multiply", FnPlain, mem1) - , [v; a] ) - ; _ } - ; { pattern= - FunApp (StanLib ("diag_matrix", FnPlain, mem2), [w]) - ; _ } ] ) - when Expr.Typed.equal v w -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp (StanLib ("quad_form_diag", suffix, lub_mem), [a; v]) - | ( "Times__" - , [ { pattern= - FunApp (StanLib ("transpose", FnPlain, mem1), [b]) - ; _ } - ; { pattern= - FunApp (StanLib ("Times__", FnPlain, mem2), [a; c]) - ; _ } ] ) - when Expr.Typed.equal b c -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp (StanLib ("quad_form", suffix, lub_mem), [a; b]) - | ( "Times__" - , [ { pattern= - FunApp - ( StanLib ("Times__", FnPlain, mem1) - , [ { pattern= - FunApp - (StanLib ("transpose", FnPlain, mem2), [b]) - ; _ }; a ] ) - ; _ }; c ] ) - when Expr.Typed.equal b c -> - let lub_mem = lub_mem_pat [mem1; mem2] in - FunApp (StanLib ("quad_form", suffix, lub_mem), [a; b]) - | ( "Times__" - , [ e1' - ; { pattern= - FunApp (StanLib ("diag_matrix", FnPlain, mem), [v]) - ; _ } ] ) -> - let lub_mem = lub_mem_pat [mem] in - FunApp - (StanLib ("diag_post_multiply", suffix, lub_mem), [e1'; v]) - | ( "Times__" - , [ { pattern= - FunApp (StanLib ("diag_matrix", FnPlain, mem), [v]) - ; _ }; e2' ] ) -> - let lub_mem = lub_mem_pat [mem] in - FunApp - (StanLib ("diag_pre_multiply", suffix, lub_mem), [v; e2']) - (* Constant folding for operators *) - | op, [{pattern= Lit (Int, i); _}] -> ( - match op with - | "PPlus__" | "PMinus__" | "PNot__" -> - apply_prefix_operator_int op (Int.of_string i) - | _ -> FunApp (kind, l) ) - | op, [{pattern= Lit (Real, r); _}] -> ( - match op with - | "PPlus__" | "PMinus__" -> - apply_prefix_operator_real op (Float.of_string r) - | _ -> FunApp (kind, l) ) - | ( ("Divide__" | "IntDivide__") - , [{meta= {type_= UInt; _}; _}; {pattern= Lit (Int, i2); _}] ) - when Int.of_string i2 = 0 -> - raise (Rejected (e.meta.loc, "Integer division by zero")) - | op, [{pattern= Lit (Int, i1); _}; {pattern= Lit (Int, i2); _}] - -> ( - match op with - | "Plus__" | "Minus__" | "Times__" | "Divide__" - |"IntDivide__" | "Modulo__" | "Or__" | "And__" | "Equals__" - |"NEquals__" | "Less__" | "Leq__" | "Greater__" | "Geq__" -> - apply_operator_int op (Int.of_string i1) - (Int.of_string i2) - | _ -> FunApp (kind, l) ) - | ( op - , [{pattern= Lit (Real, i1); _}; {pattern= Lit (Real, i2); _}] - ) - |op, [{pattern= Lit (Int, i1); _}; {pattern= Lit (Real, i2); _}] - |op, [{pattern= Lit (Real, i1); _}; {pattern= Lit (Int, i2); _}] - -> ( - match op with - | "Plus__" | "Minus__" | "Times__" | "Divide__" -> - apply_arithmetic_operator_real op (Float.of_string i1) - (Float.of_string i2) - | "Or__" | "And__" | "Equals__" | "NEquals__" | "Less__" - |"Leq__" | "Greater__" | "Geq__" -> - apply_logical_operator_real op (Float.of_string i1) - (Float.of_string i2) - | _ -> FunApp (kind, l) ) - | _ -> FunApp (kind, l) ) ) - | TernaryIf (e1, e2, e3) -> ( - match - ( eval_expr ~preserve_stability e1 - , eval_expr ~preserve_stability e2 - , eval_expr ~preserve_stability e3 ) - with - | x, _, e3' when is_int 0 x -> e3'.pattern - | {pattern= Lit (Int, _); _}, e2', _ -> e2'.pattern - | e1', e2', e3' -> TernaryIf (e1', e2', e3') ) - | EAnd (e1, e2) -> ( - match - (eval_expr ~preserve_stability e1, eval_expr ~preserve_stability e2) - with - | {pattern= Lit (Int, s1); _}, {pattern= Lit (Int, s2); _} -> - let i1, i2 = (Int.of_string s1, Int.of_string s2) in - Lit (Int, Int.to_string (Bool.to_int (i1 <> 0 && i2 <> 0))) - | {pattern= Lit (_, s1); _}, {pattern= Lit (_, s2); _} -> - let r1, r2 = (Float.of_string s1, Float.of_string s2) in - Lit (Int, Int.to_string (Bool.to_int (r1 <> 0. && r2 <> 0.))) - | e1', e2' -> EAnd (e1', e2') ) - | EOr (e1, e2) -> ( - match - (eval_expr ~preserve_stability e1, eval_expr ~preserve_stability e2) - with - | {pattern= Lit (Int, s1); _}, {pattern= Lit (Int, s2); _} -> - let i1, i2 = (Int.of_string s1, Int.of_string s2) in - Lit (Int, Int.to_string (Bool.to_int (i1 <> 0 || i2 <> 0))) - | {pattern= Lit (_, s1); _}, {pattern= Lit (_, s2); _} -> - let r1, r2 = (Float.of_string s1, Float.of_string s2) in - Lit (Int, Int.to_string (Bool.to_int (r1 <> 0. || r2 <> 0.))) - | e1', e2' -> EOr (e1', e2') ) - | Indexed (e, l) -> - (* TODO: do something clever with array and matrix expressions here? - Note that we could also constant fold array sizes if we keep those around on declarations. *) - Indexed (eval_expr e, List.map ~f:(Index.map eval_expr) l) ) } - -let rec simplify_index_expr pattern = - Expr.Fixed.( - match pattern with - | Pattern.Indexed - ( { pattern= - Indexed (obj, inner_indices) - (* , Single ({emeta= {type_= UArray UInt; _} as emeta; _} as multi) - * :: inner_tl ) *) - ; meta } - , ( Single ({meta= Expr.Typed.Meta.{type_= UInt; _}; _} as single_e) as - single ) - :: outer_tl ) - when List.exists ~f:is_multi_index inner_indices -> ( - match List.split_while ~f:(Fn.non is_multi_index) inner_indices with - | inner_singles, MultiIndex first_multi :: inner_tl -> - (* foo [arr1, ..., arrN] [i1, ..., iN] -> - foo [arr1[i1]] [arr[i2]] ... [arrN[iN]] *) - simplify_index_expr - (Indexed - ( { pattern= - Indexed - ( obj - , inner_singles - @ [ Index.Single - { pattern= Indexed (first_multi, [single]) - ; meta= {meta with type_= UInt} } ] - @ inner_tl ) - ; meta } - , outer_tl ) ) - | inner_singles, All :: inner_tl -> - (* v[:x][i] -> v[i] *) - (* v[:][i] -> v[i] *) - (* XXX generate check *) - simplify_index_expr - (Indexed - ( { pattern= Indexed (obj, inner_singles @ [single] @ inner_tl) - ; meta } - , outer_tl ) ) - | inner_singles, Between (bot, _) :: inner_tl - |inner_singles, Upfrom bot :: inner_tl -> - (* v[x:y][z] -> v[x+z-1] *) - (* XXX generate check *) - simplify_index_expr - (Indexed - ( { pattern= - Indexed - ( obj - , inner_singles - @ [ Index.Single - Expr.Helpers.( - binop (binop bot Plus single_e) Minus - loop_bottom) ] - @ inner_tl ) - ; meta } - , outer_tl ) ) - | inner_singles, (([] | Single _ :: _) as multis) -> - Common.FatalError.fatal_error_msg - [%message - " There must be a multi-index." - (inner_singles : Expr.Typed.t Index.t list) - (multis : Expr.Typed.t Index.t list)] ) - | e -> e) - -let remove_trailing_alls_expr = function - | Expr.Fixed.Pattern.Indexed (obj, indices) -> - (* a[2][:] -> a[2] *) - let rec remove_trailing_alls indices = - match List.rev indices with - | Index.All :: tl -> remove_trailing_alls (List.rev tl) - | _ -> indices in - Expr.Fixed.Pattern.Indexed (obj, remove_trailing_alls indices) - | e -> e - -let rec simplify_indices_expr expr = - Expr.Fixed.( - let pattern = - expr.pattern |> remove_trailing_alls_expr |> simplify_index_expr - |> Expr.Fixed.Pattern.map simplify_indices_expr in - {expr with pattern}) - -let try_eval_expr expr = try eval_expr expr with Rejected _ -> expr - -let rec eval_stmt s = - try - Stmt.Fixed. - { s with - pattern= - Pattern.map - (Fn.compose eval_expr simplify_indices_expr) - eval_stmt s.pattern } - with Rejected (loc, m) -> - { Stmt.Fixed.pattern= - NRFunApp (CompilerInternal FnReject, [Expr.Helpers.str m]) - ; meta= loc } - -let eval_prog = Program.map try_eval_expr eval_stmt diff --git a/src/analysis_and_optimization/Pedantic_analysis.ml b/src/analysis_and_optimization/Pedantic_analysis.ml index 19b7777f40..1c2a5cd7cb 100644 --- a/src/analysis_and_optimization/Pedantic_analysis.ml +++ b/src/analysis_and_optimization/Pedantic_analysis.ml @@ -487,11 +487,14 @@ let settings_constant_prop = ; copy_propagation= true ; partial_evaluation= true } +module Optimizer = Optimize.Make (Stan_math_backend.Stan_math_library) + (* Collect all pedantic mode warnings, sorted, to stderr *) let warn_pedantic (mir_unopt : Program.Typed.t) = (* Some warnings will be stronger when constants are propagated *) let mir = - Optimize.optimization_suite ~settings:settings_constant_prop mir_unopt in + Optimizer.optimization_suite ~settings:settings_constant_prop mir_unopt + in (* Try to avoid recomputation by pre-building structures *) let distributions_info = list_distributions mir in let factor_graph = prog_factor_graph mir in diff --git a/src/analysis_and_optimization/dune b/src/analysis_and_optimization/dune index b656409cb0..dd915065a2 100644 --- a/src/analysis_and_optimization/dune +++ b/src/analysis_and_optimization/dune @@ -1,7 +1,7 @@ (library (name analysis_and_optimization) (public_name stanc.analysis) - (libraries core_kernel str fmt common middle frontend) + (libraries core_kernel str fmt common middle frontend stan_math_backend) (inline_tests) ;; TODO: Not sure what's going on but it's throwing an error that this module has no implementation (modules_without_implementation monotone_framework_sigs) diff --git a/src/frontend/Ast_to_Mir.ml b/src/frontend/Ast_to_Mir.ml index 26da24d737..413181ff73 100644 --- a/src/frontend/Ast_to_Mir.ml +++ b/src/frontend/Ast_to_Mir.ml @@ -30,14 +30,6 @@ module Make (StdLib : Std_library_utils.Library) = struct let format_number s = s |> without_underscores |> drop_leading_zeros - let%expect_test "format_number0" = - format_number "0_000." |> print_endline ; - [%expect "0."] - - let%expect_test "format_number1" = - format_number ".123_456" |> print_endline ; - [%expect ".123456"] - let rec op_to_funapp op args type_ = let loc = Ast.expr_loc_lub args in let adlevel = Ast.expr_ad_lub args in diff --git a/src/stan_math_backend/Stan_math_library.ml b/src/stan_math_backend/Stan_math_library.ml new file mode 100644 index 0000000000..07c526dad2 --- /dev/null +++ b/src/stan_math_backend/Stan_math_library.ml @@ -0,0 +1,2507 @@ +(** The signatures of the Stan Math library, which are used for type checking *) + +open Core_kernel +open Core_kernel.Poly +open Middle +open Frontend.Std_library_utils + +(** The "dimensionality" (bad name?) is supposed to help us represent the + vectorized nature of many Stan functions. It allows us to represent when + a function argument can be just a real or matrix, or some common forms of + vectorization over reals. This captures the most commonly used forms in our + previous signatures; there are a lot partially because we had a lot of + inconsistencies. +*) +type dimensionality = + | DInt + | DReal + | DVector + | DMatrix + | DIntArray + (* Vectorizable int *) + | DVInt + (* Vectorizable real *) + | DVReal + (* DEPRECATED; vectorizable ints or reals *) + | DIntAndReals + (* Vectorizable vectors - for multivariate functions *) + | DVectors + | DDeepVectorized + +(* all base types with up 8 levels of nested containers - + just used for element-wise vectorized unary functions now *) + +let rec bare_array_type (t, i) = + match i with 0 -> t | j -> UnsizedType.UArray (bare_array_type (t, j - 1)) + +let rec expand_arg = function + | DInt -> [UnsizedType.UInt] + | DReal -> [UReal] + | DVector -> [UVector] + | DMatrix -> [UMatrix] + | DIntArray -> [UArray UInt] + | DVInt -> [UInt; UArray UInt] + | DVReal -> [UReal; UArray UReal; UVector; URowVector] + | DIntAndReals -> expand_arg DVReal @ expand_arg DVInt + | DVectors -> [UVector; UArray UVector; URowVector; UArray URowVector] + | DDeepVectorized -> + let all_base = [UnsizedType.UInt; UReal; URowVector; UVector; UMatrix] in + List.( + concat_map all_base ~f:(fun a -> + map (range 0 8) ~f:(fun i -> bare_array_type (a, i)) )) + +type fkind = Lpmf | Lpdf | Rng | Cdf | Ccdf | UnaryVectorized +[@@deriving show {with_path= false}] + +let is_primitive = function + | UnsizedType.UReal -> true + | UInt -> true + | _ -> false + +(** The signatures hash table *) +let (function_signatures : (string, signature list) Hashtbl.t) = + String.Table.create () + +(** All of the signatures that are added by hand, rather than the ones + added "declaratively" *) +let (manual_stan_math_signatures : (string, signature list) Hashtbl.t) = + String.Table.create () + +(* XXX The correct word here isn't combination - what is it? *) +let all_combinations xx = + List.fold_right xx ~init:[[]] ~f:(fun x accum -> + List.concat_map accum ~f:(fun acc -> + List.map ~f:(fun arg -> arg :: acc) x ) ) + +let%expect_test "combinations " = + let a = all_combinations [[1; 2]; [3; 4]; [5; 6]] in + [%sexp (a : int list list)] |> Sexp.to_string_hum |> print_endline ; + [%expect + {| ((1 3 5) (2 3 5) (1 4 5) (2 4 5) (1 3 6) (2 3 6) (1 4 6) (2 4 6)) |}] + +let missing_math_functions = + String.Set.of_list + ["beta_proportion_cdf"; "loglogistic_lcdf"; "loglogistic_cdf_log"] + +let rng_return_type t lt = + if List.for_all ~f:is_primitive lt then t else UnsizedType.UArray t + +let add_unqualified (name, rt, uqargts, mem_pattern) = + Hashtbl.add_multi manual_stan_math_signatures ~key:name + ~data: + ( rt + , List.map ~f:(fun x -> (UnsizedType.AutoDiffable, x)) uqargts + , mem_pattern ) + +let rec ints_to_real unsized = + match unsized with + | UnsizedType.UInt -> UnsizedType.UReal + | UArray t -> UArray (ints_to_real t) + | x -> x + +let rec complex_to_real = function + | UnsizedType.UComplex -> UnsizedType.UReal + | UComplexVector -> UVector + | UComplexRowVector -> URowVector + | UComplexMatrix -> UMatrix + | UArray t -> UArray (complex_to_real t) + | x -> x + +let reduce_sum_allowed_dimensionalities = [1; 2; 3; 4; 5; 6; 7] + +let reduce_sum_slice_types = + let base_slice_type i = + [ bare_array_type (UnsizedType.UReal, i) + ; bare_array_type (UnsizedType.UInt, i) + ; bare_array_type (UnsizedType.UMatrix, i) + ; bare_array_type (UnsizedType.UVector, i) + ; bare_array_type (UnsizedType.URowVector, i) ] in + List.concat (List.map ~f:base_slice_type reduce_sum_allowed_dimensionalities) + +(* Variadic ODE *) +let variadic_ode_adjoint_ctl_tol_arg_types = + [ (UnsizedType.DataOnly, UnsizedType.UReal) + (* real relative_tolerance_forward *) + ; (DataOnly, UVector) (* vector absolute_tolerance_forward *) + ; (DataOnly, UReal) (* real relative_tolerance_backward *) + ; (DataOnly, UVector) (* real absolute_tolerance_backward *) + ; (DataOnly, UReal) (* real relative_tolerance_quadrature *) + ; (DataOnly, UReal) (* real absolute_tolerance_quadrature *) + ; (DataOnly, UInt) (* int max_num_steps *) + ; (DataOnly, UInt) (* int num_steps_between_checkpoints *) + ; (DataOnly, UInt) (* int interpolation_polynomial *) + ; (DataOnly, UInt) (* int solver_forward *); (DataOnly, UInt) + (* int solver_backward *) ] + +let variadic_ode_tol_arg_types = + [ (UnsizedType.DataOnly, UnsizedType.UReal); (DataOnly, UReal) + ; (DataOnly, UInt) ] + +let variadic_ode_mandatory_arg_types = + [ (UnsizedType.AutoDiffable, UnsizedType.UVector); (AutoDiffable, UReal) + ; (AutoDiffable, UArray UReal) ] + +let variadic_ode_mandatory_fun_args = + [ (UnsizedType.AutoDiffable, UnsizedType.UReal) + ; (UnsizedType.AutoDiffable, UnsizedType.UVector) ] + +let variadic_ode_fun_return_type = UnsizedType.UVector +let variadic_ode_return_type = UnsizedType.UArray UnsizedType.UVector + +let variadic_dae_tol_arg_types = + [ (UnsizedType.DataOnly, UnsizedType.UReal); (DataOnly, UReal) + ; (DataOnly, UInt) ] + +let variadic_dae_mandatory_arg_types = + [ (UnsizedType.AutoDiffable, UnsizedType.UVector); (* yy *) + (UnsizedType.AutoDiffable, UnsizedType.UVector); (* yp *) + (AutoDiffable, UReal); (AutoDiffable, UArray UReal) ] + +let variadic_dae_mandatory_fun_args = + [ (UnsizedType.AutoDiffable, UnsizedType.UReal) + ; (UnsizedType.AutoDiffable, UnsizedType.UVector) + ; (UnsizedType.AutoDiffable, UnsizedType.UVector) ] + +let variadic_dae_fun_return_type = UnsizedType.UVector +let variadic_dae_return_type = UnsizedType.UArray UnsizedType.UVector + +let mk_declarative_sig (fnkinds, name, args, mem_pattern) = + let is_glm = String.is_suffix ~suffix:"_glm" name in + let sfxes = function + | Lpmf when is_glm -> ["_lpmf"] + | Lpmf -> ["_lpmf"; "_log"] + | Lpdf when is_glm -> ["_lpdf"] + | Lpdf -> ["_lpdf"; "_log"] + | Rng -> ["_rng"] + | Cdf -> ["_cdf"; "_cdf_log"; "_lcdf"] + | Ccdf -> ["_ccdf_log"; "_lccdf"] + | UnaryVectorized -> [""] in + let add_ints = function DVReal -> DIntAndReals | x -> x in + let all_expanded args = all_combinations (List.map ~f:expand_arg args) in + let promoted_dim = function + | DInt | DIntArray | DVInt -> UnsizedType.UInt + (* XXX fix this up to work with more RNGs *) + | _ -> UReal in + let find_rt rt args = function + | Rng -> UnsizedType.ReturnType (rng_return_type rt args) + | UnaryVectorized -> ReturnType (ints_to_real (List.hd_exn args)) + | _ -> ReturnType UReal in + let create_from_fk_args fk arglists = + List.concat_map arglists ~f:(fun args -> + List.map (sfxes fk) ~f:(fun sfx -> + (name ^ sfx, find_rt UReal args fk, args, mem_pattern) ) ) in + let add_fnkind = function + | Rng -> + let rt, args = (List.hd_exn args, List.tl_exn args) in + let args = List.map ~f:add_ints args in + let rt = promoted_dim rt in + let name = name ^ "_rng" in + List.map (all_expanded args) ~f:(fun args -> + (name, find_rt rt args Rng, args, mem_pattern) ) + | UnaryVectorized -> create_from_fk_args UnaryVectorized (all_expanded args) + | fk -> create_from_fk_args fk (all_expanded args) in + List.concat_map fnkinds ~f:add_fnkind + |> List.filter ~f:(fun (n, _, _, _) -> not (Set.mem missing_math_functions n)) + |> List.map ~f:(fun (n, rt, args, support_soa) -> + ( n + , rt + , List.map ~f:(fun x -> (UnsizedType.AutoDiffable, x)) args + , support_soa ) ) + +let full_lpdf = [Lpdf; Rng; Ccdf; Cdf] +let full_lpmf = [Lpmf; Rng; Ccdf; Cdf] +let reduce_sum_functions = String.Set.of_list ["reduce_sum"; "reduce_sum_static"] +let variadic_ode_adjoint_fn = "ode_adjoint_tol_ctl" + +let variadic_ode_nonadjoint_fns = + String.Set.of_list + [ "ode_bdf_tol"; "ode_rk45_tol"; "ode_adams_tol"; "ode_bdf"; "ode_rk45" + ; "ode_adams"; "ode_ckrk"; "ode_ckrk_tol" ] + +let ode_tolerances_suffix = "_tol" +let is_reduce_sum_fn f = Set.mem reduce_sum_functions f +let is_variadic_ode_nonadjoint_fn f = Set.mem variadic_ode_nonadjoint_fns f + +let is_variadic_ode_fn f = + Set.mem variadic_ode_nonadjoint_fns f || f = variadic_ode_adjoint_fn + +let is_variadic_ode_nonadjoint_tol_fn f = + is_variadic_ode_nonadjoint_fn f + && String.is_suffix f ~suffix:ode_tolerances_suffix + +let variadic_dae_fns = String.Set.of_list ["dae_tol"; "dae"] +let dae_tolerances_suffix = "_tol" +let is_variadic_dae_fn f = Set.mem variadic_dae_fns f + +let is_variadic_dae_tol_fn f = + is_variadic_dae_fn f && String.is_suffix f ~suffix:dae_tolerances_suffix + +let is_variadic_function_name name = + is_reduce_sum_fn name || is_variadic_dae_fn name || is_variadic_ode_fn name + +let is_not_overloadable = is_variadic_dae_fn + +let distributions = + [ ( full_lpmf + , "beta_binomial" + , [DVInt; DVInt; DVReal; DVReal] + , Common.Helpers.SoA ); (full_lpdf, "beta", [DVReal; DVReal; DVReal], SoA) + ; ([Lpdf; Ccdf; Cdf], "beta_proportion", [DVReal; DVReal; DIntAndReals], SoA) + ; (full_lpmf, "bernoulli", [DVInt; DVReal], SoA) + ; ([Lpmf; Rng], "bernoulli_logit", [DVInt; DVReal], SoA) + ; ([Lpmf], "bernoulli_logit_glm", [DVInt; DMatrix; DReal; DVector], SoA) + ; (full_lpmf, "binomial", [DVInt; DVInt; DVReal], SoA) + ; ([Lpmf], "binomial_logit", [DVInt; DVInt; DVReal], SoA) + ; ([Lpmf], "categorical", [DVInt; DVector], AoS) + ; ([Lpmf], "categorical_logit", [DVInt; DVector], AoS) + ; ([Lpmf], "categorical_logit_glm", [DVInt; DMatrix; DVector; DMatrix], SoA) + ; (full_lpdf, "cauchy", [DVReal; DVReal; DVReal], SoA) + ; (full_lpdf, "chi_square", [DVReal; DVReal], SoA) + ; ([Lpdf], "dirichlet", [DVectors; DVectors], SoA) + ; (full_lpmf, "discrete_range", [DVInt; DVInt; DVInt], SoA) + ; (full_lpdf, "double_exponential", [DVReal; DVReal; DVReal], SoA) + ; (full_lpdf, "exp_mod_normal", [DVReal; DVReal; DVReal; DVReal], SoA) + ; (full_lpdf, "exponential", [DVReal; DVReal], SoA) + ; (full_lpdf, "frechet", [DVReal; DVReal; DVReal], SoA) + ; (full_lpdf, "gamma", [DVReal; DVReal; DVReal], SoA) + ; ( [Lpdf] + , "gaussian_dlm_obs" + , [DMatrix; DMatrix; DMatrix; DMatrix; DMatrix; DVector; DMatrix] + , AoS ); (full_lpdf, "gumbel", [DVReal; DVReal; DVReal], SoA) + ; ([Rng], "hmm_latent", [DIntArray; DMatrix; DMatrix; DVector], AoS) + ; ([Lpmf; Rng], "hypergeometric", [DInt; DInt; DInt; DInt], SoA) + ; (full_lpdf, "inv_chi_square", [DVReal; DVReal], SoA) + ; (full_lpdf, "inv_gamma", [DVReal; DVReal; DVReal], SoA) + ; ([Lpdf], "inv_wishart", [DMatrix; DReal; DMatrix], SoA) + ; ([Lpdf], "lkj_corr", [DMatrix; DReal], AoS) + ; ([Lpdf], "lkj_corr_cholesky", [DMatrix; DReal], AoS) + ; (full_lpdf, "logistic", [DVReal; DVReal; DVReal], SoA) + ; ([Lpdf; Rng; Cdf], "loglogistic", [DVReal; DVReal; DVReal], SoA) + ; (full_lpdf, "lognormal", [DVReal; DVReal; DVReal], SoA) + ; ([Lpdf], "multi_gp", [DMatrix; DMatrix; DVector], AoS) + ; ([Lpdf], "multi_gp_cholesky", [DMatrix; DMatrix; DVector], AoS) + ; ([Lpmf], "multinomial", [DIntArray; DVector], AoS) + ; ([Lpmf], "multinomial_logit", [DIntArray; DVector], AoS) + ; ([Lpdf], "multi_normal", [DVectors; DVectors; DMatrix], AoS) + ; ([Lpdf], "multi_normal_cholesky", [DVectors; DVectors; DMatrix], AoS) + ; ([Lpdf], "multi_normal_prec", [DVectors; DVectors; DMatrix], AoS) + ; ([Lpdf], "multi_student_t", [DVectors; DReal; DVectors; DMatrix], AoS) + ; (full_lpmf, "neg_binomial", [DVInt; DVReal; DVReal], SoA) + ; (full_lpmf, "neg_binomial_2", [DVInt; DVReal; DVReal], SoA) + ; ([Lpmf; Rng], "neg_binomial_2_log", [DVInt; DVReal; DVReal], SoA) + ; ( [Lpmf] + , "neg_binomial_2_log_glm" + , [DVInt; DMatrix; DReal; DVector; DReal] + , SoA ); (full_lpdf, "normal", [DVReal; DVReal; DVReal], SoA) + ; ([Lpdf], "normal_id_glm", [DVector; DMatrix; DReal; DVector; DReal], SoA) + ; ([Lpmf], "ordered_logistic", [DInt; DReal; DVector], SoA) + ; ([Lpmf], "ordered_logistic_glm", [DVInt; DMatrix; DVector; DVector], SoA) + ; ([Lpmf], "ordered_probit", [DInt; DReal; DVector], SoA) + ; (full_lpdf, "pareto", [DVReal; DVReal; DVReal], SoA) + ; (full_lpdf, "pareto_type_2", [DVReal; DVReal; DVReal; DVReal], SoA) + ; (full_lpmf, "poisson", [DVInt; DVReal], SoA) + ; ([Lpmf; Rng], "poisson_log", [DVInt; DVReal], SoA) + ; ([Lpmf], "poisson_log_glm", [DVInt; DMatrix; DReal; DVector], SoA) + ; (full_lpdf, "rayleigh", [DVReal; DVReal], SoA) + ; (full_lpdf, "scaled_inv_chi_square", [DVReal; DVReal; DVReal], SoA) + ; (full_lpdf, "skew_normal", [DVReal; DVReal; DVReal; DVReal], SoA) + ; (full_lpdf, "skew_double_exponential", [DVReal; DVReal; DVReal; DVReal], SoA) + ; (full_lpdf, "student_t", [DVReal; DVReal; DVReal; DVReal], SoA) + ; (full_lpdf, "std_normal", [DVReal], SoA) + ; (full_lpdf, "uniform", [DVReal; DVReal; DVReal], SoA) + ; (full_lpdf, "von_mises", [DVReal; DVReal; DVReal], SoA) + ; (full_lpdf, "weibull", [DVReal; DVReal; DVReal], SoA) + ; ([Lpdf], "wiener", [DVReal; DVReal; DVReal; DVReal; DVReal], SoA) + ; ([Lpdf], "wishart", [DMatrix; DReal; DMatrix], SoA) ] + +let distribution_families = + List.map ~f:(fun (_, name, _, _) -> name) distributions + +let math_sigs = + [ ([UnaryVectorized], "acos", [DDeepVectorized], Common.Helpers.SoA) + ; ([UnaryVectorized], "acosh", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "asin", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "asinh", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "atan", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "atanh", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "cbrt", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "ceil", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "cos", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "cosh", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "digamma", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "erf", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "erfc", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "exp", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "exp2", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "expm1", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "fabs", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "floor", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "inv", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "inv_cloglog", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "inv_erfc", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "inv_logit", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "inv_Phi", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "inv_sqrt", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "inv_square", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "lambert_w0", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "lambert_wm1", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "lgamma", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "log", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "log10", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "log1m", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "log1m_exp", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "log1m_inv_logit", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "log1p", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "log1p_exp", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "log2", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "log_inv_logit", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "logit", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "Phi", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "Phi_approx", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "round", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "sin", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "sinh", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "sqrt", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "square", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "step", [DReal], SoA) + ; ([UnaryVectorized], "tan", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "tanh", [DDeepVectorized], SoA) + (* ; add_nullary ("target") *) + ; ([UnaryVectorized], "tgamma", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "trunc", [DDeepVectorized], SoA) + ; ([UnaryVectorized], "trigamma", [DDeepVectorized], SoA) ] + +let all_declarative_sigs = distributions @ math_sigs + +let declarative_fnsigs = + List.concat_map ~f:mk_declarative_sig all_declarative_sigs + +let is_stdlib_function_name name = + let name = Utils.stdlib_distribution_name name in + Hashtbl.mem function_signatures name + +let operator_to_function_names op = + match op with + | Operator.Plus -> ["add"] + | PPlus -> ["plus"] + | Minus -> ["subtract"] + | PMinus -> ["minus"] + | Times -> ["multiply"] + | Divide -> ["mdivide_right"; "divide"] + | Modulo -> ["modulus"] + | IntDivide -> [] + | LDivide -> ["mdivide_left"] + | EltTimes -> ["elt_multiply"] + | EltDivide -> ["elt_divide"] + | Pow -> ["pow"] + | EltPow -> ["pow"] + | Or -> ["logical_or"] + | And -> ["logical_and"] + | Equals -> ["logical_eq"] + | NEquals -> ["logical_neq"] + | Less -> ["logical_lt"] + | Leq -> ["logical_lte"] + | Greater -> ["logical_gt"] + | Geq -> ["logical_gte"] + | PNot -> ["logical_negation"] + | Transpose -> ["transpose"] + +let get_signatures name = + let name = Utils.stdlib_distribution_name name in + Hashtbl.find_multi function_signatures name |> List.sort ~compare + +let get_assignment_operator_signatures assop = + ( match assop with + | Operator.Divide -> ["divide"] + | assop -> operator_to_function_names assop ) + |> List.concat_map ~f:get_signatures + |> List.concat_map ~f:(function + | ReturnType rtype, [(ad1, lhs); (ad2, rhs)], _ + when rtype = lhs + && not + ( (assop = Operator.EltTimes || assop = Operator.EltDivide) + && UnsizedType.is_scalar_type rtype ) -> + if rhs = UReal then + [ (UnsizedType.Void, [(ad1, lhs); (ad2, UInt)], Common.Helpers.SoA) + ; (Void, [(ad1, lhs); (ad2, UReal)], SoA) ] + else [(Void, [(ad1, lhs); (ad2, rhs)], SoA)] + | _ -> [] ) + +let string_operator_to_function_name str = + match str with + | "Plus__" -> "add" + | "PPlus__" -> "plus" + | "Minus__" -> "subtract" + | "PMinus__" -> "minus" + | "Times__" -> "multiply" + | "Divide__" -> "divide" + | "Modulo__" -> "modulus" + | "IntDivide__" -> "divide" + | "LDivide__" -> "mdivide_left" + | "EltTimes__" -> "elt_multiply" + | "EltDivide__" -> "elt_divide" + | "Pow__" -> "pow" + | "EltPow__" -> "pow" + | "Or__" -> "logical_or" + | "And__" -> "logical_and" + | "Equals__" -> "logical_eq" + | "NEquals__" -> "logical_neq" + | "Less__" -> "logical_lt" + | "Leq__" -> "logical_lte" + | "Greater__" -> "logical_gt" + | "Geq__" -> "logical_gte" + | "PNot__" -> "logical_negation" + | "Transpose__" -> "transpose" + | _ -> str + +let pretty_print_all_math_sigs ppf () = + let open Fmt in + let pp_sig ppf (name, (rt, args, _)) = + pf ppf "%s(@[%a@]) => %a" name + (list ~sep:comma UnsizedType.pp) + (List.map ~f:snd args) UnsizedType.pp_returntype rt in + let pp_sigs_for_name ppf name = + (list ~sep:cut pp_sig) ppf + (List.map ~f:(fun t -> (name, t)) (get_signatures name)) in + pf ppf "@[%a@]" + (list ~sep:cut pp_sigs_for_name) + (List.sort ~compare (Hashtbl.keys function_signatures)) + +let pretty_print_all_math_distributions ppf () = + let open Fmt in + let pp_dist ppf (kinds, name, _, _) = + pf ppf "@[%s: %a@]" name + (list ~sep:comma Fmt.string) + (List.map ~f:(Fn.compose String.lowercase show_fkind) kinds) in + pf ppf "@[%a@]" (list ~sep:cut pp_dist) distributions + +let int_divide_type = + UnsizedType. + ( ReturnType UInt + , [(AutoDiffable, UInt); (AutoDiffable, UInt)] + , Common.Helpers.AoS ) + +let get_operator_signatures op = + if op = Operator.IntDivide then [int_divide_type] + else operator_to_function_names op |> List.concat_map ~f:get_signatures + +let deprecated_distributions = + List.concat_map distributions ~f:(fun (fnkinds, name, _, _) -> + List.filter_map fnkinds ~f:(function + | Lpdf -> Some (name ^ "_log", name ^ "_lpdf") + | Lpmf -> Some (name ^ "_log", name ^ "_lpmf") + | Cdf -> Some (name ^ "_cdf_log", name ^ "_lcdf") + | Ccdf -> Some (name ^ "_ccdf_log", name ^ "_lccdf") + | Rng | UnaryVectorized -> None ) ) + |> List.map ~f:(fun (x, y) -> + ( x + , { replacement= y + ; version= "2.32.0" + ; extra_message= + "This can be automatically changed using the canonicalize flag \ + for stanc" } ) ) + |> String.Map.of_alist_exn + +let deprecated_functions = + let make extra_message version replacement = + {extra_message; replacement; version} in + let ode = + make + "The new interface is slightly different, see: \n\ + https://mc-stan.org/users/documentation/case-studies/convert_odes.html" + "3.0" in + let std = + make + "This can be automatically changed using the canonicalize flag for stanc" + "2.32" in + String.Map.of_alist_exn + [ ("multiply_log", std "lmultiply") + ; ("binomial_coefficient_log", std "lchoose") + ; ("cov_exp_quad", std "gp_exp_quad_cov") (* ode integrators *) + ; ("integrate_ode_rk45", ode "ode_rk45"); ("integrate_ode", ode "ode_rk45") + ; ("integrate_ode_bdf", ode "ode_bdf") + ; ("integrate_ode_adams", ode "ode_adams") ] + +(* -- Some helper definitions to populate stan_math_signatures -- *) +let bare_types = + [ UnsizedType.UInt; UReal; UComplex; UVector; URowVector; UMatrix + ; UComplexVector; UComplexRowVector; UComplexMatrix ] + +let vector_types = [UnsizedType.UReal; UArray UReal; UVector; URowVector] +let primitive_types = [UnsizedType.UInt; UReal] + +let complex_types = + [UnsizedType.UComplex; UComplexVector; UComplexRowVector; UComplexMatrix] + +let all_vector_types = + [UnsizedType.UReal; UArray UReal; UVector; URowVector; UInt; UArray UInt] + +let add_qualified (name, rt, argts, supports_soa) = + Hashtbl.add_multi function_signatures ~key:name ~data:(rt, argts, supports_soa) + +let add_nullary name = + add_unqualified (name, UnsizedType.ReturnType UReal, [], AoS) + +let add_binary name supports_soa = + add_unqualified + (name, ReturnType UReal, [UnsizedType.UReal; UReal], supports_soa) + +let add_binary_vec name supports_soa = + List.iter + ~f:(fun i -> + List.iter + ~f:(fun j -> + add_unqualified + (name, ReturnType (ints_to_real i), [i; j], supports_soa) ) + [UnsizedType.UInt; UReal] ) + [UnsizedType.UInt; UReal] ; + List.iter + ~f:(fun i -> + List.iter + ~f:(fun j -> + add_unqualified + ( name + , ReturnType (ints_to_real (bare_array_type (j, i))) + , [bare_array_type (j, i); bare_array_type (j, i)] + , supports_soa ) ) + [UnsizedType.UArray UInt; UArray UReal; UVector; URowVector; UMatrix] ) + (List.range 0 8) ; + List.iter + ~f:(fun i -> + List.iter + ~f:(fun j -> + List.iter + ~f:(fun k -> + add_unqualified + ( name + , ReturnType (ints_to_real (bare_array_type (k, j))) + , [bare_array_type (k, j); i] + , supports_soa ) ) + [UnsizedType.UArray UInt; UArray UReal; UVector; URowVector; UMatrix] + ) + (List.range 0 8) ) + [UnsizedType.UInt; UReal] ; + List.iter + ~f:(fun i -> + List.iter + ~f:(fun j -> + List.iter + ~f:(fun k -> + add_unqualified + ( name + , ReturnType (ints_to_real (bare_array_type (k, j))) + , [i; bare_array_type (k, j)] + , supports_soa ) ) + [UnsizedType.UArray UInt; UArray UReal; UVector; URowVector; UMatrix] + ) + (List.range 0 8) ) + [UnsizedType.UInt; UReal] + +let add_binary_vec_real_real name supports_soa = + add_binary name supports_soa ; + List.iter + ~f:(fun i -> + List.iter + ~f:(fun j -> + add_unqualified + ( name + , ReturnType (bare_array_type (j, i)) + , [bare_array_type (j, i); bare_array_type (j, i)] + , supports_soa ) ) + [UnsizedType.UArray UReal; UVector; URowVector; UMatrix] ) + (List.range 0 8) ; + List.iter + ~f:(fun i -> + List.iter + ~f:(fun j -> + List.iter + ~f:(fun k -> + add_unqualified + ( name + , ReturnType (bare_array_type (k, j)) + , [bare_array_type (k, j); i] + , supports_soa ) ) + [UnsizedType.UArray UReal; UVector; URowVector; UMatrix] ) + (List.range 0 8) ) + [UnsizedType.UReal] ; + List.iter + ~f:(fun i -> + List.iter + ~f:(fun j -> + List.iter + ~f:(fun k -> + add_unqualified + ( name + , ReturnType (bare_array_type (k, j)) + , [i; bare_array_type (k, j)] + , supports_soa ) ) + [UnsizedType.UArray UReal; UVector; URowVector; UMatrix] ) + (List.range 0 8) ) + [UnsizedType.UReal] + +let add_binary_vec_int_real name supports_soa = + List.iter + ~f:(fun i -> + List.iter + ~f:(fun j -> + add_unqualified + ( name + , ReturnType (bare_array_type (i, j)) + , [UInt; bare_array_type (i, j)] + , supports_soa ) ) + (List.range 0 8) ) + [UnsizedType.UArray UReal; UVector; URowVector; UMatrix] ; + List.iter + ~f:(fun i -> + List.iter + ~f:(fun j -> + add_unqualified + ( name + , ReturnType (bare_array_type (i, j)) + , [bare_array_type (UInt, j + 1); bare_array_type (i, j)] + , supports_soa ) ) + (List.range 0 8) ) + [UnsizedType.UArray UReal; UVector; URowVector] ; + List.iter + ~f:(fun i -> + add_unqualified + ( name + , ReturnType (bare_array_type (UMatrix, i)) + , [bare_array_type (UInt, i + 2); bare_array_type (UMatrix, i)] + , supports_soa ) ) + (List.range 0 8) ; + List.iter + ~f:(fun i -> + add_unqualified + ( name + , ReturnType (bare_array_type (UReal, i)) + , [bare_array_type (UInt, i); UReal] + , supports_soa ) ) + (List.range 0 8) + +let add_binary_vec_real_int name supports_soa = + List.iter + ~f:(fun i -> + List.iter + ~f:(fun j -> + add_unqualified + ( name + , ReturnType (bare_array_type (i, j)) + , [bare_array_type (i, j); UInt] + , supports_soa ) ) + (List.range 0 8) ) + [UnsizedType.UArray UReal; UVector; URowVector; UMatrix] ; + List.iter + ~f:(fun i -> + List.iter + ~f:(fun j -> + add_unqualified + ( name + , ReturnType (bare_array_type (i, j)) + , [bare_array_type (i, j); bare_array_type (UInt, j + 1)] + , supports_soa ) ) + (List.range 0 8) ) + [UnsizedType.UArray UReal; UVector; URowVector] ; + List.iter + ~f:(fun i -> + add_unqualified + ( name + , ReturnType (bare_array_type (UMatrix, i)) + , [bare_array_type (UMatrix, i); bare_array_type (UInt, i + 2)] + , supports_soa ) ) + (List.range 0 8) ; + List.iter + ~f:(fun i -> + add_unqualified + ( name + , ReturnType (bare_array_type (UReal, i)) + , [UReal; bare_array_type (UInt, i)] + , supports_soa ) ) + (List.range 0 8) + +let add_binary_vec_int_int name supports_soa = + List.iter + ~f:(fun i -> + add_unqualified + ( name + , ReturnType (bare_array_type (UInt, i)) + , [bare_array_type (UInt, i); UInt] + , supports_soa ) ) + (List.range 0 8) ; + List.iter + ~f:(fun i -> + add_unqualified + ( name + , ReturnType (bare_array_type (UInt, i)) + , [UInt; bare_array_type (UInt, i)] + , supports_soa ) ) + (List.range 1 8) ; + List.iter + ~f:(fun i -> + add_unqualified + ( name + , ReturnType (bare_array_type (UInt, i)) + , [bare_array_type (UInt, i); bare_array_type (UInt, i)] + , supports_soa ) ) + (List.range 1 8) + +let add_ternary name supports_soa = + add_unqualified (name, ReturnType UReal, [UReal; UReal; UReal], supports_soa) + +(*Adds functions that operate on matrix, double array and real types*) +let add_ternary_vec name supports_soa = + add_unqualified (name, ReturnType UReal, [UReal; UReal; UReal], supports_soa) ; + add_unqualified + (name, ReturnType UVector, [UVector; UReal; UReal], supports_soa) ; + add_unqualified + (name, ReturnType UVector, [UVector; UVector; UReal], supports_soa) ; + add_unqualified + (name, ReturnType UVector, [UVector; UReal; UVector], supports_soa) ; + add_unqualified + (name, ReturnType UVector, [UVector; UVector; UVector], supports_soa) ; + add_unqualified + (name, ReturnType UVector, [UReal; UVector; UReal], supports_soa) ; + add_unqualified + (name, ReturnType UVector, [UReal; UVector; UVector], supports_soa) ; + add_unqualified + (name, ReturnType UVector, [UReal; UReal; UVector], supports_soa) ; + add_unqualified + (name, ReturnType URowVector, [URowVector; UReal; UReal], supports_soa) ; + add_unqualified + (name, ReturnType URowVector, [URowVector; URowVector; UReal], supports_soa) ; + add_unqualified + (name, ReturnType URowVector, [URowVector; UReal; URowVector], supports_soa) ; + add_unqualified + ( name + , ReturnType URowVector + , [URowVector; URowVector; URowVector] + , supports_soa ) ; + add_unqualified + (name, ReturnType URowVector, [UReal; URowVector; UReal], supports_soa) ; + add_unqualified + (name, ReturnType URowVector, [UReal; URowVector; URowVector], supports_soa) ; + add_unqualified + (name, ReturnType URowVector, [UReal; UReal; URowVector], supports_soa) ; + add_unqualified + (name, ReturnType UMatrix, [UMatrix; UReal; UReal], supports_soa) ; + add_unqualified + (name, ReturnType UMatrix, [UMatrix; UMatrix; UReal], supports_soa) ; + add_unqualified + (name, ReturnType UMatrix, [UMatrix; UReal; UMatrix], supports_soa) ; + add_unqualified + (name, ReturnType UMatrix, [UMatrix; UMatrix; UMatrix], supports_soa) ; + add_unqualified + (name, ReturnType UMatrix, [UReal; UMatrix; UReal], supports_soa) ; + add_unqualified + (name, ReturnType UMatrix, [UReal; UMatrix; UMatrix], supports_soa) ; + add_unqualified + (name, ReturnType UMatrix, [UReal; UReal; UMatrix], supports_soa) + +let for_all_vector_types s = List.iter ~f:s all_vector_types +let for_vector_types s = List.iter ~f:s vector_types + +(* -- Start populating stan_math_signaturess -- *) +let () = + List.iter declarative_fnsigs ~f:(fun (key, rt, args, mem_pattern) -> + Hashtbl.add_multi function_signatures ~key ~data:(rt, args, mem_pattern) ) ; + add_unqualified ("abs", ReturnType UInt, [UInt], SoA) ; + add_unqualified ("abs", ReturnType UReal, [UReal], SoA) ; + add_unqualified ("abs", ReturnType UReal, [UComplex], AoS) ; + add_unqualified ("acos", ReturnType UComplex, [UComplex], AoS) ; + add_unqualified ("acosh", ReturnType UComplex, [UComplex], AoS) ; + List.iter + ~f:(fun x -> add_unqualified ("add", ReturnType x, [x; x], SoA)) + bare_types ; + add_unqualified ("add", ReturnType UVector, [UVector; UReal], SoA) ; + add_unqualified ("add", ReturnType URowVector, [URowVector; UReal], SoA) ; + add_unqualified ("add", ReturnType UMatrix, [UMatrix; UReal], SoA) ; + add_unqualified ("add", ReturnType UVector, [UReal; UVector], SoA) ; + add_unqualified ("add", ReturnType URowVector, [UReal; URowVector], SoA) ; + add_unqualified ("add", ReturnType UMatrix, [UReal; UMatrix], SoA) ; + add_unqualified ("add_diag", ReturnType UMatrix, [UMatrix; UReal], AoS) ; + add_unqualified ("add_diag", ReturnType UMatrix, [UMatrix; UVector], AoS) ; + add_unqualified ("add_diag", ReturnType UMatrix, [UMatrix; URowVector], AoS) ; + add_unqualified + ("add_diag", ReturnType UComplexMatrix, [UComplexMatrix; UComplex], AoS) ; + add_unqualified + ( "add_diag" + , ReturnType UComplexMatrix + , [UComplexMatrix; UComplexVector] + , AoS ) ; + add_unqualified + ( "add_diag" + , ReturnType UComplexMatrix + , [UComplexMatrix; UComplexRowVector] + , AoS ) ; + add_qualified + ( "algebra_solver" + , ReturnType UVector + , [ ( AutoDiffable + , UFun + ( [ (AutoDiffable, UVector); (AutoDiffable, UVector) + ; (DataOnly, UArray UReal); (DataOnly, UArray UInt) ] + , ReturnType UVector + , FnPlain + , AoS ) ); (AutoDiffable, UVector); (AutoDiffable, UVector) + ; (DataOnly, UArray UReal); (DataOnly, UArray UInt) ] + , AoS ) ; + add_qualified + ( "algebra_solver" + , ReturnType UVector + , [ ( AutoDiffable + , UFun + ( [ (AutoDiffable, UVector); (AutoDiffable, UVector) + ; (DataOnly, UArray UReal); (DataOnly, UArray UInt) ] + , ReturnType UVector + , FnPlain + , Common.Helpers.AoS ) ); (AutoDiffable, UVector) + ; (AutoDiffable, UVector); (DataOnly, UArray UReal) + ; (DataOnly, UArray UInt); (DataOnly, UReal); (DataOnly, UReal) + ; (DataOnly, UReal) ] + , AoS ) ; + add_qualified + ( "algebra_solver_newton" + , ReturnType UVector + , [ ( AutoDiffable + , UFun + ( [ (AutoDiffable, UVector); (AutoDiffable, UVector) + ; (DataOnly, UArray UReal); (DataOnly, UArray UInt) ] + , ReturnType UVector + , FnPlain + , Common.Helpers.AoS ) ); (AutoDiffable, UVector) + ; (AutoDiffable, UVector); (DataOnly, UArray UReal) + ; (DataOnly, UArray UInt) ] + , AoS ) ; + add_qualified + ( "algebra_solver_newton" + , ReturnType UVector + , [ ( AutoDiffable + , UFun + ( [ (AutoDiffable, UVector); (AutoDiffable, UVector) + ; (DataOnly, UArray UReal); (DataOnly, UArray UInt) ] + , ReturnType UVector + , FnPlain + , Common.Helpers.AoS ) ); (AutoDiffable, UVector) + ; (AutoDiffable, UVector); (DataOnly, UArray UReal) + ; (DataOnly, UArray UInt); (DataOnly, UReal); (DataOnly, UReal) + ; (DataOnly, UReal) ] + , AoS ) ; + List.iter + ~f:(fun i -> + List.iter + ~f:(fun t -> + add_unqualified + ( "append_array" + , ReturnType (bare_array_type (t, i)) + , [bare_array_type (t, i); bare_array_type (t, i)] + , AoS ) ) + bare_types ) + (List.range 1 8) ; + add_unqualified ("arg", ReturnType UReal, [UComplex], AoS) ; + add_unqualified ("asin", ReturnType UComplex, [UComplex], AoS) ; + add_unqualified ("asinh", ReturnType UComplex, [UComplex], AoS) ; + add_unqualified ("atan", ReturnType UComplex, [UComplex], AoS) ; + add_unqualified ("atanh", ReturnType UComplex, [UComplex], AoS) ; + add_binary "atan2" AoS ; + add_unqualified + ( "bernoulli_logit_glm_lpmf" + , ReturnType UReal + , [UArray UInt; UMatrix; UVector; UVector] + , SoA ) ; + add_unqualified + ( "bernoulli_logit_glm_lpmf" + , ReturnType UReal + , [UInt; UMatrix; UVector; UVector] + , SoA ) ; + add_unqualified + ( "bernoulli_logit_glm_lpmf" + , ReturnType UReal + , [UArray UInt; URowVector; UReal; UVector] + , SoA ) ; + add_unqualified + ( "bernoulli_logit_glm_lpmf" + , ReturnType UReal + , [UArray UInt; URowVector; UVector; UVector] + , SoA ) ; + add_unqualified + ( "bernoulli_logit_glm_rng" + , ReturnType (UArray UInt) + , [UMatrix; UVector; UVector] + , AoS ) ; + add_unqualified + ( "bernoulli_logit_glm_rng" + , ReturnType (UArray UInt) + , [URowVector; UVector; UVector] + , AoS ) ; + add_binary_vec_int_real "bessel_first_kind" SoA ; + add_binary_vec_int_real "bessel_second_kind" SoA ; + add_binary_vec "beta" SoA ; + (* XXX For some reason beta_proportion_rng doesn't take ints as first arg *) + for_vector_types (fun t -> + for_all_vector_types (fun u -> + add_unqualified + ( "beta_proportion_rng" + , ReturnType (rng_return_type UReal [t; u]) + , [t; u] + , AoS ) ) ) ; + add_binary_vec_int_real "binary_log_loss" AoS ; + add_binary_vec "binomial_coefficient_log" AoS ; + add_unqualified + ("block", ReturnType UMatrix, [UMatrix; UInt; UInt; UInt; UInt], SoA) ; + add_unqualified + ( "block" + , ReturnType UComplexMatrix + , [UComplexMatrix; UInt; UInt; UInt; UInt] + , AoS ) ; + add_unqualified ("categorical_rng", ReturnType UInt, [UVector], AoS) ; + add_unqualified ("categorical_logit_rng", ReturnType UInt, [UVector], AoS) ; + add_unqualified + ( "categorical_logit_glm_lpmf" + , ReturnType UReal + , [UArray UInt; URowVector; UVector; UMatrix] + , SoA ) ; + add_unqualified + ( "categorical_logit_glm_lpmf" + , ReturnType UReal + , [UInt; URowVector; UVector; UMatrix] + , SoA ) ; + add_unqualified ("append_col", ReturnType UMatrix, [UMatrix; UMatrix], AoS) ; + add_unqualified ("append_col", ReturnType UMatrix, [UVector; UMatrix], AoS) ; + add_unqualified ("append_col", ReturnType UMatrix, [UMatrix; UVector], AoS) ; + add_unqualified ("append_col", ReturnType UMatrix, [UVector; UVector], AoS) ; + add_unqualified + ("append_col", ReturnType URowVector, [URowVector; URowVector], AoS) ; + add_unqualified ("append_col", ReturnType URowVector, [UReal; URowVector], AoS) ; + add_unqualified ("append_col", ReturnType URowVector, [URowVector; UReal], AoS) ; + add_unqualified + ( "append_col" + , ReturnType UComplexMatrix + , [UComplexMatrix; UComplexMatrix] + , AoS ) ; + add_unqualified + ( "append_col" + , ReturnType UComplexMatrix + , [UComplexVector; UComplexMatrix] + , AoS ) ; + add_unqualified + ( "append_col" + , ReturnType UComplexMatrix + , [UComplexMatrix; UComplexVector] + , AoS ) ; + add_unqualified + ( "append_col" + , ReturnType UComplexMatrix + , [UComplexVector; UComplexVector] + , AoS ) ; + add_unqualified + ( "append_col" + , ReturnType UComplexRowVector + , [UComplexRowVector; UComplexRowVector] + , AoS ) ; + add_unqualified + ( "append_col" + , ReturnType UComplexRowVector + , [UComplex; UComplexRowVector] + , AoS ) ; + add_unqualified + ( "append_col" + , ReturnType UComplexRowVector + , [UComplexRowVector; UComplex] + , AoS ) ; + add_unqualified ("chol2inv", ReturnType UMatrix, [UMatrix], AoS) ; + add_unqualified ("cholesky_decompose", ReturnType UMatrix, [UMatrix], SoA) ; + add_binary_vec_int_int "choose" AoS ; + add_unqualified ("col", ReturnType UVector, [UMatrix; UInt], AoS) ; + add_unqualified ("col", ReturnType UComplexVector, [UComplexMatrix; UInt], SoA) ; + add_unqualified ("cols", ReturnType UInt, [UVector], SoA) ; + add_unqualified ("cols", ReturnType UInt, [URowVector], SoA) ; + add_unqualified ("cols", ReturnType UInt, [UMatrix], SoA) ; + add_unqualified ("cols", ReturnType UInt, [UComplexVector], SoA) ; + add_unqualified ("cols", ReturnType UInt, [UComplexRowVector], SoA) ; + add_unqualified ("cols", ReturnType UInt, [UComplexMatrix], SoA) ; + add_unqualified + ("columns_dot_product", ReturnType URowVector, [UVector; UVector], AoS) ; + add_unqualified + ("columns_dot_product", ReturnType URowVector, [URowVector; URowVector], AoS) ; + add_unqualified + ("columns_dot_product", ReturnType URowVector, [UMatrix; UMatrix], SoA) ; + add_unqualified + ( "columns_dot_product" + , ReturnType UComplexRowVector + , [UComplexVector; UComplexVector] + , AoS ) ; + add_unqualified + ( "columns_dot_product" + , ReturnType UComplexRowVector + , [UComplexRowVector; UComplexRowVector] + , AoS ) ; + add_unqualified + ( "columns_dot_product" + , ReturnType UComplexRowVector + , [UComplexMatrix; UComplexMatrix] + , AoS ) ; + add_unqualified ("columns_dot_self", ReturnType URowVector, [UVector], AoS) ; + add_unqualified ("columns_dot_self", ReturnType URowVector, [URowVector], AoS) ; + add_unqualified ("columns_dot_self", ReturnType URowVector, [UMatrix], AoS) ; + add_unqualified + ("columns_dot_self", ReturnType UComplexRowVector, [UComplexVector], AoS) ; + add_unqualified + ("columns_dot_self", ReturnType UComplexRowVector, [UComplexRowVector], AoS) ; + add_unqualified + ("columns_dot_self", ReturnType UComplexRowVector, [UComplexMatrix], AoS) ; + add_unqualified ("conj", ReturnType UComplex, [UComplex], AoS) ; + add_unqualified ("cos", ReturnType UComplex, [UComplex], AoS) ; + add_unqualified ("cosh", ReturnType UComplex, [UComplex], AoS) ; + add_unqualified + ("cov_exp_quad", ReturnType UMatrix, [UArray UReal; UReal; UReal], AoS) ; + add_unqualified + ("cov_exp_quad", ReturnType UMatrix, [UArray UVector; UReal; UReal], AoS) ; + add_unqualified + ("cov_exp_quad", ReturnType UMatrix, [UArray URowVector; UReal; UReal], AoS) ; + add_unqualified + ( "cov_exp_quad" + , ReturnType UMatrix + , [UArray UReal; UArray UReal; UReal; UReal] + , AoS ) ; + add_unqualified + ( "cov_exp_quad" + , ReturnType UMatrix + , [UArray UVector; UArray UVector; UReal; UReal] + , AoS ) ; + add_unqualified + ( "cov_exp_quad" + , ReturnType UMatrix + , [UArray URowVector; UArray URowVector; UReal; UReal] + , AoS ) ; + add_unqualified ("crossprod", ReturnType UMatrix, [UMatrix], AoS) ; + add_unqualified + ( "csr_matrix_times_vector" + , ReturnType UVector + , [UInt; UInt; UVector; UArray UInt; UArray UInt; UVector] + , SoA ) ; + add_unqualified + ( "csr_to_dense_matrix" + , ReturnType UMatrix + , [UInt; UInt; UVector; UArray UInt; UArray UInt] + , AoS ) ; + add_unqualified ("csr_extract_w", ReturnType UVector, [UMatrix], AoS) ; + add_unqualified ("csr_extract_v", ReturnType (UArray UInt), [UMatrix], AoS) ; + add_unqualified ("csr_extract_u", ReturnType (UArray UInt), [UMatrix], AoS) ; + add_unqualified + ("cumulative_sum", ReturnType (UArray UInt), [UArray UInt], AoS) ; + add_unqualified + ("cumulative_sum", ReturnType (UArray UReal), [UArray UReal], AoS) ; + add_unqualified ("cumulative_sum", ReturnType UVector, [UVector], SoA) ; + add_unqualified ("cumulative_sum", ReturnType URowVector, [URowVector], SoA) ; + add_unqualified + ("cumulative_sum", ReturnType (UArray UComplex), [UArray UComplex], AoS) ; + add_unqualified + ("cumulative_sum", ReturnType UComplexVector, [UComplexVector], AoS) ; + add_unqualified + ("cumulative_sum", ReturnType UComplexRowVector, [UComplexRowVector], AoS) ; + add_unqualified ("determinant", ReturnType UReal, [UMatrix], SoA) ; + add_unqualified ("diag_matrix", ReturnType UMatrix, [UVector], AoS) ; + add_unqualified + ("diag_matrix", ReturnType UComplexMatrix, [UComplexVector], AoS) ; + add_unqualified + ("diag_post_multiply", ReturnType UMatrix, [UMatrix; UVector], SoA) ; + add_unqualified + ("diag_post_multiply", ReturnType UMatrix, [UMatrix; URowVector], SoA) ; + add_unqualified + ( "diag_post_multiply" + , ReturnType UComplexMatrix + , [UComplexMatrix; UComplexVector] + , AoS ) ; + add_unqualified + ( "diag_post_multiply" + , ReturnType UComplexMatrix + , [UComplexMatrix; UComplexRowVector] + , AoS ) ; + add_unqualified + ("diag_pre_multiply", ReturnType UMatrix, [UVector; UMatrix], SoA) ; + add_unqualified + ("diag_pre_multiply", ReturnType UMatrix, [URowVector; UMatrix], SoA) ; + add_unqualified + ( "diag_pre_multiply" + , ReturnType UComplexMatrix + , [UComplexVector; UComplexMatrix] + , AoS ) ; + add_unqualified + ( "diag_pre_multiply" + , ReturnType UComplexMatrix + , [UComplexRowVector; UComplexMatrix] + , AoS ) ; + add_unqualified ("diagonal", ReturnType UVector, [UMatrix], SoA) ; + add_unqualified ("diagonal", ReturnType UComplexVector, [UComplexMatrix], SoA) ; + add_unqualified ("dims", ReturnType (UArray UInt), [UComplex], AoS) ; + add_unqualified ("dims", ReturnType (UArray UInt), [UInt], SoA) ; + add_unqualified ("dims", ReturnType (UArray UInt), [UReal], SoA) ; + add_unqualified ("dims", ReturnType (UArray UInt), [UVector], SoA) ; + add_unqualified ("dims", ReturnType (UArray UInt), [URowVector], SoA) ; + add_unqualified ("dims", ReturnType (UArray UInt), [UMatrix], SoA) ; + List.iter + ~f:(fun i -> + List.iter + ~f:(fun t -> + add_unqualified + ("dims", ReturnType (UArray UInt), [bare_array_type (t, i + 1)], SoA) + ) + bare_types ) + (List.range 0 8) ; + add_unqualified ("dirichlet_rng", ReturnType UVector, [UVector], AoS) ; + add_unqualified ("distance", ReturnType UReal, [UVector; UVector], SoA) ; + add_unqualified ("distance", ReturnType UReal, [URowVector; URowVector], SoA) ; + add_unqualified ("distance", ReturnType UReal, [UVector; URowVector], SoA) ; + add_unqualified ("distance", ReturnType UReal, [URowVector; UVector], SoA) ; + add_unqualified ("divide", ReturnType UComplex, [UComplex; UComplex], AoS) ; + add_unqualified ("divide", ReturnType UInt, [UInt; UInt], SoA) ; + add_unqualified ("divide", ReturnType UReal, [UReal; UReal], SoA) ; + add_unqualified ("divide", ReturnType UVector, [UVector; UReal], SoA) ; + add_unqualified ("divide", ReturnType URowVector, [URowVector; UReal], SoA) ; + add_unqualified ("divide", ReturnType UMatrix, [UMatrix; UReal], SoA) ; + add_unqualified ("dot_product", ReturnType UReal, [UVector; UVector], SoA) ; + add_unqualified + ("dot_product", ReturnType UReal, [URowVector; URowVector], SoA) ; + add_unqualified ("dot_product", ReturnType UReal, [UVector; URowVector], SoA) ; + add_unqualified ("dot_product", ReturnType UReal, [URowVector; UVector], SoA) ; + add_unqualified + ("dot_product", ReturnType UReal, [UArray UReal; UArray UReal], SoA) ; + add_unqualified + ("dot_product", ReturnType UComplex, [UComplexVector; UComplexVector], AoS) ; + add_unqualified + ( "dot_product" + , ReturnType UComplex + , [UComplexRowVector; UComplexRowVector] + , AoS ) ; + add_unqualified + ( "dot_product" + , ReturnType UComplex + , [UComplexVector; UComplexRowVector] + , AoS ) ; + add_unqualified + ( "dot_product" + , ReturnType UComplex + , [UComplexRowVector; UComplexVector] + , AoS ) ; + add_unqualified + ("dot_product", ReturnType UComplex, [UArray UComplex; UArray UComplex], AoS) ; + add_unqualified ("dot_self", ReturnType UReal, [UVector], SoA) ; + add_unqualified ("dot_self", ReturnType UReal, [URowVector], SoA) ; + add_unqualified ("dot_self", ReturnType UComplex, [UComplexVector], AoS) ; + add_unqualified ("dot_self", ReturnType UComplex, [UComplexRowVector], AoS) ; + add_nullary "e" ; + add_unqualified ("eigenvalues_sym", ReturnType UVector, [UMatrix], AoS) ; + add_unqualified ("eigenvectors_sym", ReturnType UMatrix, [UMatrix], AoS) ; + add_unqualified ("generalized_inverse", ReturnType UMatrix, [UMatrix], SoA) ; + add_unqualified ("qr_Q", ReturnType UMatrix, [UMatrix], AoS) ; + add_unqualified ("qr_R", ReturnType UMatrix, [UMatrix], AoS) ; + add_unqualified ("qr_thin_Q", ReturnType UMatrix, [UMatrix], AoS) ; + add_unqualified ("qr_thin_R", ReturnType UMatrix, [UMatrix], AoS) ; + List.iter + ~f:(fun x -> add_unqualified ("elt_divide", ReturnType x, [x; x], SoA)) + bare_types ; + add_unqualified ("elt_divide", ReturnType UVector, [UVector; UReal], SoA) ; + add_unqualified ("elt_divide", ReturnType URowVector, [URowVector; UReal], SoA) ; + add_unqualified ("elt_divide", ReturnType UMatrix, [UMatrix; UReal], SoA) ; + add_unqualified ("elt_divide", ReturnType UVector, [UReal; UVector], SoA) ; + add_unqualified ("elt_divide", ReturnType URowVector, [UReal; URowVector], SoA) ; + add_unqualified ("elt_divide", ReturnType UMatrix, [UReal; UMatrix], SoA) ; + List.iter + ~f:(fun x -> add_unqualified ("elt_multiply", ReturnType x, [x; x], SoA)) + bare_types ; + add_unqualified ("exp", ReturnType UComplex, [UComplex], AoS) ; + add_binary_vec_int_int "falling_factorial" SoA ; + add_binary_vec_real_int "falling_factorial" SoA ; + add_binary_vec "fdim" AoS ; + add_ternary_vec "fma" SoA ; + add_binary_vec "fmax" AoS ; + add_binary_vec "fmin" AoS ; + add_binary_vec "fmod" AoS ; + add_binary_vec_real_real "gamma_p" AoS ; + add_binary_vec_real_real "gamma_q" AoS ; + add_unqualified + ( "gaussian_dlm_obs_log" + , ReturnType UReal + , [UMatrix; UMatrix; UMatrix; UVector; UMatrix; UVector; UMatrix] + , AoS ) ; + add_unqualified + ( "gaussian_dlm_obs_lpdf" + , ReturnType UReal + , [UMatrix; UMatrix; UMatrix; UVector; UMatrix; UVector; UMatrix] + , AoS ) ; + List.iter + ~f:(fun i -> + List.iter + ~f:(fun t -> + add_unqualified + ( "get_imag" + , ReturnType (bare_array_type (complex_to_real t, i)) + , [bare_array_type (t, i)] + , AoS ) ) + complex_types ) + (List.range 0 8) ; + List.iter + ~f:(fun i -> + List.iter + ~f:(fun t -> + add_unqualified + ( "get_real" + , ReturnType (bare_array_type (complex_to_real t, i)) + , [bare_array_type (t, i)] + , AoS ) ) + complex_types ) + (List.range 0 8) ; + add_unqualified + ("gp_dot_prod_cov", ReturnType UMatrix, [UArray UReal; UReal], AoS) ; + add_unqualified + ( "gp_dot_prod_cov" + , ReturnType UMatrix + , [UArray UReal; UArray UReal; UReal] + , AoS ) ; + add_unqualified + ("gp_dot_prod_cov", ReturnType UMatrix, [UArray UVector; UReal], AoS) ; + add_unqualified + ( "gp_dot_prod_cov" + , ReturnType UMatrix + , [UArray UVector; UArray UVector; UReal] + , AoS ) ; + add_unqualified + ("gp_exp_quad_cov", ReturnType UMatrix, [UArray UReal; UReal; UReal], AoS) ; + add_unqualified + ( "gp_exp_quad_cov" + , ReturnType UMatrix + , [UArray UReal; UArray UReal; UReal; UReal] + , AoS ) ; + add_unqualified + ("gp_exp_quad_cov", ReturnType UMatrix, [UArray UVector; UReal; UReal], AoS) ; + add_unqualified + ( "gp_exp_quad_cov" + , ReturnType UMatrix + , [UArray UVector; UArray UVector; UReal; UReal] + , AoS ) ; + add_unqualified + ( "gp_exp_quad_cov" + , ReturnType UMatrix + , [UArray UVector; UReal; UArray UReal] + , AoS ) ; + add_unqualified + ( "gp_exp_quad_cov" + , ReturnType UMatrix + , [UArray UVector; UArray UVector; UReal; UArray UReal] + , AoS ) ; + add_unqualified + ("gp_matern32_cov", ReturnType UMatrix, [UArray UReal; UReal; UReal], AoS) ; + add_unqualified + ( "gp_matern32_cov" + , ReturnType UMatrix + , [UArray UReal; UArray UReal; UReal; UReal] + , AoS ) ; + add_unqualified + ("gp_matern32_cov", ReturnType UMatrix, [UArray UVector; UReal; UReal], AoS) ; + add_unqualified + ( "gp_matern32_cov" + , ReturnType UMatrix + , [UArray UVector; UArray UVector; UReal; UReal] + , AoS ) ; + add_unqualified + ( "gp_matern32_cov" + , ReturnType UMatrix + , [UArray UVector; UReal; UArray UReal] + , AoS ) ; + add_unqualified + ( "gp_matern32_cov" + , ReturnType UMatrix + , [UArray UVector; UArray UVector; UReal; UArray UReal] + , AoS ) ; + add_unqualified + ("gp_matern52_cov", ReturnType UMatrix, [UArray UReal; UReal; UReal], AoS) ; + add_unqualified + ( "gp_matern52_cov" + , ReturnType UMatrix + , [UArray UReal; UArray UReal; UReal; UReal] + , AoS ) ; + add_unqualified + ("gp_matern52_cov", ReturnType UMatrix, [UArray UVector; UReal; UReal], AoS) ; + add_unqualified + ( "gp_matern52_cov" + , ReturnType UMatrix + , [UArray UVector; UArray UVector; UReal; UReal] + , AoS ) ; + add_unqualified + ( "gp_matern52_cov" + , ReturnType UMatrix + , [UArray UVector; UReal; UArray UReal] + , AoS ) ; + add_unqualified + ( "gp_matern52_cov" + , ReturnType UMatrix + , [UArray UVector; UArray UVector; UReal; UArray UReal] + , AoS ) ; + add_unqualified + ("gp_exponential_cov", ReturnType UMatrix, [UArray UReal; UReal; UReal], AoS) ; + add_unqualified + ( "gp_exponential_cov" + , ReturnType UMatrix + , [UArray UReal; UArray UReal; UReal; UReal] + , AoS ) ; + add_unqualified + ( "gp_exponential_cov" + , ReturnType UMatrix + , [UArray UVector; UReal; UReal] + , AoS ) ; + add_unqualified + ( "gp_exponential_cov" + , ReturnType UMatrix + , [UArray UVector; UArray UVector; UReal; UReal] + , AoS ) ; + add_unqualified + ( "gp_exponential_cov" + , ReturnType UMatrix + , [UArray UVector; UReal; UArray UReal] + , AoS ) ; + add_unqualified + ( "gp_exponential_cov" + , ReturnType UMatrix + , [UArray UVector; UArray UVector; UReal; UArray UReal] + , AoS ) ; + add_unqualified + ( "gp_periodic_cov" + , ReturnType UMatrix + , [UArray UReal; UReal; UReal; UReal] + , AoS ) ; + add_unqualified + ( "gp_periodic_cov" + , ReturnType UMatrix + , [UArray UReal; UArray UReal; UReal; UReal; UReal] + , AoS ) ; + add_unqualified + ( "gp_periodic_cov" + , ReturnType UMatrix + , [UArray UVector; UReal; UReal; UReal] + , AoS ) ; + add_unqualified + ( "gp_periodic_cov" + , ReturnType UMatrix + , [UArray UVector; UArray UVector; UReal; UReal; UReal] + , AoS ) ; + (* ; add_nullary ("get_lp") *) + add_unqualified ("head", ReturnType URowVector, [URowVector; UInt], SoA) ; + add_unqualified ("head", ReturnType UVector, [UVector; UInt], SoA) ; + add_unqualified + ("head", ReturnType UComplexRowVector, [UComplexRowVector; UInt], AoS) ; + add_unqualified + ("head", ReturnType UComplexVector, [UComplexVector; UInt], AoS) ; + List.iter + ~f:(fun t -> + List.iter + ~f:(fun j -> + add_unqualified + ( "head" + , ReturnType (bare_array_type (t, j)) + , [bare_array_type (t, j); UInt] + , SoA ) ) + (List.range 1 4) ) + bare_types ; + add_unqualified + ("hmm_marginal", ReturnType UReal, [UMatrix; UMatrix; UVector], AoS) ; + add_qualified + ( "hmm_hidden_state_prob" + , ReturnType UMatrix + , [(DataOnly, UMatrix); (DataOnly, UMatrix); (DataOnly, UVector)] + , AoS ) ; + add_binary_vec "hypot" AoS ; + add_unqualified ("identity_matrix", ReturnType UMatrix, [UInt], SoA) ; + add_unqualified ("if_else", ReturnType UInt, [UInt; UInt; UInt], SoA) ; + add_unqualified ("if_else", ReturnType UReal, [UInt; UReal; UReal], SoA) ; + add_unqualified ("inc_beta", ReturnType UReal, [UReal; UReal; UReal], SoA) ; + add_unqualified ("int_step", ReturnType UInt, [UReal], SoA) ; + add_unqualified ("int_step", ReturnType UInt, [UInt], SoA) ; + add_qualified + ( "integrate_1d" + , ReturnType UReal + , [ ( AutoDiffable + , UFun + ( [ (AutoDiffable, UReal); (AutoDiffable, UReal) + ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) + ; (DataOnly, UArray UInt) ] + , ReturnType UReal + , FnPlain + , Common.Helpers.AoS ) ); (AutoDiffable, UReal) + ; (AutoDiffable, UReal); (AutoDiffable, UArray UReal) + ; (DataOnly, UArray UReal); (DataOnly, UArray UInt) ] + , AoS ) ; + add_qualified + ( "integrate_1d" + , ReturnType UReal + , [ ( AutoDiffable + , UFun + ( [ (AutoDiffable, UReal); (AutoDiffable, UReal) + ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) + ; (DataOnly, UArray UInt) ] + , ReturnType UReal + , FnPlain + , Common.Helpers.AoS ) ); (AutoDiffable, UReal) + ; (AutoDiffable, UReal); (AutoDiffable, UArray UReal) + ; (DataOnly, UArray UReal); (DataOnly, UArray UInt); (DataOnly, UReal) ] + , AoS ) ; + add_qualified + ( "integrate_ode" + , ReturnType (UArray (UArray UReal)) + , [ ( AutoDiffable + , UFun + ( [ (AutoDiffable, UReal); (AutoDiffable, UArray UReal) + ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) + ; (DataOnly, UArray UInt) ] + , ReturnType (UArray UReal) + , FnPlain + , Common.Helpers.AoS ) ); (AutoDiffable, UArray UReal) + ; (AutoDiffable, UReal); (AutoDiffable, UArray UReal) + ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) + ; (DataOnly, UArray UInt) ] + , AoS ) ; + add_qualified + ( "integrate_ode_adams" + , ReturnType (UArray (UArray UReal)) + , [ ( AutoDiffable + , UFun + ( [ (AutoDiffable, UReal); (AutoDiffable, UArray UReal) + ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) + ; (DataOnly, UArray UInt) ] + , ReturnType (UArray UReal) + , FnPlain + , Common.Helpers.AoS ) ); (AutoDiffable, UArray UReal) + ; (AutoDiffable, UReal); (AutoDiffable, UArray UReal) + ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) + ; (DataOnly, UArray UInt) ] + , AoS ) ; + add_qualified + ( "integrate_ode_adams" + , ReturnType (UArray (UArray UReal)) + , [ ( AutoDiffable + , UFun + ( [ (AutoDiffable, UReal); (AutoDiffable, UArray UReal) + ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) + ; (DataOnly, UArray UInt) ] + , ReturnType (UArray UReal) + , FnPlain + , Common.Helpers.AoS ) ); (AutoDiffable, UArray UReal) + ; (AutoDiffable, UReal); (AutoDiffable, UArray UReal) + ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) + ; (DataOnly, UArray UInt); (DataOnly, UReal); (DataOnly, UReal) + ; (DataOnly, UReal) ] + , AoS ) ; + add_qualified + ( "integrate_ode_bdf" + , ReturnType (UArray (UArray UReal)) + , [ ( AutoDiffable + , UFun + ( [ (AutoDiffable, UReal); (AutoDiffable, UArray UReal) + ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) + ; (DataOnly, UArray UInt) ] + , ReturnType (UArray UReal) + , FnPlain + , Common.Helpers.AoS ) ); (AutoDiffable, UArray UReal) + ; (AutoDiffable, UReal); (AutoDiffable, UArray UReal) + ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) + ; (DataOnly, UArray UInt) ] + , AoS ) ; + add_qualified + ( "integrate_ode_bdf" + , ReturnType (UArray (UArray UReal)) + , [ ( AutoDiffable + , UFun + ( [ (AutoDiffable, UReal); (AutoDiffable, UArray UReal) + ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) + ; (DataOnly, UArray UInt) ] + , ReturnType (UArray UReal) + , FnPlain + , Common.Helpers.AoS ) ); (AutoDiffable, UArray UReal) + ; (AutoDiffable, UReal); (AutoDiffable, UArray UReal) + ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) + ; (DataOnly, UArray UInt); (DataOnly, UReal); (DataOnly, UReal) + ; (DataOnly, UReal) ] + , AoS ) ; + add_qualified + ( "integrate_ode_rk45" + , ReturnType (UArray (UArray UReal)) + , [ ( AutoDiffable + , UFun + ( [ (AutoDiffable, UReal); (AutoDiffable, UArray UReal) + ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) + ; (DataOnly, UArray UInt) ] + , ReturnType (UArray UReal) + , FnPlain + , Common.Helpers.AoS ) ); (AutoDiffable, UArray UReal) + ; (AutoDiffable, UReal); (AutoDiffable, UArray UReal) + ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) + ; (DataOnly, UArray UInt) ] + , AoS ) ; + add_qualified + ( "integrate_ode_rk45" + , ReturnType (UArray (UArray UReal)) + , [ ( AutoDiffable + , UFun + ( [ (AutoDiffable, UReal); (AutoDiffable, UArray UReal) + ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) + ; (DataOnly, UArray UInt) ] + , ReturnType (UArray UReal) + , FnPlain + , Common.Helpers.AoS ) ); (AutoDiffable, UArray UReal) + ; (AutoDiffable, UReal); (AutoDiffable, UArray UReal) + ; (AutoDiffable, UArray UReal); (DataOnly, UArray UReal) + ; (DataOnly, UArray UInt); (DataOnly, UReal); (DataOnly, UReal) + ; (DataOnly, UReal) ] + , AoS ) ; + add_unqualified ("inv_wishart_rng", ReturnType UMatrix, [UReal; UMatrix], AoS) ; + add_unqualified ("inverse", ReturnType UMatrix, [UMatrix], SoA) ; + add_unqualified ("inverse_spd", ReturnType UMatrix, [UMatrix], AoS) ; + add_unqualified ("is_inf", ReturnType UInt, [UReal], SoA) ; + add_unqualified ("is_nan", ReturnType UInt, [UReal], SoA) ; + add_binary_vec "lbeta" AoS ; + add_binary_vec "lchoose" AoS ; + add_binary_vec_real_int "ldexp" AoS ; + add_qualified + ( "linspaced_int_array" + , ReturnType (UArray UInt) + , [(DataOnly, UInt); (DataOnly, UInt); (DataOnly, UInt)] + , SoA ) ; + add_qualified + ( "linspaced_array" + , ReturnType (UArray UReal) + , [(DataOnly, UInt); (DataOnly, UReal); (DataOnly, UReal)] + , SoA ) ; + add_qualified + ( "linspaced_row_vector" + , ReturnType URowVector + , [(DataOnly, UInt); (DataOnly, UReal); (DataOnly, UReal)] + , SoA ) ; + add_qualified + ( "linspaced_vector" + , ReturnType UVector + , [(DataOnly, UInt); (DataOnly, UReal); (DataOnly, UReal)] + , SoA ) ; + add_unqualified + ("lkj_corr_cholesky_rng", ReturnType UMatrix, [UInt; UReal], AoS) ; + add_unqualified ("lkj_corr_rng", ReturnType UMatrix, [UInt; UReal], AoS) ; + add_unqualified + ("lkj_cov_log", ReturnType UReal, [UMatrix; UVector; UVector; UReal], AoS) ; + add_binary_vec_int_real "lmgamma" AoS ; + add_binary_vec "lmultiply" SoA ; + add_unqualified ("log", ReturnType UComplex, [UComplex], AoS) ; + add_nullary "log10" ; + add_unqualified ("log10", ReturnType UComplex, [UComplex], AoS) ; + add_nullary "log2" ; + add_unqualified ("log_determinant", ReturnType UReal, [UMatrix], SoA) ; + add_binary_vec "log_diff_exp" AoS ; + add_binary_vec "log_falling_factorial" AoS ; + add_binary_vec "log_inv_logit_diff" AoS ; + add_ternary "log_mix" AoS ; + List.iter + ~f:(fun v1 -> + List.iter + ~f:(fun v2 -> + add_unqualified ("log_mix", ReturnType UReal, [v1; v2], AoS) ) + (List.tl_exn vector_types) ; + add_unqualified ("log_mix", ReturnType UReal, [v1; UArray UVector], AoS) ; + add_unqualified ("log_mix", ReturnType UReal, [v1; UArray URowVector], AoS) + ) + (List.tl_exn vector_types) ; + add_binary_vec "log_modified_bessel_first_kind" AoS ; + add_binary_vec "log_rising_factorial" AoS ; + add_unqualified ("log_softmax", ReturnType UVector, [UVector], SoA) ; + add_unqualified ("log_sum_exp", ReturnType UReal, [UArray UReal], SoA) ; + add_unqualified ("log_sum_exp", ReturnType UReal, [UVector], SoA) ; + add_unqualified ("log_sum_exp", ReturnType UReal, [URowVector], SoA) ; + add_unqualified ("log_sum_exp", ReturnType UReal, [UMatrix], SoA) ; + add_binary "log_sum_exp" SoA ; + let logical_binops = + [ "logical_or"; "logical_and"; "logical_eq"; "logical_neq"; "logical_lt" + ; "logical_lte"; "logical_gt"; "logical_gte" ] in + List.iter + ~f:(fun t1 -> + add_unqualified ("logical_negation", ReturnType UInt, [t1], SoA) ; + List.iter + ~f:(fun t2 -> + List.iter + ~f:(fun o -> add_unqualified (o, ReturnType UInt, [t1; t2], SoA)) + logical_binops ) + primitive_types ) + primitive_types ; + add_unqualified ("logical_eq", ReturnType UInt, [UComplex; UReal], SoA) ; + add_unqualified ("logical_eq", ReturnType UInt, [UComplex; UComplex], SoA) ; + add_unqualified ("logical_neq", ReturnType UInt, [UComplex; UReal], SoA) ; + add_unqualified ("logical_neq", ReturnType UInt, [UComplex; UComplex], SoA) ; + add_nullary "machine_precision" ; + add_qualified + ( "map_rect" + , ReturnType UVector + , [ ( AutoDiffable + , UFun + ( [ (AutoDiffable, UVector); (AutoDiffable, UVector) + ; (DataOnly, UArray UReal); (DataOnly, UArray UInt) ] + , ReturnType UVector + , FnPlain + , Common.Helpers.AoS ) ); (AutoDiffable, UVector) + ; (AutoDiffable, UArray UVector); (DataOnly, UArray (UArray UReal)) + ; (DataOnly, UArray (UArray UInt)) ] + , AoS ) ; + add_unqualified ("matrix_exp", ReturnType UMatrix, [UMatrix], AoS) ; + add_unqualified + ("matrix_exp_multiply", ReturnType UMatrix, [UMatrix; UMatrix], AoS) ; + add_unqualified ("matrix_power", ReturnType UMatrix, [UMatrix; UInt], SoA) ; + add_unqualified ("max", ReturnType UInt, [UArray UInt], AoS) ; + add_unqualified ("max", ReturnType UReal, [UArray UReal], AoS) ; + add_unqualified ("max", ReturnType UReal, [UVector], AoS) ; + add_unqualified ("max", ReturnType UReal, [URowVector], AoS) ; + add_unqualified ("max", ReturnType UReal, [UMatrix], AoS) ; + add_unqualified ("max", ReturnType UInt, [UInt; UInt], AoS) ; + add_unqualified ("mdivide_left", ReturnType UVector, [UMatrix; UVector], SoA) ; + add_unqualified ("mdivide_left", ReturnType UMatrix, [UMatrix; UMatrix], SoA) ; + add_unqualified + ("mdivide_left_spd", ReturnType UVector, [UMatrix; UVector], SoA) ; + add_unqualified + ("mdivide_left_spd", ReturnType UMatrix, [UMatrix; UMatrix], SoA) ; + add_unqualified + ("mdivide_left_tri_low", ReturnType UMatrix, [UMatrix; UMatrix], AoS) ; + add_unqualified + ("mdivide_left_tri_low", ReturnType UVector, [UMatrix; UVector], AoS) ; + add_unqualified + ("mdivide_right", ReturnType URowVector, [URowVector; UMatrix], AoS) ; + add_unqualified ("mdivide_right", ReturnType UMatrix, [UMatrix; UMatrix], AoS) ; + add_unqualified + ( "mdivide_right" + , ReturnType UComplexRowVector + , [UComplexRowVector; UComplexMatrix] + , AoS ) ; + add_unqualified + ( "mdivide_right" + , ReturnType UComplexMatrix + , [UComplexMatrix; UComplexMatrix] + , AoS ) ; + add_unqualified + ("mdivide_right_spd", ReturnType UMatrix, [UMatrix; UMatrix], AoS) ; + add_unqualified + ("mdivide_right_spd", ReturnType URowVector, [URowVector; UMatrix], AoS) ; + add_unqualified + ("mdivide_right_tri_low", ReturnType URowVector, [URowVector; UMatrix], AoS) ; + add_unqualified + ("mdivide_right_tri_low", ReturnType UMatrix, [UMatrix; UMatrix], AoS) ; + add_unqualified ("mean", ReturnType UReal, [UArray UReal], SoA) ; + add_unqualified ("mean", ReturnType UReal, [UVector], AoS) ; + add_unqualified ("mean", ReturnType UReal, [URowVector], AoS) ; + add_unqualified ("mean", ReturnType UReal, [UMatrix], AoS) ; + add_unqualified ("min", ReturnType UInt, [UArray UInt], AoS) ; + add_unqualified ("min", ReturnType UReal, [UArray UReal], AoS) ; + add_unqualified ("min", ReturnType UReal, [UVector], AoS) ; + add_unqualified ("min", ReturnType UReal, [URowVector], AoS) ; + add_unqualified ("min", ReturnType UReal, [UMatrix], AoS) ; + add_unqualified ("min", ReturnType UInt, [UInt; UInt], AoS) ; + List.iter + ~f:(fun x -> add_unqualified ("minus", ReturnType x, [x], SoA)) + bare_types ; + add_binary_vec_int_real "modified_bessel_first_kind" AoS ; + add_binary_vec_int_real "modified_bessel_second_kind" AoS ; + add_unqualified ("modulus", ReturnType UInt, [UInt; UInt], AoS) ; + add_unqualified + ("multi_normal_rng", ReturnType UVector, [UVector; UMatrix], AoS) ; + add_unqualified + ( "multi_normal_rng" + , ReturnType (UArray UVector) + , [UArray UVector; UMatrix] + , AoS ) ; + add_unqualified + ("multi_normal_rng", ReturnType UVector, [URowVector; UMatrix], AoS) ; + add_unqualified + ( "multi_normal_rng" + , ReturnType (UArray UVector) + , [UArray URowVector; UMatrix] + , AoS ) ; + add_unqualified + ("multi_normal_cholesky_rng", ReturnType UVector, [UVector; UMatrix], AoS) ; + add_unqualified + ( "multi_normal_cholesky_rng" + , ReturnType (UArray UVector) + , [UArray UVector; UMatrix] + , AoS ) ; + add_unqualified + ("multi_normal_cholesky_rng", ReturnType UVector, [URowVector; UMatrix], AoS) ; + add_unqualified + ( "multi_normal_cholesky_rng" + , ReturnType (UArray UVector) + , [UArray URowVector; UMatrix] + , AoS ) ; + add_unqualified + ("multi_student_t_rng", ReturnType UVector, [UReal; UVector; UMatrix], AoS) ; + add_unqualified + ( "multi_student_t_rng" + , ReturnType (UArray UVector) + , [UReal; UArray UVector; UMatrix] + , AoS ) ; + add_unqualified + ( "multi_student_t_rng" + , ReturnType UVector + , [UReal; URowVector; UMatrix] + , AoS ) ; + add_unqualified + ( "multi_student_t_rng" + , ReturnType (UArray UVector) + , [UReal; UArray URowVector; UMatrix] + , AoS ) ; + add_unqualified + ("multinomial_logit_rng", ReturnType (UArray UInt), [UVector; UInt], AoS) ; + add_unqualified + ("multinomial_rng", ReturnType (UArray UInt), [UVector; UInt], AoS) ; + add_unqualified ("multiply", ReturnType UComplex, [UComplex; UComplex], AoS) ; + add_unqualified ("multiply", ReturnType UInt, [UInt; UInt], SoA) ; + add_unqualified ("multiply", ReturnType UReal, [UReal; UReal], SoA) ; + add_unqualified ("multiply", ReturnType UVector, [UVector; UReal], SoA) ; + add_unqualified ("multiply", ReturnType URowVector, [URowVector; UReal], SoA) ; + add_unqualified ("multiply", ReturnType UMatrix, [UMatrix; UReal], SoA) ; + add_unqualified ("multiply", ReturnType UReal, [URowVector; UVector], SoA) ; + add_unqualified ("multiply", ReturnType UMatrix, [UVector; URowVector], SoA) ; + add_unqualified ("multiply", ReturnType UVector, [UMatrix; UVector], SoA) ; + add_unqualified ("multiply", ReturnType URowVector, [URowVector; UMatrix], SoA) ; + add_unqualified ("multiply", ReturnType UMatrix, [UMatrix; UMatrix], SoA) ; + add_unqualified ("multiply", ReturnType UVector, [UReal; UVector], SoA) ; + add_unqualified ("multiply", ReturnType URowVector, [UReal; URowVector], SoA) ; + add_unqualified ("multiply", ReturnType UMatrix, [UReal; UMatrix], SoA) ; + (* TODO more complex overloads *) + add_unqualified + ( "multiply" + , ReturnType UComplexMatrix + , [UComplexMatrix; UComplexMatrix] + , SoA ) ; + add_unqualified + ("multiply", ReturnType UComplexMatrix, [UComplexMatrix; UComplex], SoA) ; + add_unqualified + ("multiply", ReturnType UComplexMatrix, [UComplex; UComplexMatrix], SoA) ; + add_unqualified + ( "multiply" + , ReturnType UComplexMatrix + , [UComplexVector; UComplexRowVector] + , SoA ) ; + add_unqualified + ("multiply", ReturnType UComplex, [UComplexRowVector; UComplexVector], SoA) ; + add_unqualified + ( "multiply" + , ReturnType UComplexVector + , [UComplexMatrix; UComplexVector] + , SoA ) ; + add_unqualified + ("multiply", ReturnType UComplexVector, [UComplexVector; UComplex], SoA) ; + add_unqualified + ("multiply", ReturnType UComplexVector, [UComplex; UComplexVector], SoA) ; + add_unqualified + ( "multiply" + , ReturnType UComplexRowVector + , [UComplexRowVector; UComplex] + , SoA ) ; + add_unqualified + ( "multiply" + , ReturnType UComplexRowVector + , [UComplex; UComplexRowVector] + , SoA ) ; + add_unqualified + ( "multiply" + , ReturnType UComplexRowVector + , [UComplexRowVector; UComplexMatrix] + , SoA ) ; + add_binary_vec "multiply_log" SoA ; + add_unqualified + ("multiply_lower_tri_self_transpose", ReturnType UMatrix, [UMatrix], SoA) ; + add_unqualified + ( "neg_binomial_2_log_glm_lpmf" + , ReturnType UReal + , [UArray UInt; UMatrix; UVector; UVector; UReal] + , SoA ) ; + add_unqualified + ( "neg_binomial_2_log_glm_lpmf" + , ReturnType UReal + , [UInt; UMatrix; UVector; UVector; UReal] + , SoA ) ; + add_unqualified + ( "neg_binomial_2_log_glm_lpmf" + , ReturnType UReal + , [UArray UInt; URowVector; UReal; UVector; UReal] + , SoA ) ; + add_unqualified + ( "neg_binomial_2_log_glm_lpmf" + , ReturnType UReal + , [UArray UInt; URowVector; UVector; UVector; UReal] + , SoA ) ; + add_nullary "negative_infinity" ; + add_unqualified ("norm", ReturnType UReal, [UComplex], AoS) ; + add_unqualified + ( "normal_id_glm_lpdf" + , ReturnType UReal + , [UVector; UMatrix; UVector; UVector; UReal] + , SoA ) ; + add_unqualified + ( "normal_id_glm_lpdf" + , ReturnType UReal + , [UReal; UMatrix; UReal; UVector; UReal] + , SoA ) ; + add_unqualified + ( "normal_id_glm_lpdf" + , ReturnType UReal + , [UReal; UMatrix; UVector; UVector; UReal] + , SoA ) ; + add_unqualified + ( "normal_id_glm_lpdf" + , ReturnType UReal + , [UReal; UMatrix; UReal; UVector; UVector] + , SoA ) ; + add_unqualified + ( "normal_id_glm_lpdf" + , ReturnType UReal + , [UReal; UMatrix; UVector; UVector; UVector] + , SoA ) ; + add_unqualified + ( "normal_id_glm_lpdf" + , ReturnType UReal + , [UVector; URowVector; UReal; UVector; UVector] + , SoA ) ; + add_unqualified + ( "normal_id_glm_lpdf" + , ReturnType UReal + , [UVector; URowVector; UVector; UVector; UReal] + , SoA ) ; + add_unqualified + ( "normal_id_glm_lpdf" + , ReturnType UReal + , [UVector; URowVector; UVector; UVector; UVector] + , SoA ) ; + add_unqualified + ( "normal_id_glm_lpdf" + , ReturnType UReal + , [UVector; URowVector; UReal; UVector; UReal] + , SoA ) ; + add_unqualified + ( "normal_id_glm_lpdf" + , ReturnType UReal + , [UVector; UMatrix; UReal; UVector; UVector] + , SoA ) ; + add_unqualified + ( "normal_id_glm_lpdf" + , ReturnType UReal + , [UVector; UMatrix; UVector; UVector; UVector] + , SoA ) ; + add_nullary "not_a_number" ; + add_unqualified ("num_elements", ReturnType UInt, [UMatrix], SoA) ; + add_unqualified ("num_elements", ReturnType UInt, [UVector], SoA) ; + add_unqualified ("num_elements", ReturnType UInt, [URowVector], SoA) ; + add_unqualified ("num_elements", ReturnType UInt, [UComplexMatrix], SoA) ; + add_unqualified ("num_elements", ReturnType UInt, [UComplexVector], SoA) ; + add_unqualified ("num_elements", ReturnType UInt, [UComplexRowVector], SoA) ; + List.iter + ~f:(fun i -> + List.iter + ~f:(fun t -> + add_unqualified + ("num_elements", ReturnType UInt, [bare_array_type (t, i)], SoA) ) + bare_types ) + (List.range 1 10) ; + add_unqualified + ("one_hot_int_array", ReturnType (UArray UInt), [UInt; UInt], SoA) ; + add_unqualified ("one_hot_array", ReturnType (UArray UReal), [UInt; UInt], SoA) ; + add_unqualified + ("one_hot_row_vector", ReturnType URowVector, [UInt; UInt], SoA) ; + add_unqualified ("one_hot_vector", ReturnType UVector, [UInt; UInt], SoA) ; + add_unqualified ("ones_int_array", ReturnType (UArray UInt), [UInt], SoA) ; + add_unqualified ("ones_array", ReturnType (UArray UReal), [UInt], SoA) ; + add_unqualified ("ones_row_vector", ReturnType URowVector, [UInt], SoA) ; + add_unqualified ("ones_vector", ReturnType UVector, [UInt], SoA) ; + add_unqualified + ( "ordered_logistic_glm_lpmf" + , ReturnType UReal + , [UArray UInt; URowVector; UVector; UVector] + , SoA ) ; + add_unqualified + ( "ordered_logistic_glm_lpmf" + , ReturnType UReal + , [UInt; URowVector; UVector; UVector] + , SoA ) ; + add_unqualified + ( "ordered_logistic_log" + , ReturnType UReal + , [UArray UInt; UVector; UVector] + , SoA ) ; + add_unqualified + ( "ordered_logistic_log" + , ReturnType UReal + , [UArray UInt; UVector; UArray UVector] + , SoA ) ; + add_unqualified + ( "ordered_logistic_lpmf" + , ReturnType UReal + , [UArray UInt; UVector; UVector] + , SoA ) ; + add_unqualified + ( "ordered_logistic_lpmf" + , ReturnType UReal + , [UArray UInt; UVector; UArray UVector] + , SoA ) ; + add_unqualified + ("ordered_logistic_rng", ReturnType UInt, [UReal; UVector], AoS) ; + add_unqualified + ( "ordered_probit_log" + , ReturnType UReal + , [UArray UInt; UVector; UVector] + , AoS ) ; + add_unqualified + ( "ordered_probit_log" + , ReturnType UReal + , [UArray UInt; UVector; UArray UVector] + , AoS ) ; + add_unqualified + ("ordered_probit_lpmf", ReturnType UReal, [UArray UInt; UReal; UVector], AoS) ; + add_unqualified + ( "ordered_probit_lpmf" + , ReturnType UReal + , [UArray UInt; UReal; UArray UVector] + , AoS ) ; + add_unqualified + ( "ordered_probit_lpmf" + , ReturnType UReal + , [UArray UInt; UVector; UVector] + , AoS ) ; + add_unqualified + ( "ordered_probit_lpmf" + , ReturnType UReal + , [UArray UInt; UVector; UArray UVector] + , AoS ) ; + add_unqualified ("ordered_probit_rng", ReturnType UInt, [UReal; UVector], AoS) ; + add_binary_vec_real_real "owens_t" AoS ; + add_nullary "pi" ; + add_unqualified ("plus", ReturnType UComplex, [UComplex], AoS) ; + add_unqualified ("plus", ReturnType UInt, [UInt], SoA) ; + add_unqualified ("plus", ReturnType UReal, [UReal], SoA) ; + add_unqualified ("plus", ReturnType UVector, [UVector], SoA) ; + add_unqualified ("plus", ReturnType URowVector, [URowVector], SoA) ; + add_unqualified ("plus", ReturnType UMatrix, [UMatrix], SoA) ; + add_unqualified + ( "poisson_log_glm_lpmf" + , ReturnType UReal + , [UArray UInt; UMatrix; UVector; UVector] + , SoA ) ; + add_unqualified + ( "poisson_log_glm_lpmf" + , ReturnType UReal + , [UInt; UMatrix; UVector; UVector] + , SoA ) ; + add_unqualified + ( "poisson_log_glm_lpmf" + , ReturnType UReal + , [UArray UInt; URowVector; UReal; UVector] + , SoA ) ; + add_unqualified + ( "poisson_log_glm_lpmf" + , ReturnType UReal + , [UArray UInt; URowVector; UVector; UVector] + , SoA ) ; + add_unqualified ("polar", ReturnType UComplex, [UReal; UReal], AoS) ; + add_nullary "positive_infinity" ; + add_binary_vec "pow" AoS ; + add_unqualified ("pow", ReturnType UComplex, [UComplex; UReal], AoS) ; + add_unqualified ("pow", ReturnType UComplex, [UComplex; UComplex], AoS) ; + add_unqualified ("prod", ReturnType UInt, [UArray UInt], AoS) ; + add_unqualified ("prod", ReturnType UReal, [UArray UReal], AoS) ; + add_unqualified ("prod", ReturnType UReal, [UVector], AoS) ; + add_unqualified ("prod", ReturnType UReal, [URowVector], AoS) ; + add_unqualified ("prod", ReturnType UReal, [UMatrix], AoS) ; + add_unqualified ("prod", ReturnType UComplex, [UArray UComplex], AoS) ; + add_unqualified ("prod", ReturnType UComplex, [UComplexVector], AoS) ; + add_unqualified ("prod", ReturnType UComplex, [UComplexRowVector], AoS) ; + add_unqualified ("prod", ReturnType UComplex, [UComplexMatrix], AoS) ; + add_unqualified ("proj", ReturnType UComplex, [UComplex], AoS) ; + add_unqualified ("quad_form", ReturnType UReal, [UMatrix; UVector], SoA) ; + add_unqualified ("quad_form", ReturnType UMatrix, [UMatrix; UMatrix], SoA) ; + add_unqualified ("quad_form_sym", ReturnType UReal, [UMatrix; UVector], AoS) ; + add_unqualified ("quad_form_sym", ReturnType UMatrix, [UMatrix; UMatrix], AoS) ; + add_unqualified ("quad_form_diag", ReturnType UMatrix, [UMatrix; UVector], AoS) ; + add_unqualified + ("quad_form_diag", ReturnType UMatrix, [UMatrix; URowVector], AoS) ; + add_qualified + ( "quantile" + , ReturnType UReal + , [(DataOnly, UArray UReal); (DataOnly, UReal)] + , SoA ) ; + add_qualified + ( "quantile" + , ReturnType (UArray UReal) + , [(DataOnly, UArray UReal); (DataOnly, UArray UReal)] + , SoA ) ; + add_qualified + ("quantile", ReturnType UReal, [(DataOnly, UVector); (DataOnly, UReal)], SoA) ; + add_qualified + ( "quantile" + , ReturnType (UArray UReal) + , [(DataOnly, UVector); (DataOnly, UArray UReal)] + , SoA ) ; + add_qualified + ( "quantile" + , ReturnType UReal + , [(DataOnly, URowVector); (DataOnly, UReal)] + , SoA ) ; + add_qualified + ( "quantile" + , ReturnType (UArray UReal) + , [(DataOnly, URowVector); (DataOnly, UArray UReal)] + , SoA ) ; + add_unqualified ("rank", ReturnType UInt, [UArray UInt; UInt], AoS) ; + add_unqualified ("rank", ReturnType UInt, [UArray UReal; UInt], AoS) ; + add_unqualified ("rank", ReturnType UInt, [UVector; UInt], AoS) ; + add_unqualified ("rank", ReturnType UInt, [URowVector; UInt], AoS) ; + add_unqualified ("append_row", ReturnType UMatrix, [UMatrix; UMatrix], AoS) ; + add_unqualified ("append_row", ReturnType UMatrix, [URowVector; UMatrix], AoS) ; + add_unqualified ("append_row", ReturnType UMatrix, [UMatrix; URowVector], AoS) ; + add_unqualified + ("append_row", ReturnType UMatrix, [URowVector; URowVector], AoS) ; + add_unqualified ("append_row", ReturnType UVector, [UVector; UVector], AoS) ; + add_unqualified ("append_row", ReturnType UVector, [UReal; UVector], AoS) ; + add_unqualified ("append_row", ReturnType UVector, [UVector; UReal], AoS) ; + add_unqualified + ( "append_row" + , ReturnType UComplexMatrix + , [UComplexMatrix; UComplexMatrix] + , AoS ) ; + add_unqualified + ( "append_row" + , ReturnType UComplexMatrix + , [UComplexRowVector; UComplexMatrix] + , AoS ) ; + add_unqualified + ( "append_row" + , ReturnType UComplexMatrix + , [UComplexMatrix; UComplexRowVector] + , AoS ) ; + add_unqualified + ( "append_row" + , ReturnType UComplexMatrix + , [UComplexRowVector; UComplexRowVector] + , AoS ) ; + add_unqualified + ( "append_row" + , ReturnType UComplexVector + , [UComplexVector; UComplexVector] + , AoS ) ; + add_unqualified + ("append_row", ReturnType UComplexVector, [UComplex; UComplexVector], AoS) ; + add_unqualified + ("append_row", ReturnType UComplexVector, [UComplexVector; UComplex], AoS) ; + List.iter + ~f:(fun t -> + add_unqualified + ("rep_array", ReturnType (bare_array_type (t, 1)), [t; UInt], SoA) ; + add_unqualified + ("rep_array", ReturnType (bare_array_type (t, 2)), [t; UInt; UInt], SoA) ; + add_unqualified + ( "rep_array" + , ReturnType (bare_array_type (t, 3)) + , [t; UInt; UInt; UInt] + , SoA ) ; + List.iter + ~f:(fun j -> + add_unqualified + ( "rep_array" + , ReturnType (bare_array_type (t, j + 1)) + , [bare_array_type (t, j); UInt] + , SoA ) ; + add_unqualified + ( "rep_array" + , ReturnType (bare_array_type (t, j + 2)) + , [bare_array_type (t, j); UInt; UInt] + , SoA ) ; + add_unqualified + ( "rep_array" + , ReturnType (bare_array_type (t, j + 3)) + , [bare_array_type (t, j); UInt; UInt; UInt] + , SoA ) ) + (List.range 1 3) ) + bare_types ; + add_unqualified ("rep_matrix", ReturnType UMatrix, [UReal; UInt; UInt], SoA) ; + add_unqualified ("rep_matrix", ReturnType UMatrix, [UVector; UInt], AoS) ; + add_unqualified ("rep_matrix", ReturnType UMatrix, [URowVector; UInt], AoS) ; + add_unqualified + ("rep_matrix", ReturnType UComplexMatrix, [UComplex; UInt; UInt], AoS) ; + add_unqualified + ("rep_matrix", ReturnType UComplexMatrix, [UComplexVector; UInt], AoS) ; + add_unqualified + ("rep_matrix", ReturnType UComplexMatrix, [UComplexRowVector; UInt], AoS) ; + add_unqualified ("rep_row_vector", ReturnType URowVector, [UReal; UInt], SoA) ; + add_unqualified + ("rep_row_vector", ReturnType UComplexRowVector, [UComplex; UInt], AoS) ; + add_unqualified ("rep_vector", ReturnType UVector, [UReal; UInt], SoA) ; + add_unqualified + ("rep_vector", ReturnType UComplexVector, [UComplex; UInt], AoS) ; + add_unqualified ("reverse", ReturnType UVector, [UVector], SoA) ; + add_unqualified ("reverse", ReturnType URowVector, [URowVector], SoA) ; + List.iter + ~f:(fun i -> + List.iter + ~f:(fun t -> + add_unqualified + ( "reverse" + , ReturnType (bare_array_type (t, i)) + , [bare_array_type (t, i)] + , SoA ) ) + bare_types ) + (List.range 1 8) ; + add_unqualified ("reverse", ReturnType UComplexVector, [UComplexVector], SoA) ; + add_unqualified + ("reverse", ReturnType UComplexRowVector, [UComplexRowVector], SoA) ; + add_binary_vec_int_int "rising_factorial" AoS ; + add_binary_vec_real_int "rising_factorial" AoS ; + add_unqualified ("row", ReturnType URowVector, [UMatrix; UInt], SoA) ; + add_unqualified + ("row", ReturnType UComplexRowVector, [UComplexMatrix; UInt], AoS) ; + add_unqualified ("rows", ReturnType UInt, [UVector], SoA) ; + add_unqualified ("rows", ReturnType UInt, [URowVector], SoA) ; + add_unqualified ("rows", ReturnType UInt, [UMatrix], SoA) ; + add_unqualified ("rows", ReturnType UInt, [UComplexVector], SoA) ; + add_unqualified ("rows", ReturnType UInt, [UComplexRowVector], SoA) ; + add_unqualified ("rows", ReturnType UInt, [UComplexMatrix], SoA) ; + add_unqualified + ("rows_dot_product", ReturnType UVector, [UVector; UVector], AoS) ; + add_unqualified + ("rows_dot_product", ReturnType UVector, [URowVector; URowVector], AoS) ; + add_unqualified + ("rows_dot_product", ReturnType UVector, [UMatrix; UMatrix], SoA) ; + add_unqualified + ( "rows_dot_product" + , ReturnType UComplexVector + , [UComplexVector; UComplexVector] + , AoS ) ; + add_unqualified + ( "rows_dot_product" + , ReturnType UComplexVector + , [UComplexRowVector; UComplexRowVector] + , AoS ) ; + add_unqualified + ( "rows_dot_product" + , ReturnType UComplexVector + , [UComplexMatrix; UComplexMatrix] + , AoS ) ; + add_unqualified ("rows_dot_self", ReturnType UVector, [UVector], SoA) ; + add_unqualified ("rows_dot_self", ReturnType UVector, [URowVector], SoA) ; + add_unqualified ("rows_dot_self", ReturnType UVector, [UMatrix], SoA) ; + add_unqualified + ("rows_dot_self", ReturnType UComplexVector, [UComplexVector], AoS) ; + add_unqualified + ("rows_dot_self", ReturnType UComplexVector, [UComplexRowVector], AoS) ; + add_unqualified + ("rows_dot_self", ReturnType UComplexVector, [UComplexMatrix], AoS) ; + add_unqualified + ( "scale_matrix_exp_multiply" + , ReturnType UMatrix + , [UReal; UMatrix; UMatrix] + , AoS ) ; + add_unqualified ("sd", ReturnType UReal, [UArray UReal], SoA) ; + add_unqualified ("sd", ReturnType UReal, [UVector], SoA) ; + add_unqualified ("sd", ReturnType UReal, [URowVector], SoA) ; + add_unqualified ("sd", ReturnType UReal, [UMatrix], SoA) ; + add_unqualified + ("segment", ReturnType URowVector, [URowVector; UInt; UInt], SoA) ; + add_unqualified ("segment", ReturnType UVector, [UVector; UInt; UInt], SoA) ; + add_unqualified + ( "segment" + , ReturnType UComplexRowVector + , [UComplexRowVector; UInt; UInt] + , AoS ) ; + add_unqualified + ("segment", ReturnType UComplexVector, [UComplexVector; UInt; UInt], AoS) ; + List.iter + ~f:(fun t -> + List.iter + ~f:(fun j -> + add_unqualified + ( "segment" + , ReturnType (bare_array_type (t, j)) + , [bare_array_type (t, j); UInt; UInt] + , SoA ) ) + (List.range 1 4) ) + bare_types ; + add_unqualified ("sin", ReturnType UComplex, [UComplex], AoS) ; + add_unqualified ("sinh", ReturnType UComplex, [UComplex], AoS) ; + add_unqualified ("singular_values", ReturnType UVector, [UMatrix], SoA) ; + List.iter + ~f:(fun i -> + List.iter + ~f:(fun t -> + add_unqualified + ("size", ReturnType UInt, [bare_array_type (t, i)], SoA) ) + bare_types ) + (List.range 1 8) ; + List.iter + ~f:(fun t -> add_unqualified ("size", ReturnType UInt, [t], SoA)) + bare_types ; + add_unqualified ("softmax", ReturnType UVector, [UVector], SoA) ; + add_unqualified ("sort_asc", ReturnType (UArray UInt), [UArray UInt], AoS) ; + add_unqualified ("sort_asc", ReturnType (UArray UReal), [UArray UReal], AoS) ; + add_unqualified ("sort_asc", ReturnType UVector, [UVector], AoS) ; + add_unqualified ("sort_asc", ReturnType URowVector, [URowVector], AoS) ; + add_unqualified ("sort_desc", ReturnType (UArray UInt), [UArray UInt], AoS) ; + add_unqualified ("sort_desc", ReturnType (UArray UReal), [UArray UReal], AoS) ; + add_unqualified ("sort_desc", ReturnType UVector, [UVector], AoS) ; + add_unqualified ("sort_desc", ReturnType URowVector, [URowVector], AoS) ; + add_unqualified + ("sort_indices_asc", ReturnType (UArray UInt), [UArray UInt], AoS) ; + add_unqualified + ("sort_indices_asc", ReturnType (UArray UInt), [UArray UReal], AoS) ; + add_unqualified ("sort_indices_asc", ReturnType (UArray UInt), [UVector], AoS) ; + add_unqualified + ("sort_indices_asc", ReturnType (UArray UInt), [URowVector], AoS) ; + add_unqualified + ("sort_indices_desc", ReturnType (UArray UInt), [UArray UInt], AoS) ; + add_unqualified + ("sort_indices_desc", ReturnType (UArray UInt), [UArray UReal], AoS) ; + add_unqualified ("sort_indices_desc", ReturnType (UArray UInt), [UVector], AoS) ; + add_unqualified + ("sort_indices_desc", ReturnType (UArray UInt), [URowVector], AoS) ; + add_unqualified ("squared_distance", ReturnType UReal, [UReal; UReal], SoA) ; + add_unqualified ("squared_distance", ReturnType UReal, [UVector; UVector], SoA) ; + add_unqualified + ("squared_distance", ReturnType UReal, [URowVector; URowVector], SoA) ; + add_unqualified + ("squared_distance", ReturnType UReal, [UVector; URowVector], SoA) ; + add_unqualified + ("squared_distance", ReturnType UReal, [URowVector; UVector], SoA) ; + add_unqualified ("sqrt", ReturnType UComplex, [UComplex], AoS) ; + add_nullary "sqrt2" ; + add_unqualified + ("sub_col", ReturnType UVector, [UMatrix; UInt; UInt; UInt], SoA) ; + add_unqualified + ( "sub_col" + , ReturnType UComplexVector + , [UComplexMatrix; UInt; UInt; UInt] + , AoS ) ; + add_unqualified + ("sub_row", ReturnType URowVector, [UMatrix; UInt; UInt; UInt], SoA) ; + add_unqualified + ( "sub_row" + , ReturnType UComplexRowVector + , [UComplexMatrix; UInt; UInt; UInt] + , AoS ) ; + List.iter + ~f:(fun x -> add_unqualified ("subtract", ReturnType x, [x; x], SoA)) + bare_types ; + add_unqualified ("subtract", ReturnType UVector, [UVector; UReal], SoA) ; + add_unqualified ("subtract", ReturnType URowVector, [URowVector; UReal], SoA) ; + add_unqualified ("subtract", ReturnType UMatrix, [UMatrix; UReal], SoA) ; + add_unqualified ("subtract", ReturnType UVector, [UReal; UVector], SoA) ; + add_unqualified ("subtract", ReturnType URowVector, [UReal; URowVector], SoA) ; + add_unqualified ("subtract", ReturnType UMatrix, [UReal; UMatrix], SoA) ; + add_unqualified ("sum", ReturnType UInt, [UArray UInt], SoA) ; + add_unqualified ("sum", ReturnType UReal, [UArray UReal], SoA) ; + add_unqualified ("sum", ReturnType UReal, [UVector], SoA) ; + add_unqualified ("sum", ReturnType UReal, [URowVector], SoA) ; + add_unqualified ("sum", ReturnType UReal, [UMatrix], SoA) ; + add_unqualified ("sum", ReturnType UComplex, [UArray UComplex], SoA) ; + add_unqualified ("sum", ReturnType UComplex, [UComplexVector], SoA) ; + add_unqualified ("sum", ReturnType UComplex, [UComplexRowVector], SoA) ; + add_unqualified ("sum", ReturnType UComplex, [UComplexMatrix], SoA) ; + add_unqualified ("svd_U", ReturnType UMatrix, [UMatrix], SoA) ; + add_unqualified ("svd_V", ReturnType UMatrix, [UMatrix], SoA) ; + add_unqualified + ("symmetrize_from_lower_tri", ReturnType UMatrix, [UMatrix], AoS) ; + add_unqualified + ( "symmetrize_from_lower_tri" + , ReturnType UComplexMatrix + , [UComplexMatrix] + , AoS ) ; + add_unqualified ("tail", ReturnType URowVector, [URowVector; UInt], SoA) ; + add_unqualified ("tail", ReturnType UVector, [UVector; UInt], SoA) ; + add_unqualified + ("tail", ReturnType UComplexRowVector, [UComplexRowVector; UInt], AoS) ; + add_unqualified + ("tail", ReturnType UComplexVector, [UComplexVector; UInt], AoS) ; + List.iter + ~f:(fun t -> + List.iter + ~f:(fun j -> + add_unqualified + ( "tail" + , ReturnType (bare_array_type (t, j)) + , [bare_array_type (t, j); UInt] + , SoA ) ) + (List.range 1 4) ) + bare_types ; + add_unqualified ("tan", ReturnType UComplex, [UComplex], AoS) ; + add_unqualified ("tanh", ReturnType UComplex, [UComplex], AoS) ; + add_unqualified ("tcrossprod", ReturnType UMatrix, [UMatrix], SoA) ; + add_unqualified ("to_array_1d", ReturnType (UArray UReal), [UMatrix], AoS) ; + add_unqualified ("to_array_1d", ReturnType (UArray UReal), [UVector], AoS) ; + add_unqualified ("to_array_1d", ReturnType (UArray UReal), [URowVector], AoS) ; + add_unqualified + ("to_array_1d", ReturnType (UArray UComplex), [UComplexMatrix], AoS) ; + add_unqualified + ("to_array_1d", ReturnType (UArray UComplex), [UComplexVector], AoS) ; + add_unqualified + ("to_array_1d", ReturnType (UArray UComplex), [UComplexRowVector], AoS) ; + List.iter + ~f:(fun i -> + add_unqualified + ( "to_array_1d" + , ReturnType (UArray UReal) + , [bare_array_type (UReal, i)] + , AoS ) ; + add_unqualified + ( "to_array_1d" + , ReturnType (UArray UInt) + , [bare_array_type (UInt, i)] + , AoS ) ) + (List.range 1 10) ; + add_unqualified + ("to_array_2d", ReturnType (bare_array_type (UReal, 2)), [UMatrix], AoS) ; + add_unqualified + ( "to_array_2d" + , ReturnType (bare_array_type (UComplex, 2)) + , [UComplexMatrix] + , AoS ) ; + add_unqualified ("to_complex", ReturnType UComplex, [], AoS) ; + add_unqualified ("to_complex", ReturnType UComplex, [UReal; UReal], AoS) ; + add_unqualified ("to_complex", ReturnType UComplex, [UReal], AoS) ; + add_unqualified ("to_matrix", ReturnType UMatrix, [UMatrix], AoS) ; + add_unqualified ("to_matrix", ReturnType UMatrix, [UMatrix; UInt; UInt], AoS) ; + add_unqualified + ("to_matrix", ReturnType UMatrix, [UMatrix; UInt; UInt; UInt], AoS) ; + add_unqualified ("to_matrix", ReturnType UMatrix, [UVector], AoS) ; + add_unqualified ("to_matrix", ReturnType UMatrix, [UVector; UInt; UInt], AoS) ; + add_unqualified + ("to_matrix", ReturnType UMatrix, [UVector; UInt; UInt; UInt], AoS) ; + add_unqualified ("to_matrix", ReturnType UMatrix, [URowVector], AoS) ; + add_unqualified ("to_matrix", ReturnType UMatrix, [UArray URowVector], AoS) ; + add_unqualified + ("to_matrix", ReturnType UMatrix, [URowVector; UInt; UInt], AoS) ; + add_unqualified + ("to_matrix", ReturnType UMatrix, [URowVector; UInt; UInt; UInt], AoS) ; + add_unqualified + ("to_matrix", ReturnType UMatrix, [UArray UReal; UInt; UInt], AoS) ; + add_unqualified + ("to_matrix", ReturnType UMatrix, [UArray UReal; UInt; UInt; UInt], AoS) ; + add_unqualified + ("to_matrix", ReturnType UMatrix, [UArray UInt; UInt; UInt], AoS) ; + add_unqualified + ("to_matrix", ReturnType UMatrix, [UArray UInt; UInt; UInt; UInt], AoS) ; + add_unqualified + ("to_matrix", ReturnType UMatrix, [bare_array_type (UReal, 2)], AoS) ; + add_unqualified + ("to_matrix", ReturnType UMatrix, [bare_array_type (UInt, 2)], AoS) ; + add_unqualified ("to_matrix", ReturnType UComplexMatrix, [UComplexMatrix], AoS) ; + add_unqualified + ("to_matrix", ReturnType UComplexMatrix, [UComplexMatrix; UInt; UInt], AoS) ; + add_unqualified + ( "to_matrix" + , ReturnType UComplexMatrix + , [UComplexMatrix; UInt; UInt; UInt] + , AoS ) ; + add_unqualified ("to_matrix", ReturnType UComplexMatrix, [UComplexVector], AoS) ; + add_unqualified + ("to_matrix", ReturnType UComplexMatrix, [UComplexVector; UInt; UInt], AoS) ; + add_unqualified + ( "to_matrix" + , ReturnType UComplexMatrix + , [UComplexVector; UInt; UInt; UInt] + , AoS ) ; + add_unqualified + ("to_matrix", ReturnType UComplexMatrix, [UComplexRowVector], AoS) ; + add_unqualified + ("to_matrix", ReturnType UComplexMatrix, [UArray UComplexRowVector], AoS) ; + add_unqualified + ( "to_matrix" + , ReturnType UComplexMatrix + , [UComplexRowVector; UInt; UInt] + , AoS ) ; + add_unqualified + ( "to_matrix" + , ReturnType UComplexMatrix + , [UComplexRowVector; UInt; UInt; UInt] + , AoS ) ; + add_unqualified + ("to_matrix", ReturnType UComplexMatrix, [UArray UComplex; UInt; UInt], AoS) ; + add_unqualified + ( "to_matrix" + , ReturnType UComplexMatrix + , [UArray UComplex; UInt; UInt; UInt] + , AoS ) ; + add_unqualified + ( "to_matrix" + , ReturnType UComplexMatrix + , [bare_array_type (UComplex, 2)] + , AoS ) ; + add_unqualified ("to_row_vector", ReturnType URowVector, [UMatrix], AoS) ; + add_unqualified ("to_row_vector", ReturnType URowVector, [UVector], AoS) ; + add_unqualified ("to_row_vector", ReturnType URowVector, [URowVector], AoS) ; + add_unqualified ("to_row_vector", ReturnType URowVector, [UArray UReal], AoS) ; + add_unqualified ("to_row_vector", ReturnType URowVector, [UArray UInt], AoS) ; + add_unqualified + ("to_row_vector", ReturnType UComplexRowVector, [UComplexMatrix], AoS) ; + add_unqualified + ("to_row_vector", ReturnType UComplexRowVector, [UComplexVector], AoS) ; + add_unqualified + ("to_row_vector", ReturnType UComplexRowVector, [UComplexRowVector], AoS) ; + add_unqualified + ("to_row_vector", ReturnType UComplexRowVector, [UArray UComplex], AoS) ; + add_unqualified ("to_vector", ReturnType UVector, [UMatrix], SoA) ; + add_unqualified ("to_vector", ReturnType UVector, [UVector], SoA) ; + add_unqualified ("to_vector", ReturnType UVector, [URowVector], SoA) ; + add_unqualified ("to_vector", ReturnType UVector, [UArray UReal], AoS) ; + add_unqualified ("to_vector", ReturnType UVector, [UArray UInt], AoS) ; + add_unqualified ("to_vector", ReturnType UComplexVector, [UComplexMatrix], AoS) ; + add_unqualified ("to_vector", ReturnType UComplexVector, [UComplexVector], AoS) ; + add_unqualified + ("to_vector", ReturnType UComplexVector, [UComplexRowVector], AoS) ; + add_unqualified + ("to_vector", ReturnType UComplexVector, [UArray UComplex], AoS) ; + add_unqualified ("trace", ReturnType UReal, [UMatrix], SoA) ; + add_unqualified ("trace", ReturnType UComplex, [UComplexMatrix], AoS) ; + add_unqualified + ("trace_gen_quad_form", ReturnType UReal, [UMatrix; UMatrix; UMatrix], SoA) ; + add_unqualified ("trace_quad_form", ReturnType UReal, [UMatrix; UVector], SoA) ; + add_unqualified ("trace_quad_form", ReturnType UReal, [UMatrix; UMatrix], SoA) ; + add_unqualified ("transpose", ReturnType URowVector, [UVector], SoA) ; + add_unqualified ("transpose", ReturnType UVector, [URowVector], SoA) ; + add_unqualified ("transpose", ReturnType UMatrix, [UMatrix], SoA) ; + add_unqualified + ("transpose", ReturnType UComplexRowVector, [UComplexVector], SoA) ; + add_unqualified + ("transpose", ReturnType UComplexVector, [UComplexRowVector], SoA) ; + add_unqualified ("transpose", ReturnType UComplexMatrix, [UComplexMatrix], SoA) ; + add_unqualified ("uniform_simplex", ReturnType UVector, [UInt], SoA) ; + add_unqualified ("variance", ReturnType UReal, [UArray UReal], SoA) ; + add_unqualified ("variance", ReturnType UReal, [UVector], SoA) ; + add_unqualified ("variance", ReturnType UReal, [URowVector], SoA) ; + add_unqualified ("variance", ReturnType UReal, [UMatrix], SoA) ; + add_unqualified ("wishart_rng", ReturnType UMatrix, [UReal; UMatrix], AoS) ; + add_unqualified ("zeros_int_array", ReturnType (UArray UInt), [UInt], SoA) ; + add_unqualified ("zeros_array", ReturnType (UArray UReal), [UInt], SoA) ; + add_unqualified ("zeros_row_vector", ReturnType URowVector, [UInt], SoA) ; + add_unqualified ("zeros_vector", ReturnType UVector, [UInt], SoA) ; + (* Now add all the manually added stuff to the main hashtable used + for type-checking *) + Hashtbl.iteri manual_stan_math_signatures ~f:(fun ~key ~data -> + List.iter data ~f:(fun data -> + Hashtbl.add_multi function_signatures ~key ~data ) ) + +let%expect_test "declarative distributions" = + let special_suffixes = + String.Set.of_list + Utils.(["lpmf"; "lpdf"; "log"] @ cumulative_distribution_suffices_w_rng) + in + let d = + distributions + |> List.map ~f:(function _, n, _, _ -> n) + |> String.Set.of_list in + Hashtbl.keys function_signatures + |> List.filter ~f:(fun name -> + match Utils.split_distribution_suffix name with + | Some (name, suffix) + when Set.mem special_suffixes suffix && not (Set.mem d name) -> + true + | _ -> false ) + |> Fmt.str "@[%a@]" Fmt.(list ~sep:cut string) + |> print_endline ; + [%expect {| + binomial_coefficient_log + multiply_log + lkj_cov_log |}] diff --git a/src/stan_math_backend/Stan_math_library.mli b/src/stan_math_backend/Stan_math_library.mli new file mode 100644 index 0000000000..f22e5f4296 --- /dev/null +++ b/src/stan_math_backend/Stan_math_library.mli @@ -0,0 +1,41 @@ +(** This module stores a table of all signatures from the Stan + math C++ library which are exposed to Stan, and some helper + functions for dealing with those signatures. +*) + +open Middle +open Frontend.Std_library_utils +include Library + +val pretty_print_all_math_sigs : unit Fmt.t +val pretty_print_all_math_distributions : unit Fmt.t + +(* TODO: We should think of a better encapsulization for these, + this doesn't scale well. +*) + +(* reduce_sum helpers *) +val is_reduce_sum_fn : string -> bool +val reduce_sum_slice_types : UnsizedType.t list + +(* variadic ODE helpers *) +val is_variadic_ode_fn : string -> bool +val is_variadic_ode_nonadjoint_tol_fn : string -> bool +val ode_tolerances_suffix : string +val variadic_ode_adjoint_fn : string +val variadic_ode_mandatory_arg_types : fun_arg list +val variadic_ode_mandatory_fun_args : fun_arg list +val variadic_ode_tol_arg_types : fun_arg list +val variadic_ode_adjoint_ctl_tol_arg_types : fun_arg list +val variadic_ode_fun_return_type : UnsizedType.t +val variadic_ode_return_type : UnsizedType.t + +(* variadic DAE helpers *) +val is_variadic_dae_fn : string -> bool +val is_variadic_dae_tol_fn : string -> bool +val dae_tolerances_suffix : string +val variadic_dae_mandatory_arg_types : fun_arg list +val variadic_dae_mandatory_fun_args : fun_arg list +val variadic_dae_tol_arg_types : fun_arg list +val variadic_dae_fun_return_type : UnsizedType.t +val variadic_dae_return_type : UnsizedType.t diff --git a/src/stanc/stanc.ml b/src/stanc/stanc.ml index 35151e9ed9..efbf8da480 100644 --- a/src/stanc/stanc.ml +++ b/src/stanc/stanc.ml @@ -16,6 +16,7 @@ module Deprecations = Deprecation_analysis.Make (CppLibrary) module Canonicalizer = Canonicalize.Make (Deprecations) module ModelInfo = Info.Make (CppLibrary) module Ast2Mir = Ast_to_Mir.Make (CppLibrary) +module Optimizer = Optimize.Make (CppLibrary) (** The main program. *) let version = "%%NAME%%3 %%VERSION%%" @@ -332,7 +333,7 @@ let use_file filename = if !no_soa_opt then {base_optims with optimize_soa= false} else if !soa_opt then {base_optims with optimize_soa= true} else base_optims in - Optimize.optimization_suite ~settings:set_optims tx_mir in + Optimizer.optimization_suite ~settings:set_optims tx_mir in if !dump_opt_mir then Sexp.pp_hum Format.std_formatter [%sexp (opt : Middle.Program.Typed.t)] ; if !dump_opt_mir_pretty then Program.Typed.pp Format.std_formatter opt ; diff --git a/test/unit/Debug_data_generation_tests.ml b/test/unit/Debug_data_generation_tests.ml index 6ab94c65a8..f3d97521e7 100644 --- a/test/unit/Debug_data_generation_tests.ml +++ b/test/unit/Debug_data_generation_tests.ml @@ -1,9 +1,4 @@ -open Analysis_and_optimization open Core_kernel -open Frontend -open Debug_data_generation - -let print_data_prog ast = print_data_prog (Ast_to_Mir.gather_data ast) let%expect_test "whole program data generation check" = let ast = @@ -17,7 +12,7 @@ let%expect_test "whole program data generation check" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| @@ -48,7 +43,7 @@ let%expect_test "whole program data generation check" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| @@ -97,7 +92,7 @@ let%expect_test "whole program data generation check" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| @@ -134,7 +129,7 @@ let%expect_test "whole program data generation check" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| @@ -256,7 +251,7 @@ let%expect_test "whole program data generation check" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| @@ -515,7 +510,7 @@ let%expect_test "whole program data generation check" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| @@ -647,7 +642,7 @@ let%expect_test "whole program data generation check" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| @@ -666,7 +661,7 @@ let%expect_test "Complex numbers program" = } |} in - let str = print_data_prog ast in + let str = Test_utils.print_data_prog ast in print_string str ; [%expect {| diff --git a/test/unit/Desugar_test.ml b/test/unit/Desugar_test.ml index 525fbfe24b..1bf13ae28b 100644 --- a/test/unit/Desugar_test.ml +++ b/test/unit/Desugar_test.ml @@ -1,6 +1,9 @@ open Core_kernel open Analysis_and_optimization +module Partial_evaluator = + Partial_evaluation.Make (Stan_math_backend.Stan_math_library) + let print_tdata Middle.Program.{prepare_data; _} = Fmt.(str "@[%a@]@," (list ~sep:cut Middle.Stmt.Located.pp) prepare_data) |> print_endline diff --git a/test/unit/Optimize.ml b/test/unit/Optimize.ml index 67154bbbe0..d61c372f82 100644 --- a/test/unit/Optimize.ml +++ b/test/unit/Optimize.ml @@ -1,9 +1,13 @@ open Core_kernel -open Analysis_and_optimization.Optimize open Middle open Common open Analysis_and_optimization.Mir_utils +module Optimizer = + Analysis_and_optimization.Optimize.Make (Stan_math_backend.Stan_math_library) + +open Optimizer + let reset_and_mir_of_string s = Gensym.reset_danger_use_cautiously () ; Test_utils.mir_of_string s diff --git a/test/unit/Test_utils.ml b/test/unit/Test_utils.ml index 5522236e61..03838e7405 100644 --- a/test/unit/Test_utils.ml +++ b/test/unit/Test_utils.ml @@ -1,6 +1,12 @@ open Frontend open Core_kernel +module CppLibrary : Std_library_utils.Library = + Stan_math_backend.Stan_math_library + +module Typechecker = Typechecking.Make (CppLibrary) +module Ast2Mir = Ast_to_Mir.Make (CppLibrary) + let untyped_ast_of_string s = let res, warnings = Parse.parse_string Parser.Incremental.program s in Fmt.epr "%a" (Fmt.list ~sep:Fmt.nop Warnings.pp) warnings ; @@ -15,4 +21,8 @@ let typed_ast_of_string_exn s = |> Result.map_error ~f:Errors.to_string |> Result.ok_or_failwith |> fst -let mir_of_string s = typed_ast_of_string_exn s |> Ast_to_Mir.trans_prog "" +let mir_of_string s = typed_ast_of_string_exn s |> Ast2Mir.trans_prog "" + +let print_data_prog ast = + Analysis_and_optimization.Debug_data_generation.print_data_prog + (Ast2Mir.gather_data ast) From c7f1b144ed4bfdcc77c609fe794bf232635e31bc Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Wed, 27 Apr 2022 16:53:15 -0400 Subject: [PATCH 06/14] Variadic typechecking, details. Tests passing --- src/frontend/Ast.ml | 4 + src/frontend/Deprecation_analysis.ml | 38 +--- src/frontend/Semantic_error.ml | 10 +- src/frontend/SignatureMismatch.ml | 14 ++ src/frontend/SignatureMismatch.mli | 9 + src/frontend/Std_library_utils.ml | 31 ++- src/frontend/Typechecking.ml | 205 +----------------- src/frontend/Typechecking.mli | 9 + src/stan_math_backend/Stan_math_library.ml | 228 ++++++++++++++++++++- test/integration/bad/stanc.expected | 2 - 10 files changed, 292 insertions(+), 258 deletions(-) diff --git a/src/frontend/Ast.ml b/src/frontend/Ast.ml index 58502a85a5..c47e9a8daa 100644 --- a/src/frontend/Ast.ml +++ b/src/frontend/Ast.ml @@ -80,6 +80,10 @@ let mk_untyped_expression ~expr ~loc = {expr; emeta= {loc}} let mk_typed_expression ~expr ~loc ~type_ ~ad_level = {expr; emeta= {loc; type_; ad_level}} +let mk_fun_app ~is_cond_dist (kind, id, arguments) = + if is_cond_dist then CondDistApp (kind, id, arguments) + else FunApp (kind, id, arguments) + let expr_loc_lub exprs = match List.map ~f:(fun e -> e.emeta.loc) exprs with | [] -> Location_span.empty diff --git a/src/frontend/Deprecation_analysis.ml b/src/frontend/Deprecation_analysis.ml index 5f4d5cc2a8..2f99fa55a8 100644 --- a/src/frontend/Deprecation_analysis.ml +++ b/src/frontend/Deprecation_analysis.ml @@ -21,38 +21,6 @@ module type Deprecation_analizer = sig end module Make (StdLib : Std_library_utils.Library) : Deprecation_analizer = struct - (* String.Map.of_alist_exn - (List.map - ~f:(fun (x, y) -> (x, (y, "2.32.0"))) - (List.concat_map StdLib.distributions - ~f:(fun (fnkinds, name, _, _) -> - List.filter_map fnkinds ~f:(function - | Lpdf -> Some (name ^ "_log", name ^ "_lpdf") - | Lpmf -> Some (name ^ "_log", name ^ "_lpmf") - | Cdf -> Some (name ^ "_cdf_log", name ^ "_lcdf") - | Ccdf -> Some (name ^ "_ccdf_log", name ^ "_lccdf") - | Rng | UnaryVectorized -> None ) ) ) ) *) - (* String.Map.of_alist_exn - [ ("multiply_log", ("lmultiply", "2.32.0")) - ; ("binomial_coefficient_log", ("lchoose", "2.32.0")) - ; ("cov_exp_quad", ("gp_exp_quad_cov", "2.32.0")) ] - - + - This can be automatically changed using the \ - canonicalize flag for stanc - *) - - (* String.Map.of_alist_exn - [ ("integrate_ode", ("ode_rk45", "3.0")) - ; ("integrate_ode_rk45", ("ode_rk45", "3.0")) - ; ("integrate_ode_bdf", ("ode_bdf", "3.0")) - ; ("integrate_ode_adams", ("ode_adams", "3.0")) ] - - + - The new interface is slightly different, see: - https://mc-stan.org/users/documentation/case-studies/convert_odes.html - *) - let stan_lib_deprecations = Map.merge_skewed StdLib.deprecated_distributions StdLib.deprecated_functions ~combine:(fun ~key x y -> @@ -68,7 +36,9 @@ module Make (StdLib : Std_library_utils.Library) : Deprecation_analizer = struct let rename_deprecated map name = Map.find map name - |> Option.map ~f:(fun Std_library_utils.{replacement; _} -> replacement) + |> Option.map + ~f:(fun Std_library_utils.{replacement; canonicalize_away; _} -> + if canonicalize_away then replacement else name ) |> Option.value ~default:name let rename_deprecated_distribution = @@ -153,7 +123,7 @@ module Make (StdLib : Std_library_utils.Library) : Deprecation_analizer = struct | FunApp ((StanLib _ | UserDefined _), {name; _}, l) -> let w = match Map.find stan_lib_deprecations name with - | Some {replacement; version; extra_message} -> + | Some {replacement; version; extra_message; _} -> [ ( emeta.loc , name ^ " is deprecated and will be removed in Stan " ^ version ^ ". Use " ^ replacement ^ " instead. " ^ extra_message ) ] diff --git a/src/frontend/Semantic_error.ml b/src/frontend/Semantic_error.ml index 42c36a5b59..e56f5fdb19 100644 --- a/src/frontend/Semantic_error.ml +++ b/src/frontend/Semantic_error.ml @@ -221,21 +221,21 @@ module TypeError = struct | IllTypedBinaryOperator (op, lt, rt, sigs) -> Fmt.pf ppf "Ill-typed arguments supplied to infix operator %a. Available \ - signatures: @[%a@]@[Instead supplied arguments of \ + signatures: @[%a@.@]@[Instead supplied arguments of \ incompatible type: %a, %a.@]" Operator.pp op Std_library_utils.pp_math_sigs sigs UnsizedType.pp lt UnsizedType.pp rt | IllTypedPrefixOperator (op, ut, sigs) -> Fmt.pf ppf "Ill-typed arguments supplied to prefix operator %a. Available \ - signatures: @[%a@]@[Instead supplied argument of incompatible \ - type: %a.@]" + signatures: @[%a@.@]@[Instead supplied argument of \ + incompatible type: %a.@]" Operator.pp op Std_library_utils.pp_math_sigs sigs UnsizedType.pp ut | IllTypedPostfixOperator (op, ut, sigs) -> Fmt.pf ppf "Ill-typed arguments supplied to postfix operator %a. Available \ - signatures: @[%a@]@[Instead supplied argument of incompatible \ - type: %a.@]" + signatures: @[%a@.@]@[Instead supplied argument of \ + incompatible type: %a.@]" Operator.pp op Std_library_utils.pp_math_sigs sigs UnsizedType.pp ut end diff --git a/src/frontend/SignatureMismatch.ml b/src/frontend/SignatureMismatch.ml index 11bac9a347..5d4d4e08bf 100644 --- a/src/frontend/SignatureMismatch.ml +++ b/src/frontend/SignatureMismatch.ml @@ -283,6 +283,20 @@ let check_variadic_args allow_lpdf mandatory_arg_tys mandatory_fun_arg_tys | (_, x) :: _ -> TypeMismatch (minimal_func_type, x, None) |> wrap_err | [] -> Error ([], ArgNumMismatch (List.length mandatory_arg_tys, 0)) +let find_matching_first_order_fn tenv matches (fname : Ast.identifier) = + let candidates = + Utils.stdlib_distribution_name fname.name + |> Environment.find tenv |> List.map ~f:matches in + let ok, errs = List.partition_map candidates ~f:Result.to_either in + match unique_minimum_promotion ok with + | Ok a -> UniqueMatch a + | Error (Some promotions) -> + List.filter_map promotions ~f:(function + | UnsizedType.UFun (args, rt, _, _) -> Some (rt, args) + | _ -> None ) + |> AmbiguousMatch + | Error None -> SignatureErrors (List.hd_exn errs) + let pp_signature_mismatch ppf (name, arg_tys, (sigs, omitted)) = let open Fmt in let ctx = ref TypeMap.empty in diff --git a/src/frontend/SignatureMismatch.mli b/src/frontend/SignatureMismatch.mli index b1565ab744..6927e21716 100644 --- a/src/frontend/SignatureMismatch.mli +++ b/src/frontend/SignatureMismatch.mli @@ -69,6 +69,15 @@ val check_variadic_args : If none is found, returns [Error] of the list of args and a function_mismatch. *) +val find_matching_first_order_fn : + Environment.t + -> (Environment.info -> (UnsizedType.t * Promotion.t list, 'a) result) + -> Ast.identifier + -> (UnsizedType.t * Promotion.t list, 'a) generic_match_result +(** Given a constraint function [matches], find any signature which exists + Returns the first [Ok] if any exist, or else [Error] +*) + val pp_signature_mismatch : Format.formatter -> string diff --git a/src/frontend/Std_library_utils.ml b/src/frontend/Std_library_utils.ml index 09c2989dfe..8c45e2d954 100644 --- a/src/frontend/Std_library_utils.ml +++ b/src/frontend/Std_library_utils.ml @@ -9,17 +9,11 @@ type fun_arg = UnsizedType.autodifftype * UnsizedType.t type signature = UnsizedType.returntype * fun_arg list * Common.Helpers.mem_pattern -type variadic_checker = - is_cond_dist:bool - -> Location_span.t - -> Environment.originblock - -> Environment.t - -> Ast.identifier - -> Ast.typed_expression list - -> Ast.typed_expression - type deprecation_info = - {replacement: string; version: string; extra_message: string} + { replacement: string + ; version: string + ; extra_message: string + ; canonicalize_away: bool } [@@deriving sexp] module type Library = sig @@ -39,6 +33,17 @@ module type Library = sig val get_assignment_operator_signatures : Operator.t -> signature list val is_not_overloadable : string -> bool val is_variadic_function_name : string -> bool + val variadic_function_returntype : string -> UnsizedType.returntype option + + val check_variadic_fn : + Ast.identifier + -> is_cond_dist:bool + -> Location_span.t + -> Environment.originblock + -> Environment.t + -> Ast.typed_expression list + -> Ast.typed_expression + val operator_to_function_names : Operator.t -> string list val string_operator_to_function_name : string -> string val deprecated_distributions : deprecation_info String.Map.t @@ -60,6 +65,12 @@ module NullLibrary : Library = struct let get_operator_signatures _ = [] let is_not_overloadable _ = false let is_variadic_function_name _ = false + let variadic_function_returntype _ = None + + let check_variadic_fn _ ~is_cond_dist _ _ _ _ : Ast.typed_expression = + ignore (is_cond_dist : bool) ; + Common.FatalError.fatal_error_msg [%message "Impossible"] + let operator_to_function_names _ = [] let string_operator_to_function_name s = s let deprecated_distributions = String.Map.empty diff --git a/src/frontend/Typechecking.ml b/src/frontend/Typechecking.ml index 9d981c020e..3920d898b1 100644 --- a/src/frontend/Typechecking.ml +++ b/src/frontend/Typechecking.ml @@ -213,7 +213,8 @@ module Make (StdLibrary : Std_library_utils.Library) : Typechecker = struct let library_function_return_type name arg_tys = match name with - | x when StdLibrary.is_variadic_function_name x -> Some (failwith "TODO") + | x when StdLibrary.is_variadic_function_name x -> + StdLibrary.variadic_function_returntype x | _ -> SignatureMismatch.matching_function std_library_tenv name arg_tys |> match_to_rt_option @@ -437,7 +438,6 @@ module Make (StdLibrary : Std_library_utils.Library) : Typechecker = struct uindices ) (* function checking *) - let verify_conddist_name loc id = if List.exists @@ -486,9 +486,6 @@ module Make (StdLibrary : Std_library_utils.Library) : Typechecker = struct && not ((cf.in_fun_def && cf.in_udf_dist_def) || cf.current_block = Model) then Semantic_error.invalid_unnormalized_fn loc |> error - let mk_fun_app ~is_cond_dist (x, y, z) = - if is_cond_dist then CondDistApp (x, y, z) else FunApp (x, y, z) - let check_normal_fn ~is_cond_dist loc tenv id es = match Env.find tenv (Utils.normalized_name id.name) with | {kind= `Variable _; _} :: _ @@ -558,205 +555,13 @@ module Make (StdLibrary : Std_library_utils.Library) : Typechecker = struct |> Semantic_error.illtyped_fn_app loc id.name (l, b) |> error ) - (** Given a constraint function [matches], find any signature which exists - Returns the first [Ok] if any exist, or else [Error] -*) - (* let find_matching_first_order_fn tenv matches fname = - let candidates = - Utils.stdlib_distribution_name fname.name - |> Env.find tenv |> List.map ~f:matches in - let ok, errs = List.partition_map candidates ~f:Result.to_either in - match SignatureMismatch.unique_minimum_promotion ok with - | Ok a -> SignatureMismatch.UniqueMatch a - | Error (Some promotions) -> - List.filter_map promotions ~f:(function - | UnsizedType.UFun (args, rt, _, _) -> Some (rt, args) - | _ -> None ) - |> AmbiguousMatch - | Error None -> SignatureMismatch.SignatureErrors (List.hd_exn errs) *) - - (* let make_function_variable current_block loc id = function - | UnsizedType.UFun (args, rt, FnLpdf _, mem_pattern) -> - let type_ = - UnsizedType.UFun - (args, rt, Fun_kind.suffix_from_name id.name, mem_pattern) in - mk_typed_expression ~expr:(Variable id) - ~ad_level:(calculate_autodifftype current_block Functions type_) - ~type_ ~loc - | UnsizedType.UFun _ as type_ -> - mk_typed_expression ~expr:(Variable id) - ~ad_level:(calculate_autodifftype current_block Functions type_) - ~type_ ~loc - | type_ -> - Common.FatalError.fatal_error_msg - [%message - "Attempting to create function variable out of " - (type_ : UnsizedType.t)] *) - let rec check_fn ~is_cond_dist loc cf tenv id (tes : Ast.typed_expression list) = - if StdLibrary.is_variadic_function_name id.name then ( - Stdlib.ignore cf ; - failwith "TODO" - (* if StdLibrary.is_reduce_sum_fn id.name then - check_reduce_sum ~is_cond_dist loc cf tenv id tes - else if StdLibrary.is_variadic_ode_fn id.name then - check_variadic_ode ~is_cond_dist loc cf tenv id tes - else if StdLibrary.is_variadic_dae_fn id.name then - check_variadic_dae ~is_cond_dist loc cf tenv id tes *) ) + if StdLibrary.is_variadic_function_name id.name then + StdLibrary.check_variadic_fn id ~is_cond_dist loc cf.current_block tenv + tes else check_normal_fn ~is_cond_dist loc tenv id tes - (* and check_reduce_sum ~is_cond_dist loc cf tenv id tes = - let basic_mismatch () = - let mandatory_args = - UnsizedType.[(AutoDiffable, UArray UReal); (AutoDiffable, UInt)] in - let mandatory_fun_args = - UnsizedType. - [(AutoDiffable, UArray UReal); (DataOnly, UInt); (DataOnly, UInt)] - in - SignatureMismatch.check_variadic_args true mandatory_args - mandatory_fun_args UReal (get_arg_types tes) in - let fail () = - let expected_args, err = - basic_mismatch () |> Result.error |> Option.value_exn in - Semantic_error.illtyped_reduce_sum_generic loc id.name - (List.map ~f:type_of_expr_typed tes) - expected_args err - |> error in - let matching remaining_es fn = - match fn with - | Env. - { type_= - UnsizedType.UFun - (((_, sliced_arg_fun_type) as sliced_arg_fun) :: _, _, _, _) as - ftype - ; _ } - when List.mem StdLibrary.reduce_sum_slice_types sliced_arg_fun_type - ~equal:( = ) -> - let mandatory_args = [sliced_arg_fun; (AutoDiffable, UInt)] in - let mandatory_fun_args = - [sliced_arg_fun; (DataOnly, UInt); (DataOnly, UInt)] in - let arg_types = - (calculate_autodifftype cf.current_block Functions ftype, ftype) - :: get_arg_types remaining_es in - SignatureMismatch.check_variadic_args true mandatory_args - mandatory_fun_args UReal arg_types - | _ -> basic_mismatch () in - match tes with - | {expr= Variable fname; _} :: remaining_es -> ( - match find_matching_first_order_fn tenv (matching remaining_es) fname with - | SignatureMismatch.UniqueMatch (ftype, promotions) -> - (* a valid signature exists *) - let tes = make_function_variable cf loc fname ftype :: remaining_es in - mk_typed_expression - ~expr: - (mk_fun_app ~is_cond_dist - (StanLib FnPlain, id, Promotion.promote_list tes promotions) ) - ~ad_level:(expr_ad_lub tes) ~type_:UnsizedType.UReal ~loc - | AmbiguousMatch ps -> - Semantic_error.ambiguous_function_promotion loc fname.name None ps - |> error - | SignatureErrors (expected_args, err) -> - Semantic_error.illtyped_reduce_sum loc id.name - (List.map ~f:type_of_expr_typed tes) - expected_args err - |> error ) - | _ -> fail () - - and check_variadic_ode ~is_cond_dist loc cf tenv id tes = - let optional_tol_mandatory_args = - if StdLibrary.variadic_ode_adjoint_fn = id.name then - StdLibrary.variadic_ode_adjoint_ctl_tol_arg_types - else if StdLibrary.is_variadic_ode_nonadjoint_tol_fn id.name then - StdLibrary.variadic_ode_tol_arg_types - else [] in - let mandatory_arg_types = - StdLibrary.variadic_ode_mandatory_arg_types @ optional_tol_mandatory_args - in - let fail () = - let expected_args, err = - SignatureMismatch.check_variadic_args false mandatory_arg_types - StdLibrary.variadic_ode_mandatory_fun_args - StdLibrary.variadic_ode_fun_return_type (get_arg_types tes) - |> Result.error |> Option.value_exn in - Semantic_error.illtyped_variadic_ode loc id.name - (List.map ~f:type_of_expr_typed tes) - expected_args err - |> error in - let matching remaining_es Env.{type_= ftype; _} = - let arg_types = - (calculate_autodifftype cf.current_block Functions ftype, ftype) - :: get_arg_types remaining_es in - SignatureMismatch.check_variadic_args false mandatory_arg_types - StdLibrary.variadic_ode_mandatory_fun_args - StdLibrary.variadic_ode_fun_return_type arg_types in - match tes with - | {expr= Variable fname; _} :: remaining_es -> ( - match find_matching_first_order_fn tenv (matching remaining_es) fname with - | SignatureMismatch.UniqueMatch (ftype, promotions) -> - let tes = make_function_variable cf loc fname ftype :: remaining_es in - mk_typed_expression - ~expr: - (mk_fun_app ~is_cond_dist - (StanLib FnPlain, id, Promotion.promote_list tes promotions) ) - ~ad_level:(expr_ad_lub tes) - ~type_:StdLibrary.variadic_ode_return_type ~loc - | AmbiguousMatch ps -> - Semantic_error.ambiguous_function_promotion loc fname.name None ps - |> error - | SignatureErrors (expected_args, err) -> - Semantic_error.illtyped_variadic_ode loc id.name - (List.map ~f:type_of_expr_typed tes) - expected_args err - |> error ) - | _ -> fail () - - and check_variadic_dae ~is_cond_dist loc cf tenv id tes = - let optional_tol_mandatory_args = - if StdLibrary.is_variadic_dae_tol_fn id.name then - StdLibrary.variadic_dae_tol_arg_types - else [] in - let mandatory_arg_types = - StdLibrary.variadic_dae_mandatory_arg_types @ optional_tol_mandatory_args - in - let fail () = - let expected_args, err = - SignatureMismatch.check_variadic_args false mandatory_arg_types - StdLibrary.variadic_dae_mandatory_fun_args - StdLibrary.variadic_dae_fun_return_type (get_arg_types tes) - |> Result.error |> Option.value_exn in - Semantic_error.illtyped_variadic_dae loc id.name - (List.map ~f:type_of_expr_typed tes) - expected_args err - |> error in - let matching remaining_es Env.{type_= ftype; _} = - let arg_types = - (calculate_autodifftype cf.current_block Functions ftype, ftype) - :: get_arg_types remaining_es in - SignatureMismatch.check_variadic_args false mandatory_arg_types - StdLibrary.variadic_dae_mandatory_fun_args - StdLibrary.variadic_dae_fun_return_type arg_types in - match tes with - | {expr= Variable fname; _} :: remaining_es -> ( - match find_matching_first_order_fn tenv (matching remaining_es) fname with - | SignatureMismatch.UniqueMatch (ftype, promotions) -> - let tes = make_function_variable cf loc fname ftype :: remaining_es in - mk_typed_expression - ~expr: - (mk_fun_app ~is_cond_dist - (StanLib FnPlain, id, Promotion.promote_list tes promotions) ) - ~ad_level:(expr_ad_lub tes) - ~type_:StdLibrary.variadic_dae_return_type ~loc - | AmbiguousMatch ps -> - Semantic_error.ambiguous_function_promotion loc fname.name None ps - |> error - | SignatureErrors (expected_args, err) -> - Semantic_error.illtyped_variadic_dae loc id.name - (List.map ~f:type_of_expr_typed tes) - expected_args err - |> error ) - | _ -> fail () *) - and check_funapp loc cf tenv ~is_cond_dist id (es : Ast.typed_expression list) = let name_check = diff --git a/src/frontend/Typechecking.mli b/src/frontend/Typechecking.mli index 397b30b60e..97aacdcb15 100644 --- a/src/frontend/Typechecking.mli +++ b/src/frontend/Typechecking.mli @@ -22,6 +22,15 @@ val model_name : string ref val check_that_all_functions_have_definition : bool ref (** A switch to determine whether we check that all functions have a definition *) +val get_arg_types : typed_expression list -> Std_library_utils.fun_arg list +val type_of_expr_typed : typed_expression -> Middle.UnsizedType.t + +val calculate_autodifftype : + Environment.originblock + -> Environment.originblock + -> Middle.UnsizedType.t + -> Middle.UnsizedType.autodifftype + module type Typechecker = sig val check_program_exn : untyped_program -> typed_program * Warnings.t list (** diff --git a/src/stan_math_backend/Stan_math_library.ml b/src/stan_math_backend/Stan_math_library.ml index 07c526dad2..464a476d2d 100644 --- a/src/stan_math_backend/Stan_math_library.ml +++ b/src/stan_math_backend/Stan_math_library.ml @@ -239,7 +239,219 @@ let is_variadic_dae_tol_fn f = let is_variadic_function_name name = is_reduce_sum_fn name || is_variadic_dae_fn name || is_variadic_ode_fn name -let is_not_overloadable = is_variadic_dae_fn +let variadic_function_returntype name = + if is_reduce_sum_fn name then Some (UnsizedType.ReturnType UReal) + else if is_variadic_ode_fn name then + Some (UnsizedType.ReturnType variadic_ode_return_type) + else if is_variadic_dae_fn name then + Some (UnsizedType.ReturnType variadic_dae_return_type) + else None + +let is_not_overloadable = is_variadic_function_name + +module Variadic_typechecking = struct + (** This module serves as the backend-specific portion + of the typechecker. *) + + open Frontend + open Typechecking + open Ast + + let error e = raise (Errors.SemanticError e) + + let make_function_variable current_block loc id = function + | UnsizedType.UFun (args, rt, FnLpdf _, mem_pattern) -> + let type_ = + UnsizedType.UFun + (args, rt, Fun_kind.suffix_from_name id.name, mem_pattern) in + mk_typed_expression ~expr:(Variable id) + ~ad_level:(calculate_autodifftype current_block Functions type_) + ~type_ ~loc + | UnsizedType.UFun _ as type_ -> + mk_typed_expression ~expr:(Variable id) + ~ad_level:(calculate_autodifftype current_block Functions type_) + ~type_ ~loc + | type_ -> + Common.FatalError.fatal_error_msg + [%message + "Attempting to create function variable out of " + (type_ : UnsizedType.t)] + + let check_reduce_sum ~is_cond_dist loc current_block tenv id tes = + let basic_mismatch () = + let mandatory_args = + UnsizedType.[(AutoDiffable, UArray UReal); (AutoDiffable, UInt)] in + let mandatory_fun_args = + UnsizedType. + [(AutoDiffable, UArray UReal); (DataOnly, UInt); (DataOnly, UInt)] + in + SignatureMismatch.check_variadic_args true mandatory_args + mandatory_fun_args UReal (get_arg_types tes) in + let fail () = + let expected_args, err = + basic_mismatch () |> Result.error |> Option.value_exn in + Semantic_error.illtyped_variadic_fn loc id.name + (List.map ~f:type_of_expr_typed tes) + expected_args err UReal + |> error in + let matching remaining_es fn = + match fn with + | Environment. + { type_= + UnsizedType.UFun + (((_, sliced_arg_fun_type) as sliced_arg_fun) :: _, _, _, _) as + ftype + ; _ } + when List.mem reduce_sum_slice_types sliced_arg_fun_type ~equal:( = ) -> + let mandatory_args = [sliced_arg_fun; (AutoDiffable, UInt)] in + let mandatory_fun_args = + [sliced_arg_fun; (DataOnly, UInt); (DataOnly, UInt)] in + let arg_types = + (calculate_autodifftype current_block Functions ftype, ftype) + :: get_arg_types remaining_es in + SignatureMismatch.check_variadic_args true mandatory_args + mandatory_fun_args UReal arg_types + | _ -> basic_mismatch () in + match tes with + | {expr= Variable fname; _} :: remaining_es -> ( + match + SignatureMismatch.find_matching_first_order_fn tenv + (matching remaining_es) fname + with + | SignatureMismatch.UniqueMatch (ftype, promotions) -> + (* a valid signature exists *) + let tes = + make_function_variable current_block loc fname ftype :: remaining_es + in + mk_typed_expression + ~expr: + (mk_fun_app ~is_cond_dist + (StanLib FnPlain, id, Promotion.promote_list tes promotions) ) + ~ad_level:(expr_ad_lub tes) ~type_:UnsizedType.UReal ~loc + | AmbiguousMatch ps -> + Semantic_error.ambiguous_function_promotion loc fname.name None ps + |> error + | SignatureErrors (expected_args, err) -> + Semantic_error.illtyped_variadic_fn loc id.name + (List.map ~f:type_of_expr_typed tes) + expected_args err UReal + |> error ) + | _ -> fail () + + let check_variadic_ode ~is_cond_dist loc current_block tenv id tes = + let optional_tol_mandatory_args = + if variadic_ode_adjoint_fn = id.name then + variadic_ode_adjoint_ctl_tol_arg_types + else if is_variadic_ode_nonadjoint_tol_fn id.name then + variadic_ode_tol_arg_types + else [] in + let mandatory_arg_types = + variadic_ode_mandatory_arg_types @ optional_tol_mandatory_args in + let fail () = + let expected_args, err = + SignatureMismatch.check_variadic_args false mandatory_arg_types + variadic_ode_mandatory_fun_args variadic_ode_fun_return_type + (get_arg_types tes) + |> Result.error |> Option.value_exn in + Semantic_error.illtyped_variadic_fn loc id.name + (List.map ~f:type_of_expr_typed tes) + expected_args err variadic_ode_fun_return_type + |> error in + let matching remaining_es Environment.{type_= ftype; _} = + let arg_types = + (calculate_autodifftype current_block Functions ftype, ftype) + :: get_arg_types remaining_es in + SignatureMismatch.check_variadic_args false mandatory_arg_types + variadic_ode_mandatory_fun_args variadic_ode_fun_return_type arg_types + in + match tes with + | {expr= Variable fname; _} :: remaining_es -> ( + match + SignatureMismatch.find_matching_first_order_fn tenv + (matching remaining_es) fname + with + | SignatureMismatch.UniqueMatch (ftype, promotions) -> + let tes = + make_function_variable current_block loc fname ftype :: remaining_es + in + mk_typed_expression + ~expr: + (mk_fun_app ~is_cond_dist + (StanLib FnPlain, id, Promotion.promote_list tes promotions) ) + ~ad_level:(expr_ad_lub tes) ~type_:variadic_ode_return_type ~loc + | AmbiguousMatch ps -> + Semantic_error.ambiguous_function_promotion loc fname.name None ps + |> error + | SignatureErrors (expected_args, err) -> + Semantic_error.illtyped_variadic_fn loc id.name + (List.map ~f:type_of_expr_typed tes) + expected_args err variadic_ode_fun_return_type + |> error ) + | _ -> fail () + + let check_variadic_dae ~is_cond_dist loc current_block tenv id tes = + let optional_tol_mandatory_args = + if is_variadic_dae_tol_fn id.name then variadic_dae_tol_arg_types else [] + in + let mandatory_arg_types = + variadic_dae_mandatory_arg_types @ optional_tol_mandatory_args in + let fail () = + let expected_args, err = + SignatureMismatch.check_variadic_args false mandatory_arg_types + variadic_dae_mandatory_fun_args variadic_dae_fun_return_type + (get_arg_types tes) + |> Result.error |> Option.value_exn in + Semantic_error.illtyped_variadic_fn loc id.name + (List.map ~f:type_of_expr_typed tes) + expected_args err variadic_dae_fun_return_type + |> error in + let matching remaining_es Environment.{type_= ftype; _} = + let arg_types = + (calculate_autodifftype current_block Functions ftype, ftype) + :: get_arg_types remaining_es in + SignatureMismatch.check_variadic_args false mandatory_arg_types + variadic_dae_mandatory_fun_args variadic_dae_fun_return_type arg_types + in + match tes with + | {expr= Variable fname; _} :: remaining_es -> ( + match + SignatureMismatch.find_matching_first_order_fn tenv + (matching remaining_es) fname + with + | SignatureMismatch.UniqueMatch (ftype, promotions) -> + let tes = + make_function_variable current_block loc fname ftype :: remaining_es + in + mk_typed_expression + ~expr: + (mk_fun_app ~is_cond_dist + (StanLib FnPlain, id, Promotion.promote_list tes promotions) ) + ~ad_level:(expr_ad_lub tes) ~type_:variadic_dae_return_type ~loc + | AmbiguousMatch ps -> + Semantic_error.ambiguous_function_promotion loc fname.name None ps + |> error + | SignatureErrors (expected_args, err) -> + Semantic_error.illtyped_variadic_fn loc id.name + (List.map ~f:type_of_expr_typed tes) + expected_args err variadic_dae_fun_return_type + |> error ) + | _ -> fail () +end + +let check_variadic_fn id ~is_cond_dist loc current_block tenv tes = + if is_reduce_sum_fn id.Frontend.Ast.name then + Variadic_typechecking.check_reduce_sum ~is_cond_dist loc current_block tenv + id tes + else if is_variadic_ode_fn id.name then + Variadic_typechecking.check_variadic_ode ~is_cond_dist loc current_block + tenv id tes + else if is_variadic_dae_fn id.name then + Variadic_typechecking.check_variadic_dae ~is_cond_dist loc current_block + tenv id tes + else + Common.FatalError.fatal_error_msg + [%message + "Invalid variadic function for Stan Math backend" (id.name : string)] let distributions = [ ( full_lpmf @@ -499,21 +711,23 @@ let deprecated_distributions = ; version= "2.32.0" ; extra_message= "This can be automatically changed using the canonicalize flag \ - for stanc" } ) ) + for stanc" + ; canonicalize_away= true } ) ) |> String.Map.of_alist_exn let deprecated_functions = - let make extra_message version replacement = - {extra_message; replacement; version} in + let make extra_message version canonicalize_away replacement = + {extra_message; replacement; version; canonicalize_away} in let ode = make - "The new interface is slightly different, see: \n\ + "\n\ + The new interface is slightly different, see: \ https://mc-stan.org/users/documentation/case-studies/convert_odes.html" - "3.0" in + "3.0" false in let std = make "This can be automatically changed using the canonicalize flag for stanc" - "2.32" in + "2.32.0" true in String.Map.of_alist_exn [ ("multiply_log", std "lmultiply") ; ("binomial_coefficient_log", std "lchoose") diff --git a/test/integration/bad/stanc.expected b/test/integration/bad/stanc.expected index a4bfd12395..3263db8341 100644 --- a/test/integration/bad/stanc.expected +++ b/test/integration/bad/stanc.expected @@ -1531,7 +1531,6 @@ Ill-typed arguments supplied to infix operator /. Available signatures: (matrix, matrix) => matrix (complex_row_vector, complex_matrix) => complex_row_vector (complex_matrix, complex_matrix) => complex_matrix - (int, int) => int (real, real) => real (vector, real) => vector @@ -1554,7 +1553,6 @@ Ill-typed arguments supplied to infix operator /. Available signatures: (matrix, matrix) => matrix (complex_row_vector, complex_matrix) => complex_row_vector (complex_matrix, complex_matrix) => complex_matrix - (int, int) => int (real, real) => real (vector, real) => vector From 29d0bf7e4fd2bb8f57aa77f09a444fc07577f2e5 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 28 Apr 2022 09:35:54 -0400 Subject: [PATCH 07/14] Rename to avoid ambiguity with Core.Stdlib --- src/analysis_and_optimization/Mem_pattern.ml | 8 ++++---- src/analysis_and_optimization/Optimize.ml | 7 ++++--- src/analysis_and_optimization/Optimize.mli | 2 +- .../Partial_evaluation.ml | 4 ++-- .../Partial_evaluation.mli | 2 +- src/frontend/Ast_to_Mir.ml | 6 +++--- src/frontend/Ast_to_Mir.mli | 2 +- src/frontend/Deprecation_analysis.ml | 14 ++++++++------ src/frontend/Deprecation_analysis.mli | 2 +- src/frontend/Info.ml | 4 ++-- src/frontend/Info.mli | 2 +- src/frontend/Std_library_utils.ml | 5 +++-- 12 files changed, 31 insertions(+), 27 deletions(-) diff --git a/src/analysis_and_optimization/Mem_pattern.ml b/src/analysis_and_optimization/Mem_pattern.ml index 5b252626d4..e212e10332 100644 --- a/src/analysis_and_optimization/Mem_pattern.ml +++ b/src/analysis_and_optimization/Mem_pattern.ml @@ -2,7 +2,7 @@ open Core_kernel open Core_kernel.Poly open Middle -module Make (StdLib : Frontend.Std_library_utils.Library) = struct +module Make (StdLibrary : Frontend.Std_library_utils.Library) = struct (** Return a Var expression of the name for each type containing an eigen matrix @@ -113,12 +113,12 @@ module Make (StdLib : Frontend.Std_library_utils.Library) = struct let query_stan_math_mem_pattern_support (name : string) (args : (UnsizedType.autodifftype * UnsizedType.t) list) = match name with - | x when StdLib.is_variadic_function_name x -> false + | x when StdLibrary.is_variadic_function_name x -> false | _ -> let name = - StdLib.string_operator_to_function_name + StdLibrary.string_operator_to_function_name (Utils.stdlib_distribution_name name) in - let namematches = StdLib.get_signatures name in + let namematches = StdLibrary.get_signatures name in let filteredmatches = List.filter ~f:(fun x -> diff --git a/src/analysis_and_optimization/Optimize.ml b/src/analysis_and_optimization/Optimize.ml index a3c8bc5380..02ce9f90ac 100644 --- a/src/analysis_and_optimization/Optimize.ml +++ b/src/analysis_and_optimization/Optimize.ml @@ -93,9 +93,10 @@ module type Optimizer = sig ?settings:optimization_settings -> Program.Typed.t -> Program.Typed.t end -module Make (StdLib : Frontend.Std_library_utils.Library) : Optimizer = struct - module Mem = Mem_pattern.Make (StdLib) - module Partial_evaluator = Partial_evaluation.Make (StdLib) +module Make (StdLibrary : Frontend.Std_library_utils.Library) : Optimizer = +struct + module Mem = Mem_pattern.Make (StdLibrary) + module Partial_evaluator = Partial_evaluation.Make (StdLibrary) (** Apply the transformation to each function body and to the rest of the program as one diff --git a/src/analysis_and_optimization/Optimize.mli b/src/analysis_and_optimization/Optimize.mli index 8d1bc28331..1f3ae7c018 100644 --- a/src/analysis_and_optimization/Optimize.mli +++ b/src/analysis_and_optimization/Optimize.mli @@ -94,4 +94,4 @@ module type Optimizer = sig (** Perform all optimizations in this module on the MIR in an appropriate order. *) end -module Make (StdLib : Frontend.Std_library_utils.Library) : Optimizer +module Make (StdLibrary : Frontend.Std_library_utils.Library) : Optimizer diff --git a/src/analysis_and_optimization/Partial_evaluation.ml b/src/analysis_and_optimization/Partial_evaluation.ml index fc3864fb94..cd78a69d1f 100644 --- a/src/analysis_and_optimization/Partial_evaluation.ml +++ b/src/analysis_and_optimization/Partial_evaluation.ml @@ -92,8 +92,8 @@ module type PartialEvaluator = sig val eval_prog : Program.Typed.t -> Program.Typed.t end -module Make (StdLib : Frontend.Std_library_utils.Library) = struct - module TC = Frontend.Typechecking.Make (StdLib) +module Make (StdLibrary : Frontend.Std_library_utils.Library) = struct + module TC = Frontend.Typechecking.Make (StdLibrary) let rec eval_expr ?(preserve_stability = false) (e : Expr.Typed.t) = { e with diff --git a/src/analysis_and_optimization/Partial_evaluation.mli b/src/analysis_and_optimization/Partial_evaluation.mli index a12d5db74a..2491a8a8dc 100644 --- a/src/analysis_and_optimization/Partial_evaluation.mli +++ b/src/analysis_and_optimization/Partial_evaluation.mli @@ -5,4 +5,4 @@ module type PartialEvaluator = sig val eval_prog : Program.Typed.t -> Program.Typed.t end -module Make (StdLib : Frontend.Std_library_utils.Library) : PartialEvaluator +module Make (StdLibrary : Frontend.Std_library_utils.Library) : PartialEvaluator diff --git a/src/frontend/Ast_to_Mir.ml b/src/frontend/Ast_to_Mir.ml index 413181ff73..09e199c984 100644 --- a/src/frontend/Ast_to_Mir.ml +++ b/src/frontend/Ast_to_Mir.ml @@ -10,7 +10,7 @@ module type Ast_Mir_translator = sig val trans_prog : string -> Ast.typed_program -> Program.Typed.t end -module Make (StdLib : Std_library_utils.Library) = struct +module Make (StdLibrary : Std_library_utils.Library) = struct let trans_fn_kind kind name = let fname = Utils.stdlib_distribution_name name in match kind with @@ -113,7 +113,7 @@ module Make (StdLib : Std_library_utils.Library) = struct | None -> ( Ast.StanLib FnPlain , Set.to_list possible_names |> List.hd_exn - , if StdLib.is_stdlib_function_name (id.name ^ "_lpmf") then + , if StdLibrary.is_stdlib_function_name (id.name ^ "_lpmf") then UnsizedType.UInt else UnsizedType.UReal (* close enough *) ) in let trunc cond_op (x : Ast.typed_expression) y = @@ -446,7 +446,7 @@ module Make (StdLib : Std_library_utils.Library) = struct | Ast.Tilde {arg; distribution; args; truncation} -> let suffix = Std_library_utils.dist_name_suffix - (module StdLib) + (module StdLibrary) ud_dists distribution.name in let name = distribution.name ^ suffix in let kind = diff --git a/src/frontend/Ast_to_Mir.mli b/src/frontend/Ast_to_Mir.mli index c045b9a5b0..3d43f2579b 100644 --- a/src/frontend/Ast_to_Mir.mli +++ b/src/frontend/Ast_to_Mir.mli @@ -9,4 +9,4 @@ module type Ast_Mir_translator = sig val trans_prog : string -> Ast.typed_program -> Program.Typed.t end -module Make (StdLib : Std_library_utils.Library) : Ast_Mir_translator +module Make (StdLibrary : Std_library_utils.Library) : Ast_Mir_translator diff --git a/src/frontend/Deprecation_analysis.ml b/src/frontend/Deprecation_analysis.ml index 2f99fa55a8..dbe0699f58 100644 --- a/src/frontend/Deprecation_analysis.ml +++ b/src/frontend/Deprecation_analysis.ml @@ -20,10 +20,11 @@ module type Deprecation_analizer = sig val collect_warnings : typed_program -> Warnings.t list end -module Make (StdLib : Std_library_utils.Library) : Deprecation_analizer = struct +module Make (StdLibrary : Std_library_utils.Library) : Deprecation_analizer = +struct let stan_lib_deprecations = - Map.merge_skewed StdLib.deprecated_distributions StdLib.deprecated_functions - ~combine:(fun ~key x y -> + Map.merge_skewed StdLibrary.deprecated_distributions + StdLibrary.deprecated_functions ~combine:(fun ~key x y -> Common.FatalError.fatal_error_msg [%message "Common key in deprecation map" @@ -32,7 +33,7 @@ module Make (StdLib : Std_library_utils.Library) : Deprecation_analizer = struct (y : Std_library_utils.deprecation_info)] ) let is_deprecated_distribution name = - Map.mem StdLib.deprecated_distributions name + Map.mem StdLibrary.deprecated_distributions name let rename_deprecated map name = Map.find map name @@ -42,9 +43,10 @@ module Make (StdLib : Std_library_utils.Library) : Deprecation_analizer = struct |> Option.value ~default:name let rename_deprecated_distribution = - rename_deprecated StdLib.deprecated_distributions + rename_deprecated StdLibrary.deprecated_distributions - let rename_deprecated_function = rename_deprecated StdLib.deprecated_functions + let rename_deprecated_function = + rename_deprecated StdLibrary.deprecated_functions let distribution_suffix name = let open String in diff --git a/src/frontend/Deprecation_analysis.mli b/src/frontend/Deprecation_analysis.mli index 90261e8d6f..34e80aa35a 100644 --- a/src/frontend/Deprecation_analysis.mli +++ b/src/frontend/Deprecation_analysis.mli @@ -23,4 +23,4 @@ module type Deprecation_analizer = sig val collect_warnings : typed_program -> Warnings.t list end -module Make (StdLib : Std_library_utils.Library) : Deprecation_analizer +module Make (StdLibrary : Std_library_utils.Library) : Deprecation_analizer diff --git a/src/frontend/Info.ml b/src/frontend/Info.ml index 399d9f7477..19770dbe5b 100644 --- a/src/frontend/Info.ml +++ b/src/frontend/Info.ml @@ -56,7 +56,7 @@ module type Information = sig val info : Ast.typed_program -> string end -module Make (StdLib : Std_library_utils.Library) : Information = struct +module Make (StdLibrary : Std_library_utils.Library) : Information = struct let rec get_function_calls_stmt ud_dists (funs, distrs) stmt = let acc = match stmt.stmt with @@ -70,7 +70,7 @@ module Make (StdLib : Std_library_utils.Library) : Information = struct else let suffix = Std_library_utils.dist_name_suffix - (module StdLib) + (module StdLibrary) ud_dists distribution.name in let name = distribution.name ^ Utils.unnormalized_suffix suffix in (funs, Set.add distrs name) diff --git a/src/frontend/Info.mli b/src/frontend/Info.mli index f9a7b57cda..7d111ef79e 100644 --- a/src/frontend/Info.mli +++ b/src/frontend/Info.mli @@ -20,4 +20,4 @@ module type Information = sig val info : Ast.typed_program -> string end -module Make (StdLib : Std_library_utils.Library) : Information +module Make (StdLibrary : Std_library_utils.Library) : Information diff --git a/src/frontend/Std_library_utils.ml b/src/frontend/Std_library_utils.ml index 8c45e2d954..f7c5c766e8 100644 --- a/src/frontend/Std_library_utils.ml +++ b/src/frontend/Std_library_utils.ml @@ -83,10 +83,11 @@ let pp_math_sig ppf (rt, args, mem_pattern) = let pp_math_sigs ppf sigs = (Fmt.list ~sep:Fmt.cut pp_math_sig) ppf sigs let pretty_print_math_sigs = Fmt.str "@[@,%a@]" pp_math_sigs -let dist_name_suffix (module StdLib : Library) udf_names name = +let dist_name_suffix (module StdLibrary : Library) udf_names name = let is_udf_name s = List.exists ~f:(fun (n, _) -> String.equal s n) udf_names in Utils.distribution_suffices |> List.filter ~f:(fun sfx -> - StdLib.is_stdlib_function_name (name ^ sfx) || is_udf_name (name ^ sfx) ) + StdLibrary.is_stdlib_function_name (name ^ sfx) + || is_udf_name (name ^ sfx) ) |> List.hd_exn From 79acb8192229d43e4c84e4a7f23b005691a07da3 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 28 Apr 2022 09:42:38 -0400 Subject: [PATCH 08/14] Update Stancjs to use functors --- .../Pedantic_analysis.ml | 1 + src/stanc/stanc.ml | 17 +++++------ src/stancjs/stancjs.ml | 30 ++++++++++++------- 3 files changed, 27 insertions(+), 21 deletions(-) diff --git a/src/analysis_and_optimization/Pedantic_analysis.ml b/src/analysis_and_optimization/Pedantic_analysis.ml index 1c2a5cd7cb..61bed39fac 100644 --- a/src/analysis_and_optimization/Pedantic_analysis.ml +++ b/src/analysis_and_optimization/Pedantic_analysis.ml @@ -487,6 +487,7 @@ let settings_constant_prop = ; copy_propagation= true ; partial_evaluation= true } +(** Pedantic mode is only really valid for the Stan Math backend *) module Optimizer = Optimize.Make (Stan_math_backend.Stan_math_library) (* Collect all pedantic mode warnings, sorted, to stderr *) diff --git a/src/stanc/stanc.ml b/src/stanc/stanc.ml index efbf8da480..c1f4365d68 100644 --- a/src/stanc/stanc.ml +++ b/src/stanc/stanc.ml @@ -3,20 +3,17 @@ open Core_kernel open Core_kernel.Poly open Frontend -open Stan_math_backend -open Analysis_and_optimization open Middle +open Analysis_and_optimization +open Stan_math_backend (* Initialize functor modules with the Stan Math Library *) -module CppLibrary : Std_library_utils.Library = - Stan_math_backend.Stan_math_library - -module Typechecker = Typechecking.Make (CppLibrary) -module Deprecations = Deprecation_analysis.Make (CppLibrary) +module Typechecker = Typechecking.Make (Stan_math_library) +module Deprecations = Deprecation_analysis.Make (Stan_math_library) module Canonicalizer = Canonicalize.Make (Deprecations) -module ModelInfo = Info.Make (CppLibrary) -module Ast2Mir = Ast_to_Mir.Make (CppLibrary) -module Optimizer = Optimize.Make (CppLibrary) +module ModelInfo = Info.Make (Stan_math_library) +module Ast2Mir = Ast_to_Mir.Make (Stan_math_library) +module Optimizer = Optimize.Make (Stan_math_library) (** The main program. *) let version = "%%NAME%%3 %%VERSION%%" diff --git a/src/stancjs/stancjs.ml b/src/stancjs/stancjs.ml index 7277a88e2f..46ee6ef316 100644 --- a/src/stancjs/stancjs.ml +++ b/src/stancjs/stancjs.ml @@ -1,11 +1,19 @@ open Core_kernel open Core_kernel.Poly open Frontend -open Stan_math_backend -open Analysis_and_optimization open Middle +open Analysis_and_optimization +open Stan_math_backend open Js_of_ocaml +(* Initialize functors with Stan Math C++ signatures *) +module Typechecker = Typechecking.Make (Stan_math_library) +module Deprecations = Deprecation_analysis.Make (Stan_math_library) +module Canonicalizer = Canonicalize.Make (Deprecations) +module ModelInfo = Info.Make (Stan_math_library) +module Ast2Mir = Ast_to_Mir.Make (Stan_math_library) +module Optimizer = Optimize.Make (Stan_math_library) + let version = "%%NAME%% %%VERSION%%" let warn_uninitialized_msgs (uninit_vars : (Location_span.t * string) Set.Poly.t) @@ -21,8 +29,8 @@ let warn_uninitialized_msgs (uninit_vars : (Location_span.t * string) Set.Poly.t let stan2cpp model_name model_string is_flag_set flag_val = Common.Gensym.reset_danger_use_cautiously () ; - Typechecker.model_name := model_name ; - Typechecker.check_that_all_functions_have_definition := + Typechecking.model_name := model_name ; + Typechecking.check_that_all_functions_have_definition := not (is_flag_set "allow_undefined" || is_flag_set "allow-undefined") ; Transform_Mir.use_opencl := is_flag_set "use-opencl" ; Stan_math_code_gen.standalone_functions := @@ -45,7 +53,7 @@ let stan2cpp model_name model_string is_flag_set flag_val = >>| fun (typed_ast, type_warnings) -> let warnings = parser_warnings @ type_warnings in if is_flag_set "info" then - r.return (Result.Ok (Info.info typed_ast), warnings, []) ; + r.return (Result.Ok (ModelInfo.info typed_ast), warnings, []) ; let canonicalizer_settings = if is_flag_set "print-canonical" then Canonicalize.all else @@ -67,7 +75,7 @@ let stan2cpp model_name model_string is_flag_set flag_val = flag_val "max-line-length" |> Option.map ~f:int_of_string |> Option.value ~default:78 in - let mir = Ast_to_Mir.trans_prog model_name typed_ast in + let mir = Ast2Mir.trans_prog model_name typed_ast in let tx_mir = Transform_Mir.trans_prog mir in if is_flag_set "auto-format" || is_flag_set "print-canonical" then r.return @@ -76,7 +84,7 @@ let stan2cpp model_name model_string is_flag_set flag_val = ~bare_functions:(is_flag_set "functions-only") ~line_length ~inline_includes:canonicalizer_settings.inline_includes - (Canonicalize.canonicalize_program typed_ast + (Canonicalizer.canonicalize_program typed_ast canonicalizer_settings ) ) , warnings , [] ) ; @@ -92,7 +100,7 @@ let stan2cpp model_name model_string is_flag_set flag_val = r.return ( Result.Ok (Debug_data_generation.print_data_prog - (Ast_to_Mir.gather_data typed_ast) ) + (Ast2Mir.gather_data typed_ast) ) , warnings , [] ) ; let opt_mir = @@ -102,7 +110,7 @@ let stan2cpp model_name model_string is_flag_set flag_val = else if is_flag_set "Oexperimental" || is_flag_set "O" then Optimize.Oexperimental else Optimize.O0 in - Optimize.optimization_suite + Optimizer.optimization_suite ~settings:(Optimize.level_optimizations opt_lvl) tx_mir in if is_flag_set "debug-optimized-mir" then @@ -181,11 +189,11 @@ let stan2cpp_wrapped name code (flags : Js.string_array Js.t Js.opt) = wrap_result ?printed_filename ~code result ~warnings let dump_stan_math_signatures () = - Js.string @@ Fmt.str "%a" Stan_math_signatures.pretty_print_all_math_sigs () + Js.string @@ Fmt.str "%a" Stan_math_library.pretty_print_all_math_sigs () let dump_stan_math_distributions () = Js.string - @@ Fmt.str "%a" Stan_math_signatures.pretty_print_all_math_distributions () + @@ Fmt.str "%a" Stan_math_library.pretty_print_all_math_distributions () let () = Js.export "dump_stan_math_signatures" dump_stan_math_signatures ; From 72a10fe05d03f5f34693997818446d7095ae0258 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 28 Apr 2022 10:01:09 -0400 Subject: [PATCH 09/14] Cleanup, move some signatures to _intf files, capitalize module types --- .../Dependence_analysis.ml | 7 +- .../Monotone_framework.ml | 22 ++--- ...rk_sigs.mli => Monotone_framework_intf.ml} | 0 src/analysis_and_optimization/Optimize.ml | 51 +--------- src/analysis_and_optimization/Optimize.mli | 94 ++----------------- .../Optimize_intf.ml | 87 +++++++++++++++++ .../Partial_evaluation.ml | 5 +- .../Partial_evaluation.mli | 5 +- src/analysis_and_optimization/dune | 2 - src/frontend/Ast_to_Mir.ml | 5 +- src/frontend/Ast_to_Mir.mli | 4 +- src/frontend/Canonicalize.ml | 5 +- src/frontend/Canonicalize.mli | 6 +- src/frontend/Deprecation_analysis.ml | 4 +- src/frontend/Deprecation_analysis.mli | 4 +- src/frontend/Info.ml | 4 +- src/frontend/Info.mli | 4 +- src/frontend/Typechecking.ml | 21 +---- src/frontend/Typechecking.mli | 34 ++----- src/frontend/Typechecking_intf.ml | 29 ++++++ 20 files changed, 171 insertions(+), 222 deletions(-) rename src/analysis_and_optimization/{Monotone_framework_sigs.mli => Monotone_framework_intf.ml} (100%) create mode 100644 src/analysis_and_optimization/Optimize_intf.ml create mode 100644 src/frontend/Typechecking_intf.ml diff --git a/src/analysis_and_optimization/Dependence_analysis.ml b/src/analysis_and_optimization/Dependence_analysis.ml index 0bc86aeeb8..7b4a28951b 100644 --- a/src/analysis_and_optimization/Dependence_analysis.ml +++ b/src/analysis_and_optimization/Dependence_analysis.ml @@ -4,7 +4,7 @@ open Middle open Dataflow_types open Mir_utils open Dataflow_utils -open Monotone_framework_sigs +open Monotone_framework_intf open Monotone_framework (***********************************) @@ -119,9 +119,8 @@ let mir_reaching_definitions (mir : Program.Typed.t) (stmt : Stmt.Located.t) : Map.Poly.map rd_map ~f:(fun {entry; exit} -> {entry= to_rd_set entry; exit= to_rd_set exit} ) -let all_labels - (module Flowgraph : Monotone_framework_sigs.FLOWGRAPH with type labels = int) - : int Set.Poly.t = +let all_labels (module Flowgraph : FLOWGRAPH with type labels = int) : + int Set.Poly.t = let step set = Set.Poly.union set (union_map set ~f:(fun l -> Map.Poly.find_exn Flowgraph.successors l)) diff --git a/src/analysis_and_optimization/Monotone_framework.ml b/src/analysis_and_optimization/Monotone_framework.ml index 381fa736d5..c607fd96b3 100644 --- a/src/analysis_and_optimization/Monotone_framework.ml +++ b/src/analysis_and_optimization/Monotone_framework.ml @@ -2,7 +2,7 @@ open Core_kernel open Core_kernel.Poly -open Monotone_framework_sigs +open Monotone_framework_intf open Mir_utils open Middle @@ -310,7 +310,7 @@ let minimal_variables_lattice initial_variables = (* The transfer function for a constant propagation analysis *) let constant_propagation_transfer - (module Partial_evaluator : Partial_evaluation.PartialEvaluator) + (module Partial_evaluator : Partial_evaluation.PARTIAL_EVALUATOR) ?(preserve_stability = false) (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) = ( module struct @@ -866,7 +866,7 @@ let rec declared_variables_stmt (List.map ~f:(fun x -> declared_variables_stmt x.pattern) l) let propagation_mfp (prog : Program.Typed.t) - (module Flowgraph : Monotone_framework_sigs.FLOWGRAPH with type labels = int) + (module Flowgraph : FLOWGRAPH with type labels = int) (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) (propagation_transfer : (int, Stmt.Located.Non_recursive.t) Map.Poly.t @@ -899,7 +899,7 @@ let propagation_mfp (prog : Program.Typed.t) Mf.mfp () let reaching_definitions_mfp (mir : Program.Typed.t) - (module Flowgraph : Monotone_framework_sigs.FLOWGRAPH with type labels = int) + (module Flowgraph : FLOWGRAPH with type labels = int) (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) = let variables = ( module struct @@ -920,7 +920,7 @@ let reaching_definitions_mfp (mir : Program.Typed.t) Mf.mfp () let initialized_vars_mfp (total : string Set.Poly.t) - (module Flowgraph : Monotone_framework_sigs.FLOWGRAPH with type labels = int) + (module Flowgraph : FLOWGRAPH with type labels = int) (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) = let (module Lattice) = dual_powerset_lattice_empty_initial @@ -951,8 +951,7 @@ let globals (prog : Program.Typed.t) = (** Monotone framework instance for live_variables analysis. Expects reverse flowgraph. *) let live_variables_mfp (prog : Program.Typed.t) - (module Rev_Flowgraph : Monotone_framework_sigs.FLOWGRAPH - with type labels = int ) + (module Rev_Flowgraph : FLOWGRAPH with type labels = int) (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) = let never_kill = globals prog in let variables = @@ -972,10 +971,8 @@ let live_variables_mfp (prog : Program.Typed.t) (** Instantiate all four instances of the monotone framework for lazy code motion, reusing code between them *) -let lazy_expressions_mfp - (module Flowgraph : Monotone_framework_sigs.FLOWGRAPH with type labels = int) - (module Rev_Flowgraph : Monotone_framework_sigs.FLOWGRAPH - with type labels = int ) +let lazy_expressions_mfp (module Flowgraph : FLOWGRAPH with type labels = int) + (module Rev_Flowgraph : FLOWGRAPH with type labels = int) (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) = let all_expressions = used_subexpressions_stmt @@ -1033,8 +1030,7 @@ let lazy_expressions_mfp * *) let minimal_variables_mfp - (module Circular_Fwd_Flowgraph : Monotone_framework_sigs.FLOWGRAPH - with type labels = int ) + (module Circular_Fwd_Flowgraph : FLOWGRAPH with type labels = int) (flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) (initial_variables : string Set.Poly.t) (gen_variable : diff --git a/src/analysis_and_optimization/Monotone_framework_sigs.mli b/src/analysis_and_optimization/Monotone_framework_intf.ml similarity index 100% rename from src/analysis_and_optimization/Monotone_framework_sigs.mli rename to src/analysis_and_optimization/Monotone_framework_intf.ml diff --git a/src/analysis_and_optimization/Optimize.ml b/src/analysis_and_optimization/Optimize.ml index 02ce9f90ac..0ac3815b2e 100644 --- a/src/analysis_and_optimization/Optimize.ml +++ b/src/analysis_and_optimization/Optimize.ml @@ -5,23 +5,7 @@ open Core_kernel.Poly open Common open Middle open Mir_utils - -type optimization_settings = - { function_inlining: bool - ; static_loop_unrolling: bool - ; one_step_loop_unrolling: bool - ; list_collapsing: bool - ; block_fixing: bool - ; allow_uninitialized_decls: bool - ; constant_propagation: bool - ; expression_propagation: bool - ; copy_propagation: bool - ; dead_code_elimination: bool - ; partial_evaluation: bool - ; lazy_code_motion: bool - ; optimize_ad_levels: bool - ; preserve_stability: bool - ; optimize_soa: bool } +open Optimize_intf let settings_const b = { function_inlining= b @@ -66,34 +50,7 @@ let level_optimizations (lvl : optimization_level) : optimization_settings = ; optimize_soa= true } | Oexperimental -> all_optimizations -module type Optimizer = sig - val function_inlining : Program.Typed.t -> Program.Typed.t - val static_loop_unrolling : Program.Typed.t -> Program.Typed.t - val one_step_loop_unrolling : Program.Typed.t -> Program.Typed.t - val list_collapsing : Program.Typed.t -> Program.Typed.t - val block_fixing : Program.Typed.t -> Program.Typed.t - - val constant_propagation : - ?preserve_stability:bool -> Program.Typed.t -> Program.Typed.t - - val expression_propagation : - ?preserve_stability:bool -> Program.Typed.t -> Program.Typed.t - - val copy_propagation : Program.Typed.t -> Program.Typed.t - val dead_code_elimination : Program.Typed.t -> Program.Typed.t - val partial_evaluation : Program.Typed.t -> Program.Typed.t - - val lazy_code_motion : - ?preserve_stability:bool -> Program.Typed.t -> Program.Typed.t - - val optimize_ad_levels : Program.Typed.t -> Program.Typed.t - val allow_uninitialized_decls : Program.Typed.t -> Program.Typed.t - - val optimization_suite : - ?settings:optimization_settings -> Program.Typed.t -> Program.Typed.t -end - -module Make (StdLibrary : Frontend.Std_library_utils.Library) : Optimizer = +module Make (StdLibrary : Frontend.Std_library_utils.Library) : OPTIMIZER = struct module Mem = Mem_pattern.Make (StdLibrary) module Partial_evaluator = Partial_evaluation.Make (StdLibrary) @@ -739,7 +696,7 @@ struct let propagation (propagation_transfer : (int, Stmt.Located.Non_recursive.t) Map.Poly.t - -> (module Monotone_framework_sigs.TRANSFER_FUNCTION + -> (module Monotone_framework_intf.TRANSFER_FUNCTION with type labels = int and type properties = (string, Middle.Expr.Typed.t) Map.Poly.t option ) ) (mir : Program.Typed.t) = @@ -827,7 +784,7 @@ struct let dead_code_elim_stmt_base i stmt = (* NOTE: entry in the reverse flowgraph, so exit in the forward flowgraph *) let live_variables_s = - (Map.find_exn live_variables i).Monotone_framework_sigs.entry in + (Map.find_exn live_variables i).Monotone_framework_intf.entry in match stmt with | Stmt.Fixed.Pattern.Assignment ((x, _, []), rhs) -> if Set.Poly.mem live_variables_s x || cannot_remove_expr rhs then diff --git a/src/analysis_and_optimization/Optimize.mli b/src/analysis_and_optimization/Optimize.mli index 1f3ae7c018..665e5d3914 100644 --- a/src/analysis_and_optimization/Optimize.mli +++ b/src/analysis_and_optimization/Optimize.mli @@ -1,24 +1,5 @@ (* Code for optimization passes on the MIR *) -open Middle - -(** Interface for turning individual optimizations on/off. Useful for testing - and for top-level interface flags. *) -type optimization_settings = - { function_inlining: bool - ; static_loop_unrolling: bool - ; one_step_loop_unrolling: bool - ; list_collapsing: bool - ; block_fixing: bool - ; allow_uninitialized_decls: bool - ; constant_propagation: bool - ; expression_propagation: bool - ; copy_propagation: bool - ; dead_code_elimination: bool - ; partial_evaluation: bool - ; lazy_code_motion: bool - ; optimize_ad_levels: bool - ; preserve_stability: bool - ; optimize_soa: bool } +open Optimize_intf val all_optimizations : optimization_settings val no_optimizations : optimization_settings @@ -27,71 +8,8 @@ type optimization_level = O0 | O1 | Oexperimental val level_optimizations : optimization_level -> optimization_settings -module type Optimizer = sig - val function_inlining : Program.Typed.t -> Program.Typed.t - (** Inline all functions except for ones with forward declarations - (e.g. recursive functions, mutually recursive functions, and - functions without a definition *) - - val static_loop_unrolling : Program.Typed.t -> Program.Typed.t - (** Unroll all for-loops with constant bounds, as long as they do - not contain break or continue statements in their body at the - top level *) - - val one_step_loop_unrolling : Program.Typed.t -> Program.Typed.t - (** Unroll all loops for one iteration, as long as they do - not contain break or continue statements in their body at the - top level *) - - val list_collapsing : Program.Typed.t -> Program.Typed.t - (** Remove redundant SList constructors from the Mir that might have - been introduced by other optimizations *) - - val block_fixing : Program.Typed.t -> Program.Typed.t - (** Make sure that SList constructors directly under if, for, while or fundef - constructors are replaced with Block constructors. - This should probably be run before we generate code. *) - - val constant_propagation : - ?preserve_stability:bool -> Program.Typed.t -> Program.Typed.t - (** Propagate constant values through variable assignments *) - - val expression_propagation : - ?preserve_stability:bool -> Program.Typed.t -> Program.Typed.t - (** Propagate arbitrary expressions through variable assignments. - This can be useful for opening up new possibilities for partial evaluation. - It should be followed by some CSE or lazy code motion pass, however. *) - - val copy_propagation : Program.Typed.t -> Program.Typed.t - (** Propagate copies of variables through assignments. *) - - val dead_code_elimination : Program.Typed.t -> Program.Typed.t - (** Eliminate semantically redundant code branches. - This includes removing redundant assignments (because they will be overwritten) - and removing redundant code in program branches that will never be reached. *) - - val partial_evaluation : Program.Typed.t -> Program.Typed.t - (** Partially evaluate expressions in the program. This includes simplification using - algebraic identities of logical and arithmetic operators as well as Stan math functions. *) - - val lazy_code_motion : - ?preserve_stability:bool -> Program.Typed.t -> Program.Typed.t - (** Perform partial redundancy elmination using the lazy code motion algorithm. This - subsumes common subexpression elimination and loop-invariant code motion. *) - - val optimize_ad_levels : Program.Typed.t -> Program.Typed.t - (** Assign the optimal ad-levels to local variables. That means, make sure that - variables only ever get treated as autodiff variables if they have some - dependency on a parameter *) - - val allow_uninitialized_decls : Program.Typed.t -> Program.Typed.t - (** Marks Decl types such that, if the first assignment after the decl - assigns to the full object, allow the object to be constructed but - not uninitialized. *) - - val optimization_suite : - ?settings:optimization_settings -> Program.Typed.t -> Program.Typed.t - (** Perform all optimizations in this module on the MIR in an appropriate order. *) -end - -module Make (StdLibrary : Frontend.Std_library_utils.Library) : Optimizer +(** Produce an optimizer for the MIR which is parameterized by the + given library of functions. These are used in the partial evaluator + and memory optimizations + *) +module Make (StdLibrary : Frontend.Std_library_utils.Library) : OPTIMIZER diff --git a/src/analysis_and_optimization/Optimize_intf.ml b/src/analysis_and_optimization/Optimize_intf.ml new file mode 100644 index 0000000000..d3b4fac4c0 --- /dev/null +++ b/src/analysis_and_optimization/Optimize_intf.ml @@ -0,0 +1,87 @@ +open Middle + +(** Interface for turning individual optimizations on/off. Useful for testing + and for top-level interface flags. *) +type optimization_settings = + { function_inlining: bool + ; static_loop_unrolling: bool + ; one_step_loop_unrolling: bool + ; list_collapsing: bool + ; block_fixing: bool + ; allow_uninitialized_decls: bool + ; constant_propagation: bool + ; expression_propagation: bool + ; copy_propagation: bool + ; dead_code_elimination: bool + ; partial_evaluation: bool + ; lazy_code_motion: bool + ; optimize_ad_levels: bool + ; preserve_stability: bool + ; optimize_soa: bool } + +module type OPTIMIZER = sig + val function_inlining : Program.Typed.t -> Program.Typed.t + (** Inline all functions except for ones with forward declarations + (e.g. recursive functions, mutually recursive functions, and + functions without a definition *) + + val static_loop_unrolling : Program.Typed.t -> Program.Typed.t + (** Unroll all for-loops with constant bounds, as long as they do + not contain break or continue statements in their body at the + top level *) + + val one_step_loop_unrolling : Program.Typed.t -> Program.Typed.t + (** Unroll all loops for one iteration, as long as they do + not contain break or continue statements in their body at the + top level *) + + val list_collapsing : Program.Typed.t -> Program.Typed.t + (** Remove redundant SList constructors from the Mir that might have + been introduced by other optimizations *) + + val block_fixing : Program.Typed.t -> Program.Typed.t + (** Make sure that SList constructors directly under if, for, while or fundef + constructors are replaced with Block constructors. + This should probably be run before we generate code. *) + + val constant_propagation : + ?preserve_stability:bool -> Program.Typed.t -> Program.Typed.t + (** Propagate constant values through variable assignments *) + + val expression_propagation : + ?preserve_stability:bool -> Program.Typed.t -> Program.Typed.t + (** Propagate arbitrary expressions through variable assignments. + This can be useful for opening up new possibilities for partial evaluation. + It should be followed by some CSE or lazy code motion pass, however. *) + + val copy_propagation : Program.Typed.t -> Program.Typed.t + (** Propagate copies of variables through assignments. *) + + val dead_code_elimination : Program.Typed.t -> Program.Typed.t + (** Eliminate semantically redundant code branches. + This includes removing redundant assignments (because they will be overwritten) + and removing redundant code in program branches that will never be reached. *) + + val partial_evaluation : Program.Typed.t -> Program.Typed.t + (** Partially evaluate expressions in the program. This includes simplification using + algebraic identities of logical and arithmetic operators as well as Stan math functions. *) + + val lazy_code_motion : + ?preserve_stability:bool -> Program.Typed.t -> Program.Typed.t + (** Perform partial redundancy elmination using the lazy code motion algorithm. This + subsumes common subexpression elimination and loop-invariant code motion. *) + + val optimize_ad_levels : Program.Typed.t -> Program.Typed.t + (** Assign the optimal ad-levels to local variables. That means, make sure that + variables only ever get treated as autodiff variables if they have some + dependency on a parameter *) + + val allow_uninitialized_decls : Program.Typed.t -> Program.Typed.t + (** Marks Decl types such that, if the first assignment after the decl + assigns to the full object, allow the object to be constructed but + not uninitialized. *) + + val optimization_suite : + ?settings:optimization_settings -> Program.Typed.t -> Program.Typed.t + (** Perform all optimizations in this module on the MIR in an appropriate order. *) +end diff --git a/src/analysis_and_optimization/Partial_evaluation.ml b/src/analysis_and_optimization/Partial_evaluation.ml index cd78a69d1f..78cc289528 100644 --- a/src/analysis_and_optimization/Partial_evaluation.ml +++ b/src/analysis_and_optimization/Partial_evaluation.ml @@ -87,12 +87,13 @@ let is_multi_index = function | Index.MultiIndex _ | Upfrom _ | Between _ | All -> true | Single _ -> false -module type PartialEvaluator = sig +module type PARTIAL_EVALUATOR = sig val try_eval_expr : Expr.Typed.t -> Expr.Typed.t val eval_prog : Program.Typed.t -> Program.Typed.t end -module Make (StdLibrary : Frontend.Std_library_utils.Library) = struct +module Make (StdLibrary : Frontend.Std_library_utils.Library) : + PARTIAL_EVALUATOR = struct module TC = Frontend.Typechecking.Make (StdLibrary) let rec eval_expr ?(preserve_stability = false) (e : Expr.Typed.t) = diff --git a/src/analysis_and_optimization/Partial_evaluation.mli b/src/analysis_and_optimization/Partial_evaluation.mli index 2491a8a8dc..d64d8bfe55 100644 --- a/src/analysis_and_optimization/Partial_evaluation.mli +++ b/src/analysis_and_optimization/Partial_evaluation.mli @@ -1,8 +1,9 @@ open Middle -module type PartialEvaluator = sig +module type PARTIAL_EVALUATOR = sig val try_eval_expr : Expr.Typed.t -> Expr.Typed.t val eval_prog : Program.Typed.t -> Program.Typed.t end -module Make (StdLibrary : Frontend.Std_library_utils.Library) : PartialEvaluator +module Make (StdLibrary : Frontend.Std_library_utils.Library) : + PARTIAL_EVALUATOR diff --git a/src/analysis_and_optimization/dune b/src/analysis_and_optimization/dune index dd915065a2..c7f4a7dd05 100644 --- a/src/analysis_and_optimization/dune +++ b/src/analysis_and_optimization/dune @@ -3,7 +3,5 @@ (public_name stanc.analysis) (libraries core_kernel str fmt common middle frontend stan_math_backend) (inline_tests) - ;; TODO: Not sure what's going on but it's throwing an error that this module has no implementation - (modules_without_implementation monotone_framework_sigs) (preprocess (pps ppx_jane ppx_deriving.map ppx_deriving.fold))) diff --git a/src/frontend/Ast_to_Mir.ml b/src/frontend/Ast_to_Mir.ml index 09e199c984..964dc6f1e7 100644 --- a/src/frontend/Ast_to_Mir.ml +++ b/src/frontend/Ast_to_Mir.ml @@ -2,7 +2,7 @@ open Core_kernel open Core_kernel.Poly open Middle -module type Ast_Mir_translator = sig +module type AST_MIR_TRANSLATOR = sig val gather_data : Ast.typed_program -> (Expr.Typed.t SizedType.t * Expr.Typed.t Transformation.t * string) list @@ -10,7 +10,8 @@ module type Ast_Mir_translator = sig val trans_prog : string -> Ast.typed_program -> Program.Typed.t end -module Make (StdLibrary : Std_library_utils.Library) = struct +module Make (StdLibrary : Std_library_utils.Library) : AST_MIR_TRANSLATOR = +struct let trans_fn_kind kind name = let fname = Utils.stdlib_distribution_name name in match kind with diff --git a/src/frontend/Ast_to_Mir.mli b/src/frontend/Ast_to_Mir.mli index 3d43f2579b..bc39d6f323 100644 --- a/src/frontend/Ast_to_Mir.mli +++ b/src/frontend/Ast_to_Mir.mli @@ -1,7 +1,7 @@ (** Translate from the AST to the MIR *) open Middle -module type Ast_Mir_translator = sig +module type AST_MIR_TRANSLATOR = sig val gather_data : Ast.typed_program -> (Expr.Typed.t SizedType.t * Expr.Typed.t Transformation.t * string) list @@ -9,4 +9,4 @@ module type Ast_Mir_translator = sig val trans_prog : string -> Ast.typed_program -> Program.Typed.t end -module Make (StdLibrary : Std_library_utils.Library) : Ast_Mir_translator +module Make (StdLibrary : Std_library_utils.Library) : AST_MIR_TRANSLATOR diff --git a/src/frontend/Canonicalize.ml b/src/frontend/Canonicalize.ml index c26b900015..59156cf4ee 100644 --- a/src/frontend/Canonicalize.ml +++ b/src/frontend/Canonicalize.ml @@ -13,7 +13,7 @@ let none = ; inline_includes= false ; braces= false } -module type Canonicalizer = sig +module type CANONICALIZER = sig val repair_syntax : untyped_program -> canonicalizer_settings -> untyped_program @@ -21,7 +21,8 @@ module type Canonicalizer = sig typed_program -> canonicalizer_settings -> typed_program end -module Make (Deprecation : Deprecation_analysis.Deprecation_analizer) = struct +module Make (Deprecation : Deprecation_analysis.DEPRECATION_ANALYZER) : + CANONICALIZER = struct let rec repair_syntax_stmt user_dists {stmt; smeta} = match stmt with | Tilde {arg; distribution= {name; id_loc}; args; truncation} -> diff --git a/src/frontend/Canonicalize.mli b/src/frontend/Canonicalize.mli index b2282c0e91..f9fb4933a8 100644 --- a/src/frontend/Canonicalize.mli +++ b/src/frontend/Canonicalize.mli @@ -13,7 +13,7 @@ type canonicalizer_settings = val all : canonicalizer_settings val none : canonicalizer_settings -module type Canonicalizer = sig +module type CANONICALIZER = sig val repair_syntax : untyped_program -> canonicalizer_settings -> untyped_program (** When deprecation canonicalization is enabled, this runs before typechecking @@ -25,5 +25,5 @@ module type Canonicalizer = sig and braces, etc. *) end -module Make (Deprecation : Deprecation_analysis.Deprecation_analizer) : - Canonicalizer +module Make (Deprecation : Deprecation_analysis.DEPRECATION_ANALYZER) : + CANONICALIZER diff --git a/src/frontend/Deprecation_analysis.ml b/src/frontend/Deprecation_analysis.ml index dbe0699f58..3cd8f23e27 100644 --- a/src/frontend/Deprecation_analysis.ml +++ b/src/frontend/Deprecation_analysis.ml @@ -2,7 +2,7 @@ open Core_kernel open Ast open Middle -module type Deprecation_analizer = sig +module type DEPRECATION_ANALYZER = sig val find_udf_log_suffix : typed_statement -> (string * Middle.UnsizedType.t) option @@ -20,7 +20,7 @@ module type Deprecation_analizer = sig val collect_warnings : typed_program -> Warnings.t list end -module Make (StdLibrary : Std_library_utils.Library) : Deprecation_analizer = +module Make (StdLibrary : Std_library_utils.Library) : DEPRECATION_ANALYZER = struct let stan_lib_deprecations = Map.merge_skewed StdLibrary.deprecated_distributions diff --git a/src/frontend/Deprecation_analysis.mli b/src/frontend/Deprecation_analysis.mli index 34e80aa35a..9c59fa5234 100644 --- a/src/frontend/Deprecation_analysis.mli +++ b/src/frontend/Deprecation_analysis.mli @@ -5,7 +5,7 @@ open Core_kernel open Ast -module type Deprecation_analizer = sig +module type DEPRECATION_ANALYZER = sig val find_udf_log_suffix : typed_statement -> (string * Middle.UnsizedType.t) option @@ -23,4 +23,4 @@ module type Deprecation_analizer = sig val collect_warnings : typed_program -> Warnings.t list end -module Make (StdLibrary : Std_library_utils.Library) : Deprecation_analizer +module Make (StdLibrary : Std_library_utils.Library) : DEPRECATION_ANALYZER diff --git a/src/frontend/Info.ml b/src/frontend/Info.ml index 19770dbe5b..66d0808b64 100644 --- a/src/frontend/Info.ml +++ b/src/frontend/Info.ml @@ -52,11 +52,11 @@ let includes_json () = ( List.rev !Preprocessor.included_files |> List.map ~f:(fun str -> `String str) ) ) ] -module type Information = sig +module type INFO = sig val info : Ast.typed_program -> string end -module Make (StdLibrary : Std_library_utils.Library) : Information = struct +module Make (StdLibrary : Std_library_utils.Library) : INFO = struct let rec get_function_calls_stmt ud_dists (funs, distrs) stmt = let acc = match stmt.stmt with diff --git a/src/frontend/Info.mli b/src/frontend/Info.mli index 7d111ef79e..1d07aab8fd 100644 --- a/src/frontend/Info.mli +++ b/src/frontend/Info.mli @@ -16,8 +16,8 @@ distributions used. *) -module type Information = sig +module type INFO = sig val info : Ast.typed_program -> string end -module Make (StdLibrary : Std_library_utils.Library) : Information +module Make (StdLibrary : Std_library_utils.Library) : INFO diff --git a/src/frontend/Typechecking.ml b/src/frontend/Typechecking.ml index 3920d898b1..10617e9530 100644 --- a/src/frontend/Typechecking.ml +++ b/src/frontend/Typechecking.ml @@ -17,6 +17,7 @@ open Core_kernel open Core_kernel.Poly open Middle open Ast +open Typechecking_intf module Env = Environment (* we only allow errors raised by this function *) @@ -83,25 +84,7 @@ let reserved_keywords = ; "get_lp"; "print"; "reject"; "typedef"; "struct"; "var"; "export"; "extern" ; "static"; "auto" ] -module type Typechecker = sig - val check_program_exn : untyped_program -> typed_program * Warnings.t list - - val check_program : - untyped_program - -> (typed_program * Warnings.t list, Semantic_error.t) result - - val operator_return_type : - Middle.Operator.t - -> (Middle.UnsizedType.autodifftype * Middle.UnsizedType.t) list - -> (Middle.UnsizedType.returntype * Promotion.t list) option - - val library_function_return_type : - string - -> (Middle.UnsizedType.autodifftype * Middle.UnsizedType.t) list - -> Middle.UnsizedType.returntype option -end - -module Make (StdLibrary : Std_library_utils.Library) : Typechecker = struct +module Make (StdLibrary : Std_library_utils.Library) : TYPECHECKER = struct let std_library_tenv : Env.t = Env.make_from_library StdLibrary.function_signatures diff --git a/src/frontend/Typechecking.mli b/src/frontend/Typechecking.mli index 97aacdcb15..f4288b7f9d 100644 --- a/src/frontend/Typechecking.mli +++ b/src/frontend/Typechecking.mli @@ -11,9 +11,14 @@ A type environment {!val:Environment.t} is used to hold variables and functions, including Stan math functions. This is a functional map, meaning it is handled immutably. + + This module is parameterized over a Standard Library of function signatures, See + [Std_library_utils.Library]. For the main compiler, this is + [Stan_math_backend.Stan_math_library] *) open Ast +open Typechecking_intf val model_name : string ref (** A reference to hold the model name. Relevant for checking variable @@ -31,31 +36,4 @@ val calculate_autodifftype : -> Middle.UnsizedType.t -> Middle.UnsizedType.autodifftype -module type Typechecker = sig - val check_program_exn : untyped_program -> typed_program * Warnings.t list - (** - Type check a full Stan program. - Can raise [Errors.SemanticError] - *) - - val check_program : - untyped_program - -> (typed_program * Warnings.t list, Semantic_error.t) result - (** - The safe version of [check_program_exn]. This catches - all [Errors.SemanticError] exceptions and converts them - into a [Result.t] - *) - - val operator_return_type : - Middle.Operator.t - -> (Middle.UnsizedType.autodifftype * Middle.UnsizedType.t) list - -> (Middle.UnsizedType.returntype * Promotion.t list) option - - val library_function_return_type : - string - -> (Middle.UnsizedType.autodifftype * Middle.UnsizedType.t) list - -> Middle.UnsizedType.returntype option -end - -module Make (StdLibrary : Std_library_utils.Library) : Typechecker +module Make (StdLibrary : Std_library_utils.Library) : TYPECHECKER diff --git a/src/frontend/Typechecking_intf.ml b/src/frontend/Typechecking_intf.ml new file mode 100644 index 0000000000..ffdd57ba84 --- /dev/null +++ b/src/frontend/Typechecking_intf.ml @@ -0,0 +1,29 @@ +open Ast + +(** Signature for a Stan typechecker *) +module type TYPECHECKER = sig + val check_program_exn : untyped_program -> typed_program * Warnings.t list + (** + Type check a full Stan program. + Can raise [Errors.SemanticError] + *) + + val check_program : + untyped_program + -> (typed_program * Warnings.t list, Semantic_error.t) result + (** + The safe version of [check_program_exn]. This catches + all [Errors.SemanticError] exceptions and converts them + into a [Result.t] + *) + + val operator_return_type : + Middle.Operator.t + -> (Middle.UnsizedType.autodifftype * Middle.UnsizedType.t) list + -> (Middle.UnsizedType.returntype * Promotion.t list) option + + val library_function_return_type : + string + -> (Middle.UnsizedType.autodifftype * Middle.UnsizedType.t) list + -> Middle.UnsizedType.returntype option +end From 0b6955193f05db31cfc4176633a3717e329ea3ef Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 28 Apr 2022 10:58:24 -0400 Subject: [PATCH 10/14] Clean up Stan_math_library, document --- docs/core_ideas.mld | 25 ++++++++++++++++++ docs/exposing_new_functions.mld | 7 +++--- src/stan_math_backend/Stan_math_library.mli | 28 ++++++++------------- 3 files changed, 39 insertions(+), 21 deletions(-) diff --git a/docs/core_ideas.mld b/docs/core_ideas.mld index 111cd474c0..b05bfec243 100644 --- a/docs/core_ideas.mld +++ b/docs/core_ideas.mld @@ -65,6 +65,31 @@ This takes some getting used to, and also can lead to some unhelpful type signat VSCode, because abbreviations are not always used in hover-over text. For example, [Expr.Typed.t], the MIR's typed expression type, actually has a signature of [Expr.Typed.Meta.t Expr.Fixed.t]. +{1 The [Library] interface and functors} + +Many modules of stanc are modeled as OCaml {{:https://ocaml.org/learn/tutorials/functors.html}functors}, +which take in another module as input and produce a module as output. For the most part, +these functors expect an instance of the [Library] interface defined in +[src/frontend/Std_library_utils.ml]. + +This module primarily contains signatures for the Stan standard library. For most users, +you can assume this will be filled in with [src/stan_math_backend/Stan_math_library.ml], +the object representing the {{:https://github.com/stan-dev/math}stan-dev/math} C++ library. + +Usages of these functors are rather simple, e.g. in the core stanc driver the line + +{[ +module Typechecker = Typechecking.Make (Stan_math_library) +]} + +defines a module [Typechecker] by supplying the functor [Typechecking.Make] with +the Stan C++ library module. After this, [Typechecker.check_program] will typecheck +an AST against those specific functions. + +As noted in the above tutorial link, the syntax of functors is often the hardest part +of using and understanding them. The functors which accept [Library] are all relatively +simple, and should serve as good examples to beginners with the concept. + {1 The [Fmt] library and pretty-printing} We extensively use the {{:https://erratique.ch/software/fmt}Fmt} library for our pretty-printing and code diff --git a/docs/exposing_new_functions.mld b/docs/exposing_new_functions.mld index 2d1ff34b32..087d5bffa6 100644 --- a/docs/exposing_new_functions.mld +++ b/docs/exposing_new_functions.mld @@ -7,7 +7,7 @@ For a function to be built into Stan, it has to be included in the Stan Math library and its signature has to be exposed to the compiler. -To do the latter, we have to add a corresponding line in [src/middle/Stan_math_signatures.ml]. +To do the latter, we have to add a corresponding line in [src/stan_math_backend/Stan_math_library.ml]. The compiler uses the signatures defined there to do type checking. @@ -130,8 +130,9 @@ For example, the following line defines the signature [add(real, matrix) => matr Functions such as the ODE integrators or [reduce_sum], which take in user-functions and a variable-length list of arguments, are {b NOT} added to this list. -These are instead treated as special cases in the [Typechecker] module in the frontend folder. It -it best to consult an existing example of how these are done before proceeding. +These are instead handled by special functions like [is_variadic_function_name]. They +must also be given custom typechecking rules in the private sub-module [Variadic_typechecking]. +It is best to consult an existing example of how these are done before proceeding. {1 Testing} diff --git a/src/stan_math_backend/Stan_math_library.mli b/src/stan_math_backend/Stan_math_library.mli index f22e5f4296..f999d9486c 100644 --- a/src/stan_math_backend/Stan_math_library.mli +++ b/src/stan_math_backend/Stan_math_library.mli @@ -3,39 +3,31 @@ functions for dealing with those signatures. *) -open Middle -open Frontend.Std_library_utils -include Library +include Frontend.Std_library_utils.Library + +(** These functions are used by the drivers to display + all available functions and distributions. They are + not part of the Library interface since different drivers + for different backends would likely want different behavior + here *) val pretty_print_all_math_sigs : unit Fmt.t val pretty_print_all_math_distributions : unit Fmt.t -(* TODO: We should think of a better encapsulization for these, - this doesn't scale well. -*) +(** These functions related to variadic functions + are specific to this backend and used + during code generation *) (* reduce_sum helpers *) val is_reduce_sum_fn : string -> bool -val reduce_sum_slice_types : UnsizedType.t list (* variadic ODE helpers *) val is_variadic_ode_fn : string -> bool val is_variadic_ode_nonadjoint_tol_fn : string -> bool val ode_tolerances_suffix : string val variadic_ode_adjoint_fn : string -val variadic_ode_mandatory_arg_types : fun_arg list -val variadic_ode_mandatory_fun_args : fun_arg list -val variadic_ode_tol_arg_types : fun_arg list -val variadic_ode_adjoint_ctl_tol_arg_types : fun_arg list -val variadic_ode_fun_return_type : UnsizedType.t -val variadic_ode_return_type : UnsizedType.t (* variadic DAE helpers *) val is_variadic_dae_fn : string -> bool val is_variadic_dae_tol_fn : string -> bool val dae_tolerances_suffix : string -val variadic_dae_mandatory_arg_types : fun_arg list -val variadic_dae_mandatory_fun_args : fun_arg list -val variadic_dae_tol_arg_types : fun_arg list -val variadic_dae_fun_return_type : UnsizedType.t -val variadic_dae_return_type : UnsizedType.t From d3db3e06e3be3ee6367d79f031a471ba09b73270 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 28 Apr 2022 16:05:40 -0400 Subject: [PATCH 11/14] Comments --- src/frontend/Std_library_utils.ml | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/src/frontend/Std_library_utils.ml b/src/frontend/Std_library_utils.ml index f7c5c766e8..176db30852 100644 --- a/src/frontend/Std_library_utils.ml +++ b/src/frontend/Std_library_utils.ml @@ -16,6 +16,15 @@ type deprecation_info = ; canonicalize_away: bool } [@@deriving sexp] +(* We could consider breaking up this module more, so we would have + more type-level guarantees about what each Functor is able to do + with the library. Most of them only need is_stdlib_function_name, + maybe get_signatures. + + The Stan_math_library could still satisfy all of them by + using [include] +*) + module type Library = sig (** This module is used as a parameter for many functors which rely on information about a backend-specific Stan library. *) @@ -26,9 +35,11 @@ module type Library = sig val distribution_families : string list val is_stdlib_function_name : string -> bool - (** Equivalent to [Hashtbl.mem stan_math_signatures s]*) + (** Equivalent to [Hashtbl.mem function_signatures s]*) val get_signatures : string -> signature list + (** Equivalent to [Hashtbl.find_multi function_signatures s]*) + val get_operator_signatures : Operator.t -> signature list val get_assignment_operator_signatures : Operator.t -> signature list val is_not_overloadable : string -> bool @@ -43,6 +54,9 @@ module type Library = sig -> Environment.t -> Ast.typed_expression list -> Ast.typed_expression + (** This function is responsible for typechecking varadic function + calls. It needs to live in the Library since this is usually + bespoke per-function. *) val operator_to_function_names : Operator.t -> string list val string_operator_to_function_name : string -> string @@ -50,11 +64,10 @@ module type Library = sig val deprecated_functions : deprecation_info String.Map.t end -module NullLibrary : Library = struct - (** A "standard library" for stan which contains no functions. +(** A "standard library" for Stan which contains no functions. Useful only for testing *) - +module NullLibrary : Library = struct let function_signatures : (string, signature list) Hashtbl.t = String.Table.create () @@ -77,10 +90,12 @@ module NullLibrary : Library = struct let deprecated_functions = String.Map.empty end -let pp_math_sig ppf (rt, args, mem_pattern) = +let pp_math_sig ppf ((rt, args, mem_pattern) : signature) = UnsizedType.pp ppf (UFun (args, rt, FnPlain, mem_pattern)) -let pp_math_sigs ppf sigs = (Fmt.list ~sep:Fmt.cut pp_math_sig) ppf sigs +let pp_math_sigs ppf (sigs : signature list) = + (Fmt.list ~sep:Fmt.cut pp_math_sig) ppf sigs + let pretty_print_math_sigs = Fmt.str "@[@,%a@]" pp_math_sigs let dist_name_suffix (module StdLibrary : Library) udf_names name = From dd4772fdcb84c76dfc99aa7bbb8053f9383dd187 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 12 May 2022 12:35:20 -0400 Subject: [PATCH 12/14] Cleanups ported from virtual library attempt --- dune-project | 4 ++-- src/frontend/Deprecation_analysis.ml | 8 -------- src/frontend/Typechecking.ml | 12 ++++++------ src/stan_math_backend/Stan_math_library.ml | 3 ++- src/stanc/stanc.ml | 6 ++---- 5 files changed, 12 insertions(+), 21 deletions(-) diff --git a/dune-project b/dune-project index 1afc3dfd55..4b11f611ec 100644 --- a/dune-project +++ b/dune-project @@ -13,8 +13,8 @@ (ppx_deriving (= 5.2.1)) (fmt (= 0.8.8)) (yojson (= 1.7.0)) - (ocamlformat (and :with-test (= 0.19))) - (merlin (and :with-test (= 4.3.1))) + (ocamlformat (and :with-test (= 0.19.0))) + (merlin :with-test) (utop :with-test) (ocp-indent :with-test) (patdiff :with-test) diff --git a/src/frontend/Deprecation_analysis.ml b/src/frontend/Deprecation_analysis.ml index 3cd8f23e27..5df128c679 100644 --- a/src/frontend/Deprecation_analysis.ml +++ b/src/frontend/Deprecation_analysis.ml @@ -114,14 +114,6 @@ struct , "Use of the `abs` function with real-valued arguments is \ deprecated; use function `fabs` instead." ) ] ) e - | FunApp (StanLib FnPlain, {name= "if_else"; _}, l) -> - acc - @ [ ( emeta.loc - , "The function `if_else` is deprecated and will be removed in \ - Stan 2.32.0. Use the conditional operator (x ? y : z) instead; \ - this can be automatically changed using the canonicalize flag \ - for stanc" ) ] - @ List.concat_map l ~f:(fun e -> collect_deprecated_expr [] e) | FunApp ((StanLib _ | UserDefined _), {name; _}, l) -> let w = match Map.find stan_lib_deprecations name with diff --git a/src/frontend/Typechecking.ml b/src/frontend/Typechecking.ml index 10617e9530..1c9e36dfd5 100644 --- a/src/frontend/Typechecking.ml +++ b/src/frontend/Typechecking.ml @@ -88,6 +88,9 @@ module Make (StdLibrary : Std_library_utils.Library) : TYPECHECKER = struct let std_library_tenv : Env.t = Env.make_from_library StdLibrary.function_signatures + let matching_library_function = + SignatureMismatch.matching_function std_library_tenv + let verify_identifier id : unit = if id.name = !model_name then Semantic_error.ident_is_model_name id.id_loc id.name |> error @@ -198,9 +201,7 @@ module Make (StdLibrary : Std_library_utils.Library) : TYPECHECKER = struct match name with | x when StdLibrary.is_variadic_function_name x -> StdLibrary.variadic_function_returntype x - | _ -> - SignatureMismatch.matching_function std_library_tenv name arg_tys - |> match_to_rt_option + | _ -> matching_library_function name arg_tys |> match_to_rt_option let operator_return_type op arg_tys = match (op, arg_tys) with @@ -211,7 +212,7 @@ module Make (StdLibrary : Std_library_utils.Library) : TYPECHECKER = struct | _ -> StdLibrary.operator_to_function_names op |> List.filter_map ~f:(fun name -> - SignatureMismatch.matching_function std_library_tenv name arg_tys + matching_library_function name arg_tys |> function | SignatureMismatch.UniqueMatch (rt, _, p) -> Some (rt, p) | _ -> None ) @@ -220,8 +221,7 @@ module Make (StdLibrary : Std_library_utils.Library) : TYPECHECKER = struct let assignmentoperator_return_type assop arg_tys = ( match assop with | Operator.Divide -> - SignatureMismatch.matching_function std_library_tenv "divide" arg_tys - |> match_to_rt_option + matching_library_function "divide" arg_tys |> match_to_rt_option | Plus | Minus | Times | EltTimes | EltDivide -> operator_return_type assop arg_tys |> Option.map ~f:fst | _ -> None ) diff --git a/src/stan_math_backend/Stan_math_library.ml b/src/stan_math_backend/Stan_math_library.ml index e815df3585..6af0117f11 100644 --- a/src/stan_math_backend/Stan_math_library.ml +++ b/src/stan_math_backend/Stan_math_library.ml @@ -763,7 +763,8 @@ let deprecated_functions = ; ("cov_exp_quad", std "gp_exp_quad_cov") (* ode integrators *) ; ("integrate_ode_rk45", ode "ode_rk45"); ("integrate_ode", ode "ode_rk45") ; ("integrate_ode_bdf", ode "ode_bdf") - ; ("integrate_ode_adams", ode "ode_adams") ] + ; ("integrate_ode_adams", ode "ode_adams") + ; ("if_else", std "the conditional operator (x ? y : z)") ] (* -- Some helper definitions to populate stan_math_signatures -- *) let bare_types = diff --git a/src/stanc/stanc.ml b/src/stanc/stanc.ml index 7c42fa2ae6..feed86a848 100644 --- a/src/stanc/stanc.ml +++ b/src/stanc/stanc.ml @@ -199,16 +199,14 @@ let options = ) ; ( "--include-paths" , Arg.String - (fun str -> - Preprocessor.include_paths := String.split_on_chars ~on:[','] str ) + (fun str -> Preprocessor.include_paths := String.split ~on:',' str) , " Takes a comma-separated list of directories that may contain a file \ in an #include directive (default = \"\")" ) ; ( "--include_paths" , Arg.String (fun str -> Preprocessor.include_paths := - !Preprocessor.include_paths @ String.split_on_chars ~on:[','] str - ) + !Preprocessor.include_paths @ String.split ~on:',' str ) , " Deprecated. Same as --include-paths. Will be removed in Stan 2.32.0" ) ; ( "--use-opencl" From dad6788b089a5a24f8f1292e69331b7a4f0fb8b4 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Thu, 12 May 2022 12:43:24 -0400 Subject: [PATCH 13/14] Dune promote --- test/integration/good/warning/pretty.expected | 104 +++++++++--------- 1 file changed, 52 insertions(+), 52 deletions(-) diff --git a/test/integration/good/warning/pretty.expected b/test/integration/good/warning/pretty.expected index 84a4e467df..82d695eb0c 100644 --- a/test/integration/good/warning/pretty.expected +++ b/test/integration/good/warning/pretty.expected @@ -1712,58 +1712,58 @@ model { y_p ~ normal(0, 1); } -Warning in 'if_else.stan', line 9, column 26: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 10, column 26: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 11, column 26: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 12, column 26: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 21, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 22, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 23, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 24, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 26, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 27, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 28, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 29, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc -Warning in 'if_else.stan', line 30, column 28: The function `if_else` is - deprecated and will be removed in Stan 2.32.0. Use the conditional - operator (x ? y : z) instead; this can be automatically changed using the - canonicalize flag for stanc +Warning in 'if_else.stan', line 9, column 26: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 10, column 26: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 11, column 26: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 12, column 26: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 21, column 28: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 22, column 28: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 23, column 28: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 24, column 28: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 26, column 28: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 27, column 28: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 28, column 28: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 29, column 28: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc +Warning in 'if_else.stan', line 30, column 28: if_else is deprecated and will + be removed in Stan 2.32.0. Use the conditional operator (x ? y : z) + instead. This can be automatically changed using the canonicalize flag + for stanc $ ../../../../../install/default/bin/stanc --auto-format increment_log_prob.stan transformed data { int n; From 9d684d9f25b11d660c4e70f1387d048511cc7242 Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 23 Sep 2022 14:59:25 -0400 Subject: [PATCH 14/14] Empty commit