Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: encapsulate variadic functions in typechecking #1259

Merged
merged 3 commits into from
Oct 14, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions docs/exposing_new_functions.mld
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,15 @@ For example, the following line defines the signature [add(real, matrix) => matr
Functions such as the ODE integrators or [reduce_sum], which take in user-functions and a variable-length
list of arguments, are {b NOT} added to this list.

These are instead treated as special cases in the [Typechecker] module in the frontend folder. It
it best to consult an existing example of how these are done before proceeding.
"Nice" variadic functions are added to the hashtable [Stan_math_signatures.stan_math_variadic_signatures].
This is probably sufficient for most variadic functions, e.g. all the ODE solvers and DAE solvers are done
via this method.
[reduce_sum] is not "nice", since it is both variadic and {e polymorphic}, requiring certain arguments to have the same
(but {e not predetermined}) type. Therefore, [reduce_sum] is treated as special case in the [Typechecker]
module in the frontend folder.

Note that higher-order functions also usually require changes to the C++ code generation to work properly.
It is best to consult an existing example of how these are done before proceeding.

{1 Testing}

Expand Down
3 changes: 1 addition & 2 deletions src/analysis_and_optimization/Memory_patterns.ml
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,8 @@ let query_stan_math_mem_pattern_support (name : string)
(args : (UnsizedType.autodifftype * UnsizedType.t) list) =
let open Stan_math_signatures in
match name with
| x when is_stan_math_variadic_function_name x -> false
| x when is_reduce_sum_fn x -> false
| x when is_variadic_ode_fn x -> false
| x when is_variadic_dae_fn x -> false
| _ ->
let name =
string_operator_to_stan_math_fns (Utils.stdlib_distribution_name name)
Expand Down
25 changes: 4 additions & 21 deletions src/frontend/Semantic_error.ml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ module TypeError = struct
* UnsizedType.t list
* (UnsizedType.autodifftype * UnsizedType.t) list
* SignatureMismatch.function_mismatch
| IllTypedVariadicDE of
| IllTypedVariadic of
string
* UnsizedType.t list
* (UnsizedType.autodifftype * UnsizedType.t) list
Expand Down Expand Up @@ -131,7 +131,7 @@ module TypeError = struct
| IllTypedReduceSumGeneric (name, arg_tys, expected_args, error) ->
SignatureMismatch.pp_signature_mismatch ppf
(name, arg_tys, ([((ReturnType UReal, expected_args), error)], false))
| IllTypedVariadicDE (name, arg_tys, args, error, return_type) ->
| IllTypedVariadic (name, arg_tys, args, error, return_type) ->
SignatureMismatch.pp_signature_mismatch ppf
( name
, arg_tys
Expand Down Expand Up @@ -550,25 +550,8 @@ let illtyped_reduce_sum_generic loc name arg_tys expected_args error =
, TypeError.IllTypedReduceSumGeneric (name, arg_tys, expected_args, error)
)

let illtyped_variadic_ode loc name arg_tys args error =
TypeError
( loc
, TypeError.IllTypedVariadicDE
( name
, arg_tys
, args
, error
, Stan_math_signatures.variadic_ode_fun_return_type ) )

let illtyped_variadic_dae loc name arg_tys args error =
TypeError
( loc
, TypeError.IllTypedVariadicDE
( name
, arg_tys
, args
, error
, Stan_math_signatures.variadic_dae_fun_return_type ) )
let illtyped_variadic loc name arg_tys args fn_rt error =
TypeError (loc, TypeError.IllTypedVariadic (name, arg_tys, args, error, fn_rt))

let ambiguous_function_promotion loc name arg_tys signatures =
TypeError
Expand Down
11 changes: 2 additions & 9 deletions src/frontend/Semantic_error.mli
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,6 @@ val illtyped_reduce_sum_generic :
-> SignatureMismatch.function_mismatch
-> t

val illtyped_variadic_ode :
Location_span.t
-> string
-> UnsizedType.t list
-> (UnsizedType.autodifftype * UnsizedType.t) list
-> SignatureMismatch.function_mismatch
-> t

val ambiguous_function_promotion :
Location_span.t
-> string
Expand All @@ -74,11 +66,12 @@ val ambiguous_function_promotion :
list
-> t

val illtyped_variadic_dae :
val illtyped_variadic :
Location_span.t
-> string
-> UnsizedType.t list
-> (UnsizedType.autodifftype * UnsizedType.t) list
-> UnsizedType.t
-> SignatureMismatch.function_mismatch
-> t

Expand Down
111 changes: 32 additions & 79 deletions src/frontend/Typechecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ let verify_name_fresh_udf loc tenv name =
(* variadic functions are currently not in math sigs and aren't
overloadable due to their separate typechecking *)
Stan_math_signatures.is_reduce_sum_fn name
|| Stan_math_signatures.is_variadic_ode_fn name
|| Stan_math_signatures.is_variadic_dae_fn name
|| Stan_math_signatures.is_stan_math_variadic_function_name name
then Semantic_error.ident_is_stanmath_name loc name |> error
else if Utils.is_unnormalized_distribution name then
Semantic_error.udf_is_unnormalized_fn loc name |> error
Expand Down Expand Up @@ -194,10 +193,12 @@ let stan_math_return_type name arg_tys =
match name with
| x when Stan_math_signatures.is_reduce_sum_fn x ->
Some (UnsizedType.ReturnType UReal)
| x when Stan_math_signatures.is_variadic_ode_fn x ->
Some (UnsizedType.ReturnType (UArray UVector))
| x when Stan_math_signatures.is_variadic_dae_fn x ->
Some (UnsizedType.ReturnType (UArray UVector))
| x when Stan_math_signatures.is_stan_math_variadic_function_name x ->
Some
(UnsizedType.ReturnType
(Hashtbl.find_exn Stan_math_signatures.stan_math_variadic_signatures
x )
.return_type )
| _ ->
SignatureMismatch.matching_stanlib_function name arg_tys
|> match_to_rt_option
WardBrian marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -571,14 +572,16 @@ let make_function_variable cf loc id = function

let rec check_fn ~is_cond_dist loc cf tenv id (tes : Ast.typed_expression list)
=
if Stan_math_signatures.is_reduce_sum_fn id.name then
if Stan_math_signatures.is_stan_math_variadic_function_name id.name then
check_variadic ~is_cond_dist loc cf tenv id tes
else if Stan_math_signatures.is_reduce_sum_fn id.name then
check_reduce_sum ~is_cond_dist loc cf tenv id tes
else if Stan_math_signatures.is_variadic_ode_fn id.name then
check_variadic_ode ~is_cond_dist loc cf tenv id tes
else if Stan_math_signatures.is_variadic_dae_fn id.name then
check_variadic_dae ~is_cond_dist loc cf tenv id tes
else check_normal_fn ~is_cond_dist loc tenv id tes

(** Reduce sum is a special case, even compared to the other
variadic functions, because it is polymorphic in the type of the
first argument. The first, fourth, and fifth arguments must agree,
which is too complicated to be captured declaratively. *)
and check_reduce_sum ~is_cond_dist loc cf tenv id tes =
let basic_mismatch () =
let mandatory_args =
Expand Down Expand Up @@ -635,79 +638,30 @@ and check_reduce_sum ~is_cond_dist loc cf tenv id tes =
|> error )
| _ -> fail ()

and check_variadic_ode ~is_cond_dist loc cf tenv id tes =
let optional_tol_mandatory_args =
if Stan_math_signatures.variadic_ode_adjoint_fn = id.name then
Stan_math_signatures.variadic_ode_adjoint_ctl_tol_arg_types
else if Stan_math_signatures.is_variadic_ode_nonadjoint_tol_fn id.name then
Stan_math_signatures.variadic_ode_tol_arg_types
else [] in
let mandatory_arg_types =
Stan_math_signatures.variadic_ode_mandatory_arg_types
@ optional_tol_mandatory_args in
let fail () =
let expected_args, err =
SignatureMismatch.check_variadic_args false mandatory_arg_types
Stan_math_signatures.variadic_ode_mandatory_fun_args
Stan_math_signatures.variadic_ode_fun_return_type (get_arg_types tes)
|> Result.error |> Option.value_exn in
Semantic_error.illtyped_variadic_ode loc id.name
(List.map ~f:type_of_expr_typed tes)
expected_args err
|> error in
let matching remaining_es Env.{type_= ftype; _} =
let arg_types =
(calculate_autodifftype cf Functions ftype, ftype)
:: get_arg_types remaining_es in
SignatureMismatch.check_variadic_args false mandatory_arg_types
Stan_math_signatures.variadic_ode_mandatory_fun_args
Stan_math_signatures.variadic_ode_fun_return_type arg_types in
match tes with
| {expr= Variable fname; _} :: remaining_es -> (
match find_matching_first_order_fn tenv (matching remaining_es) fname with
| SignatureMismatch.UniqueMatch (ftype, promotions) ->
let tes = make_function_variable cf loc fname ftype :: remaining_es in
mk_typed_expression
~expr:
(mk_fun_app ~is_cond_dist
(StanLib FnPlain, id, Promotion.promote_list tes promotions) )
~ad_level:(expr_ad_lub tes)
~type_:Stan_math_signatures.variadic_ode_return_type ~loc
| AmbiguousMatch ps ->
Semantic_error.ambiguous_function_promotion loc fname.name None ps
|> error
| SignatureErrors (expected_args, err) ->
Semantic_error.illtyped_variadic_ode loc id.name
(List.map ~f:type_of_expr_typed tes)
expected_args err
|> error )
| _ -> fail ()

and check_variadic_dae ~is_cond_dist loc cf tenv id tes =
let optional_tol_mandatory_args =
if Stan_math_signatures.is_variadic_dae_tol_fn id.name then
Stan_math_signatures.variadic_dae_tol_arg_types
else [] in
let mandatory_arg_types =
Stan_math_signatures.variadic_dae_mandatory_arg_types
@ optional_tol_mandatory_args in
and check_variadic ~is_cond_dist loc cf tenv id tes =
let Stan_math_signatures.
{ control_args
; required_fn_args
; required_fn_rt
; allow_fn_lpdf
; return_type } =
Hashtbl.find_exn Stan_math_signatures.stan_math_variadic_signatures id.name
in
let fail () =
WardBrian marked this conversation as resolved.
Show resolved Hide resolved
let expected_args, err =
SignatureMismatch.check_variadic_args false mandatory_arg_types
Stan_math_signatures.variadic_dae_mandatory_fun_args
Stan_math_signatures.variadic_dae_fun_return_type (get_arg_types tes)
SignatureMismatch.check_variadic_args allow_fn_lpdf control_args
required_fn_args required_fn_rt (get_arg_types tes)
|> Result.error |> Option.value_exn in
Semantic_error.illtyped_variadic_dae loc id.name
Semantic_error.illtyped_variadic loc id.name
(List.map ~f:type_of_expr_typed tes)
expected_args err
expected_args required_fn_rt err
|> error in
let matching remaining_es Env.{type_= ftype; _} =
let arg_types =
(calculate_autodifftype cf Functions ftype, ftype)
:: get_arg_types remaining_es in
SignatureMismatch.check_variadic_args false mandatory_arg_types
Stan_math_signatures.variadic_dae_mandatory_fun_args
Stan_math_signatures.variadic_dae_fun_return_type arg_types in
SignatureMismatch.check_variadic_args allow_fn_lpdf control_args
required_fn_args required_fn_rt arg_types in
match tes with
| {expr= Variable fname; _} :: remaining_es -> (
match find_matching_first_order_fn tenv (matching remaining_es) fname with
Expand All @@ -717,15 +671,14 @@ and check_variadic_dae ~is_cond_dist loc cf tenv id tes =
~expr:
(mk_fun_app ~is_cond_dist
(StanLib FnPlain, id, Promotion.promote_list tes promotions) )
~ad_level:(expr_ad_lub tes)
~type_:Stan_math_signatures.variadic_dae_return_type ~loc
~ad_level:(expr_ad_lub tes) ~type_:return_type ~loc
| AmbiguousMatch ps ->
Semantic_error.ambiguous_function_promotion loc fname.name None ps
|> error
| SignatureErrors (expected_args, err) ->
Semantic_error.illtyped_variadic_dae loc id.name
Semantic_error.illtyped_variadic loc id.name
(List.map ~f:type_of_expr_typed tes)
expected_args err
expected_args required_fn_rt err
|> error )
| _ -> fail ()

Expand Down
Loading