Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: encapsulate variadic functions in typechecking #1259

Merged
merged 3 commits into from
Oct 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions docs/exposing_new_functions.mld
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,15 @@ For example, the following line defines the signature [add(real, matrix) => matr
Functions such as the ODE integrators or [reduce_sum], which take in user-functions and a variable-length
list of arguments, are {b NOT} added to this list.

These are instead treated as special cases in the [Typechecker] module in the frontend folder. It
it best to consult an existing example of how these are done before proceeding.
"Nice" variadic functions are added to the hashtable [Stan_math_signatures.stan_math_variadic_signatures].
This is probably sufficient for most variadic functions, e.g. all the ODE solvers and DAE solvers are done
via this method.
[reduce_sum] is not "nice", since it is both variadic and {e polymorphic}, requiring certain arguments to have the same
(but {e not predetermined}) type. Therefore, [reduce_sum] is treated as special case in the [Typechecker]
module in the frontend folder.

Note that higher-order functions also usually require changes to the C++ code generation to work properly.
It is best to consult an existing example of how these are done before proceeding.

{1 Testing}

Expand Down
3 changes: 1 addition & 2 deletions src/analysis_and_optimization/Memory_patterns.ml
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,8 @@ let query_stan_math_mem_pattern_support (name : string)
(args : (UnsizedType.autodifftype * UnsizedType.t) list) =
let open Stan_math_signatures in
match name with
| x when is_stan_math_variadic_function_name x -> false
| x when is_reduce_sum_fn x -> false
| x when is_variadic_ode_fn x -> false
| x when is_variadic_dae_fn x -> false
| _ ->
let name =
string_operator_to_stan_math_fns (Utils.stdlib_distribution_name name)
Expand Down
25 changes: 4 additions & 21 deletions src/frontend/Semantic_error.ml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ module TypeError = struct
* UnsizedType.t list
* (UnsizedType.autodifftype * UnsizedType.t) list
* SignatureMismatch.function_mismatch
| IllTypedVariadicDE of
| IllTypedVariadic of
string
* UnsizedType.t list
* (UnsizedType.autodifftype * UnsizedType.t) list
Expand Down Expand Up @@ -131,7 +131,7 @@ module TypeError = struct
| IllTypedReduceSumGeneric (name, arg_tys, expected_args, error) ->
SignatureMismatch.pp_signature_mismatch ppf
(name, arg_tys, ([((ReturnType UReal, expected_args), error)], false))
| IllTypedVariadicDE (name, arg_tys, args, error, return_type) ->
| IllTypedVariadic (name, arg_tys, args, error, return_type) ->
SignatureMismatch.pp_signature_mismatch ppf
( name
, arg_tys
Expand Down Expand Up @@ -550,25 +550,8 @@ let illtyped_reduce_sum_generic loc name arg_tys expected_args error =
, TypeError.IllTypedReduceSumGeneric (name, arg_tys, expected_args, error)
)

let illtyped_variadic_ode loc name arg_tys args error =
TypeError
( loc
, TypeError.IllTypedVariadicDE
( name
, arg_tys
, args
, error
, Stan_math_signatures.variadic_ode_fun_return_type ) )

let illtyped_variadic_dae loc name arg_tys args error =
TypeError
( loc
, TypeError.IllTypedVariadicDE
( name
, arg_tys
, args
, error
, Stan_math_signatures.variadic_dae_fun_return_type ) )
let illtyped_variadic loc name arg_tys args fn_rt error =
TypeError (loc, TypeError.IllTypedVariadic (name, arg_tys, args, error, fn_rt))

let ambiguous_function_promotion loc name arg_tys signatures =
TypeError
Expand Down
11 changes: 2 additions & 9 deletions src/frontend/Semantic_error.mli
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,6 @@ val illtyped_reduce_sum_generic :
-> SignatureMismatch.function_mismatch
-> t

val illtyped_variadic_ode :
Location_span.t
-> string
-> UnsizedType.t list
-> (UnsizedType.autodifftype * UnsizedType.t) list
-> SignatureMismatch.function_mismatch
-> t

val ambiguous_function_promotion :
Location_span.t
-> string
Expand All @@ -74,11 +66,12 @@ val ambiguous_function_promotion :
list
-> t

val illtyped_variadic_dae :
val illtyped_variadic :
Location_span.t
-> string
-> UnsizedType.t list
-> (UnsizedType.autodifftype * UnsizedType.t) list
-> UnsizedType.t
-> SignatureMismatch.function_mismatch
-> t

Expand Down
2 changes: 1 addition & 1 deletion src/frontend/SignatureMismatch.ml
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ let matching_function env name args =
let matching_stanlib_function =
matching_function Environment.stan_math_environment

let check_variadic_args allow_lpdf mandatory_arg_tys mandatory_fun_arg_tys
let check_variadic_args ~allow_lpdf mandatory_arg_tys mandatory_fun_arg_tys
fun_return args =
let minimal_func_type =
UnsizedType.UFun (mandatory_fun_arg_tys, ReturnType fun_return, FnPlain, AoS)
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/SignatureMismatch.mli
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ val matching_stanlib_function :
*)

val check_variadic_args :
bool
allow_lpdf:bool
-> (UnsizedType.autodifftype * UnsizedType.t) list
-> (UnsizedType.autodifftype * UnsizedType.t) list
-> UnsizedType.t
Expand Down
144 changes: 44 additions & 100 deletions src/frontend/Typechecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ let verify_name_fresh_udf loc tenv name =
(* variadic functions are currently not in math sigs and aren't
overloadable due to their separate typechecking *)
Stan_math_signatures.is_reduce_sum_fn name
|| Stan_math_signatures.is_variadic_ode_fn name
|| Stan_math_signatures.is_variadic_dae_fn name
|| Stan_math_signatures.is_stan_math_variadic_function_name name
then Semantic_error.ident_is_stanmath_name loc name |> error
else if Utils.is_unnormalized_distribution name then
Semantic_error.udf_is_unnormalized_fn loc name |> error
Expand Down Expand Up @@ -191,14 +190,13 @@ let match_to_rt_option = function
| _ -> None

let stan_math_return_type name arg_tys =
match name with
| x when Stan_math_signatures.is_reduce_sum_fn x ->
match
Hashtbl.find Stan_math_signatures.stan_math_variadic_signatures name
with
| Some {return_type; _} -> Some (UnsizedType.ReturnType return_type)
| None when Stan_math_signatures.is_reduce_sum_fn name ->
Some (UnsizedType.ReturnType UReal)
| x when Stan_math_signatures.is_variadic_ode_fn x ->
Some (UnsizedType.ReturnType (UArray UVector))
| x when Stan_math_signatures.is_variadic_dae_fn x ->
Some (UnsizedType.ReturnType (UArray UVector))
| _ ->
| None ->
SignatureMismatch.matching_stanlib_function name arg_tys
|> match_to_rt_option

Expand Down Expand Up @@ -571,30 +569,25 @@ let make_function_variable cf loc id = function

let rec check_fn ~is_cond_dist loc cf tenv id (tes : Ast.typed_expression list)
=
if Stan_math_signatures.is_reduce_sum_fn id.name then
if Stan_math_signatures.is_stan_math_variadic_function_name id.name then
check_variadic ~is_cond_dist loc cf tenv id tes
else if Stan_math_signatures.is_reduce_sum_fn id.name then
check_reduce_sum ~is_cond_dist loc cf tenv id tes
else if Stan_math_signatures.is_variadic_ode_fn id.name then
check_variadic_ode ~is_cond_dist loc cf tenv id tes
else if Stan_math_signatures.is_variadic_dae_fn id.name then
check_variadic_dae ~is_cond_dist loc cf tenv id tes
else check_normal_fn ~is_cond_dist loc tenv id tes

(** Reduce sum is a special case, even compared to the other
variadic functions, because it is polymorphic in the type of the
first argument. The first, fourth, and fifth arguments must agree,
which is too complicated to be captured declaratively. *)
and check_reduce_sum ~is_cond_dist loc cf tenv id tes =
let basic_mismatch () =
let mandatory_args =
UnsizedType.[(AutoDiffable, UArray UReal); (AutoDiffable, UInt)] in
let mandatory_fun_args =
UnsizedType.
[(AutoDiffable, UArray UReal); (DataOnly, UInt); (DataOnly, UInt)] in
SignatureMismatch.check_variadic_args true mandatory_args mandatory_fun_args
UReal (get_arg_types tes) in
let fail () =
let expected_args, err =
basic_mismatch () |> Result.error |> Option.value_exn in
Semantic_error.illtyped_reduce_sum_generic loc id.name
(List.map ~f:type_of_expr_typed tes)
expected_args err
|> error in
SignatureMismatch.check_variadic_args ~allow_lpdf:true mandatory_args
mandatory_fun_args UReal (get_arg_types tes) in
let matching remaining_es fn =
match fn with
| Env.
Expand All @@ -611,7 +604,7 @@ and check_reduce_sum ~is_cond_dist loc cf tenv id tes =
let arg_types =
(calculate_autodifftype cf Functions ftype, ftype)
:: get_arg_types remaining_es in
SignatureMismatch.check_variadic_args true mandatory_args
SignatureMismatch.check_variadic_args ~allow_lpdf:true mandatory_args
mandatory_fun_args UReal arg_types
| _ -> basic_mismatch () in
match tes with
Expand All @@ -633,81 +626,25 @@ and check_reduce_sum ~is_cond_dist loc cf tenv id tes =
(List.map ~f:type_of_expr_typed tes)
expected_args err
|> error )
| _ -> fail ()

and check_variadic_ode ~is_cond_dist loc cf tenv id tes =
let optional_tol_mandatory_args =
if Stan_math_signatures.variadic_ode_adjoint_fn = id.name then
Stan_math_signatures.variadic_ode_adjoint_ctl_tol_arg_types
else if Stan_math_signatures.is_variadic_ode_nonadjoint_tol_fn id.name then
Stan_math_signatures.variadic_ode_tol_arg_types
else [] in
let mandatory_arg_types =
Stan_math_signatures.variadic_ode_mandatory_arg_types
@ optional_tol_mandatory_args in
let fail () =
let expected_args, err =
SignatureMismatch.check_variadic_args false mandatory_arg_types
Stan_math_signatures.variadic_ode_mandatory_fun_args
Stan_math_signatures.variadic_ode_fun_return_type (get_arg_types tes)
|> Result.error |> Option.value_exn in
Semantic_error.illtyped_variadic_ode loc id.name
(List.map ~f:type_of_expr_typed tes)
expected_args err
|> error in
let matching remaining_es Env.{type_= ftype; _} =
let arg_types =
(calculate_autodifftype cf Functions ftype, ftype)
:: get_arg_types remaining_es in
SignatureMismatch.check_variadic_args false mandatory_arg_types
Stan_math_signatures.variadic_ode_mandatory_fun_args
Stan_math_signatures.variadic_ode_fun_return_type arg_types in
match tes with
| {expr= Variable fname; _} :: remaining_es -> (
match find_matching_first_order_fn tenv (matching remaining_es) fname with
| SignatureMismatch.UniqueMatch (ftype, promotions) ->
let tes = make_function_variable cf loc fname ftype :: remaining_es in
mk_typed_expression
~expr:
(mk_fun_app ~is_cond_dist
(StanLib FnPlain, id, Promotion.promote_list tes promotions) )
~ad_level:(expr_ad_lub tes)
~type_:Stan_math_signatures.variadic_ode_return_type ~loc
| AmbiguousMatch ps ->
Semantic_error.ambiguous_function_promotion loc fname.name None ps
|> error
| SignatureErrors (expected_args, err) ->
Semantic_error.illtyped_variadic_ode loc id.name
(List.map ~f:type_of_expr_typed tes)
expected_args err
|> error )
| _ -> fail ()

and check_variadic_dae ~is_cond_dist loc cf tenv id tes =
let optional_tol_mandatory_args =
if Stan_math_signatures.is_variadic_dae_tol_fn id.name then
Stan_math_signatures.variadic_dae_tol_arg_types
else [] in
let mandatory_arg_types =
Stan_math_signatures.variadic_dae_mandatory_arg_types
@ optional_tol_mandatory_args in
let fail () =
let expected_args, err =
SignatureMismatch.check_variadic_args false mandatory_arg_types
Stan_math_signatures.variadic_dae_mandatory_fun_args
Stan_math_signatures.variadic_dae_fun_return_type (get_arg_types tes)
|> Result.error |> Option.value_exn in
Semantic_error.illtyped_variadic_dae loc id.name
(List.map ~f:type_of_expr_typed tes)
expected_args err
|> error in
| _ ->
let expected_args, err =
basic_mismatch () |> Result.error |> Option.value_exn in
Semantic_error.illtyped_reduce_sum_generic loc id.name
(List.map ~f:type_of_expr_typed tes)
expected_args err
|> error

and check_variadic ~is_cond_dist loc cf tenv id tes =
let Stan_math_signatures.
{control_args; required_fn_args; required_fn_rt; return_type} =
Hashtbl.find_exn Stan_math_signatures.stan_math_variadic_signatures id.name
in
let matching remaining_es Env.{type_= ftype; _} =
let arg_types =
(calculate_autodifftype cf Functions ftype, ftype)
:: get_arg_types remaining_es in
SignatureMismatch.check_variadic_args false mandatory_arg_types
Stan_math_signatures.variadic_dae_mandatory_fun_args
Stan_math_signatures.variadic_dae_fun_return_type arg_types in
SignatureMismatch.check_variadic_args ~allow_lpdf:false control_args
required_fn_args required_fn_rt arg_types in
match tes with
| {expr= Variable fname; _} :: remaining_es -> (
match find_matching_first_order_fn tenv (matching remaining_es) fname with
Expand All @@ -717,17 +654,24 @@ and check_variadic_dae ~is_cond_dist loc cf tenv id tes =
~expr:
(mk_fun_app ~is_cond_dist
(StanLib FnPlain, id, Promotion.promote_list tes promotions) )
~ad_level:(expr_ad_lub tes)
~type_:Stan_math_signatures.variadic_dae_return_type ~loc
~ad_level:(expr_ad_lub tes) ~type_:return_type ~loc
| AmbiguousMatch ps ->
Semantic_error.ambiguous_function_promotion loc fname.name None ps
|> error
| SignatureErrors (expected_args, err) ->
Semantic_error.illtyped_variadic_dae loc id.name
Semantic_error.illtyped_variadic loc id.name
(List.map ~f:type_of_expr_typed tes)
expected_args err
expected_args required_fn_rt err
|> error )
| _ -> fail ()
| _ ->
let expected_args, err =
SignatureMismatch.check_variadic_args ~allow_lpdf:false control_args
required_fn_args required_fn_rt (get_arg_types tes)
|> Result.error |> Option.value_exn in
Semantic_error.illtyped_variadic loc id.name
(List.map ~f:type_of_expr_typed tes)
expected_args required_fn_rt err
|> error

and check_funapp loc cf tenv ~is_cond_dist id (es : Ast.typed_expression list) =
let name_check =
Expand Down
Loading