Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor "standard library" as a virtual module #1184

Closed
wants to merge 39 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
5f9a000
Start refactor
WardBrian Apr 27, 2022
654810c
Virtual libraries working
WardBrian May 3, 2022
8ffd93e
Cleanup
WardBrian May 3, 2022
95412b4
Update docs
WardBrian May 3, 2022
1a03bf7
Cleanup
WardBrian May 4, 2022
6af15ae
Add docstrings to Library interface
WardBrian May 4, 2022
8c7690a
Merge branch 'master' into virtual-library-experiment
WardBrian May 4, 2022
140f969
Move if_else deprecation
WardBrian May 5, 2022
183274d
Merge branch 'master' into virtual-library-experiment
WardBrian May 5, 2022
fe7a140
Remove unused dune stanza
WardBrian May 5, 2022
0c8f5a6
Merge branch 'master' into virtual-library-experiment
WardBrian May 6, 2022
8969a82
Merge branch 'master' into virtual-library-experiment
WardBrian May 9, 2022
ce4184a
Merge branch 'master' into virtual-library-experiment
WardBrian May 10, 2022
4c143c2
Simplify include path splitting
WardBrian May 10, 2022
851b94b
Merge branch 'master' into virtual-library-experiment
WardBrian May 12, 2022
e6d7bf6
Update dune project to prevent opam file from changing
WardBrian May 12, 2022
9f10094
Merge branch 'master' into virtual-library-experiment
WardBrian May 17, 2022
a9ae3b1
Merge branch 'master' into virtual-library-experiment
WardBrian May 24, 2022
5edcca9
Merge branch 'master' into virtual-library-experiment
WardBrian May 25, 2022
69bfa1a
Merge branch 'master' into virtual-library-experiment
WardBrian May 26, 2022
e87cd7a
Merge branch 'master' into virtual-library-experiment
WardBrian Jun 1, 2022
66e6565
Merge branch 'master' into virtual-library-experiment
WardBrian Jun 3, 2022
872b466
Merge branch 'master' into virtual-library-experiment
WardBrian Jun 13, 2022
598f608
Merge branch 'master' into virtual-library-experiment
WardBrian Jun 15, 2022
6846e38
Merge branch 'master' into virtual-library-experiment
WardBrian Jul 11, 2022
7c4c39a
Merge branch 'master' into virtual-library-experiment
WardBrian Jul 26, 2022
7f3bac2
Merge branch 'master' into virtual-library-experiment
WardBrian Sep 22, 2022
ec5ab3b
Empty commit
WardBrian Sep 22, 2022
42d0351
Merge branch 'master' into virtual-library-experiment
WardBrian Oct 12, 2022
6dcab97
Merge branch 'master'
WardBrian Oct 14, 2022
56efb2d
Empty commit
WardBrian Oct 14, 2022
37aa03e
Force rename try?
WardBrian Oct 14, 2022
72b64de
Try to force different stash
WardBrian Oct 14, 2022
438320e
More debug
WardBrian Oct 14, 2022
58b3e03
Delete files after ocaml tests
WardBrian Oct 14, 2022
0f44e26
Restore jenkinsfile
WardBrian Oct 14, 2022
20dcf0a
Ensure ocaml tests proper cleanup
serban-nicusor-toptal Oct 14, 2022
f30e7d9
Change order of --root in dune calls
serban-nicusor-toptal Oct 14, 2022
6142213
Merge branch 'master' into experiment/virtual-library
WardBrian Oct 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/core_ideas.mld
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,15 @@ This takes some getting used to, and also can lead to some unhelpful type signat
VSCode, because abbreviations are not always used in hover-over text. For example, [Expr.Typed.t], the MIR's typed
expression type, actually has a signature of [Expr.Typed.Meta.t Expr.Fixed.t].

{1 The [Library] virtual module}

The [Frontend] library is a {{:https://dune.readthedocs.io/en/latest/variants.html}virtual library}
where the module [Library] is unimplemented. This allows the rest of the library to operate without
making backend-specific assumptions about any one library.

This must be supplied when the executable is built in the [dune] file.
For Stanc, we supply the [stan_math_library] instantiation defined in [src/stan_math_backend/stan_math_library].

{1 The [Fmt] library and pretty-printing}

We extensively use the {{:https://erratique.ch/software/fmt}Fmt} library for our pretty-printing and code
Expand Down
7 changes: 4 additions & 3 deletions docs/exposing_new_functions.mld
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
For a function to be built into Stan, it has to be included in the Stan Math library and its signature has to be
exposed to the compiler.

To do the latter, we have to add a corresponding line in [src/middle/Stan_math_signatures.ml].
To do the latter, we have to add a corresponding line in [src/stan_math_backend/Stan_math_signatures.ml].
The compiler uses the signatures defined there to do type checking.


Expand Down Expand Up @@ -130,8 +130,9 @@ For example, the following line defines the signature [add(real, matrix) => matr
Functions such as the ODE integrators or [reduce_sum], which take in user-functions and a variable-length
list of arguments, are {b NOT} added to this list.

These are instead treated as special cases in the [Typechecker] module in the frontend folder. It
it best to consult an existing example of how these are done before proceeding.
These are instead handled by special functions like [is_variadic_function_name]. They
must also be given custom typechecking rules in the private sub-module [Variadic_typechecking].
It is best to consult an existing example of how these are done before proceeding.

{1 Testing}

Expand Down
2 changes: 1 addition & 1 deletion src/analysis_and_optimization/Debug_data_generation.ml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ let rec vect_to_mat l m =

let eval_expr m e =
let e = Mir_utils.subst_expr m e in
let e = Partial_evaluator.eval_expr e in
let e = Partial_evaluator.try_eval_expr e in
let rec strip_promotions (e : Middle.Expr.Typed.t) =
match e.pattern with Promotion (e, _, _) -> strip_promotions e | _ -> e
in
Expand Down
7 changes: 3 additions & 4 deletions src/analysis_and_optimization/Dependence_analysis.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ open Middle
open Dataflow_types
open Mir_utils
open Dataflow_utils
open Monotone_framework_sigs
open Monotone_framework_intf
open Monotone_framework

(***********************************)
Expand Down Expand Up @@ -119,9 +119,8 @@ let mir_reaching_definitions (mir : Program.Typed.t) (stmt : Stmt.Located.t) :
Map.Poly.map rd_map ~f:(fun {entry; exit} ->
{entry= to_rd_set entry; exit= to_rd_set exit} )

let all_labels
(module Flowgraph : Monotone_framework_sigs.FLOWGRAPH with type labels = int)
: int Set.Poly.t =
let all_labels (module Flowgraph : FLOWGRAPH with type labels = int) :
int Set.Poly.t =
let step set =
Set.Poly.union set
(union_map set ~f:(fun l -> Map.Poly.find_exn Flowgraph.successors l))
Expand Down
11 changes: 4 additions & 7 deletions src/analysis_and_optimization/Mem_pattern.ml
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,13 @@ let rec is_uni_eigen_loop_indexing in_loop (ut : UnsizedType.t)

let query_stan_math_mem_pattern_support (name : string)
(args : (UnsizedType.autodifftype * UnsizedType.t) list) =
let open Stan_math_signatures in
match name with
| x when is_reduce_sum_fn x -> false
| x when is_variadic_ode_fn x -> false
| x when is_variadic_dae_fn x -> false
| x when Frontend.Library.is_variadic_function_name x -> false
| _ ->
let name =
string_operator_to_stan_math_fns (Utils.stdlib_distribution_name name)
in
let namematches = Hashtbl.find_multi stan_math_signatures name in
Frontend.Library.string_operator_to_function_name
(Utils.stdlib_distribution_name name) in
let namematches = Frontend.Library.get_signatures name in
let filteredmatches =
List.filter
~f:(fun x ->
Expand Down
16 changes: 8 additions & 8 deletions src/analysis_and_optimization/Mir_utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -464,16 +464,16 @@ let cleanup_empty_stmts stmts =

(**
* Convert a Type.Unsized to a Type.Sized.
* This function is useful in the inlining scheme as
* the Mem_patterns optimization cannot work with decl types
* for unsized types. (Steve: tmk the inline optimization is the only place
* we create Decl's with unsized types.)
* This function is useful in the inlining scheme as
* the Mem_patterns optimization cannot work with decl types
* for unsized types. (Steve: tmk the inline optimization is the only place
* we create Decl's with unsized types.)
*
* Note that there is no true mapping from Sized types to Unsized types.
* Any sizes are set to 0 and it is assumed that the intent
* of Types.Unsized with inner UFun types is to size the return
* type of the UFun. Any Decl that uses this type should
* have initialize set to false.
* Any sizes are set to 0 and it is assumed that the intent
* of Types.Unsized with inner UFun types is to size the return
* type of the UFun. Any Decl that uses this type should
* have initialize set to false.
*)
let unsafe_unsized_to_sized_type (rt : Expr.Typed.t Type.t) =
match rt with
Expand Down
20 changes: 8 additions & 12 deletions src/analysis_and_optimization/Monotone_framework.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

open Core_kernel
open Core_kernel.Poly
open Monotone_framework_sigs
open Monotone_framework_intf
open Mir_utils
open Middle

Expand Down Expand Up @@ -864,7 +864,7 @@ let rec declared_variables_stmt
(List.map ~f:(fun x -> declared_variables_stmt x.pattern) l)

let propagation_mfp (prog : Program.Typed.t)
(module Flowgraph : Monotone_framework_sigs.FLOWGRAPH with type labels = int)
(module Flowgraph : FLOWGRAPH with type labels = int)
(flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t)
(propagation_transfer :
(int, Stmt.Located.Non_recursive.t) Map.Poly.t
Expand Down Expand Up @@ -897,7 +897,7 @@ let propagation_mfp (prog : Program.Typed.t)
Mf.mfp ()

let reaching_definitions_mfp (mir : Program.Typed.t)
(module Flowgraph : Monotone_framework_sigs.FLOWGRAPH with type labels = int)
(module Flowgraph : FLOWGRAPH with type labels = int)
(flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) =
let variables =
( module struct
Expand All @@ -918,7 +918,7 @@ let reaching_definitions_mfp (mir : Program.Typed.t)
Mf.mfp ()

let initialized_vars_mfp (total : string Set.Poly.t)
(module Flowgraph : Monotone_framework_sigs.FLOWGRAPH with type labels = int)
(module Flowgraph : FLOWGRAPH with type labels = int)
(flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) =
let (module Lattice) =
dual_powerset_lattice_empty_initial
Expand Down Expand Up @@ -949,8 +949,7 @@ let globals (prog : Program.Typed.t) =
(** Monotone framework instance for live_variables analysis. Expects reverse
flowgraph. *)
let live_variables_mfp (prog : Program.Typed.t)
(module Rev_Flowgraph : Monotone_framework_sigs.FLOWGRAPH
with type labels = int )
(module Rev_Flowgraph : FLOWGRAPH with type labels = int)
(flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) =
let never_kill = globals prog in
let variables =
Expand All @@ -970,10 +969,8 @@ let live_variables_mfp (prog : Program.Typed.t)

(** Instantiate all four instances of the monotone framework for lazy
code motion, reusing code between them *)
let lazy_expressions_mfp
(module Flowgraph : Monotone_framework_sigs.FLOWGRAPH with type labels = int)
(module Rev_Flowgraph : Monotone_framework_sigs.FLOWGRAPH
with type labels = int )
let lazy_expressions_mfp (module Flowgraph : FLOWGRAPH with type labels = int)
(module Rev_Flowgraph : FLOWGRAPH with type labels = int)
(flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) =
let all_expressions =
used_subexpressions_stmt
Expand Down Expand Up @@ -1031,8 +1028,7 @@ let lazy_expressions_mfp
*
*)
let minimal_variables_mfp
(module Circular_Fwd_Flowgraph : Monotone_framework_sigs.FLOWGRAPH
with type labels = int )
(module Circular_Fwd_Flowgraph : FLOWGRAPH with type labels = int)
(flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t)
(initial_variables : string Set.Poly.t)
(gen_variable :
Expand Down
11 changes: 2 additions & 9 deletions src/analysis_and_optimization/Optimize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ let list_collapsing (mir : Program.Typed.t) =
let propagation
(propagation_transfer :
(int, Stmt.Located.Non_recursive.t) Map.Poly.t
-> (module Monotone_framework_sigs.TRANSFER_FUNCTION
-> (module Monotone_framework_intf.TRANSFER_FUNCTION
with type labels = int
and type properties = (string, Middle.Expr.Typed.t) Map.Poly.t
option ) ) (mir : Program.Typed.t) =
Expand Down Expand Up @@ -712,7 +712,7 @@ let dead_code_elimination (mir : Program.Typed.t) =
let dead_code_elim_stmt_base i stmt =
(* NOTE: entry in the reverse flowgraph, so exit in the forward flowgraph *)
let live_variables_s =
(Map.find_exn live_variables i).Monotone_framework_sigs.entry in
(Map.find_exn live_variables i).Monotone_framework_intf.entry in
match stmt with
| Stmt.Fixed.Pattern.Assignment ((x, _, []), rhs) ->
if Set.Poly.mem live_variables_s x || cannot_remove_expr rhs then stmt
Expand Down Expand Up @@ -1184,11 +1184,6 @@ let optimize_soa (mir : Program.Typed.t) =
List.fold ~init:Set.Poly.empty
~f:(Mem_pattern.query_initial_demotable_stmt false)
mir.log_prob in
(*
let print_set s =
Set.Poly.iter ~f:print_endline s in
let () = print_set initial_variables in
*)
let mod_exprs aos_exits mod_expr =
Mir_utils.map_rec_expr (Mem_pattern.modify_expr_pattern aos_exits) mod_expr
in
Expand All @@ -1207,8 +1202,6 @@ let optimize_soa (mir : Program.Typed.t) =
in
{mir with log_prob= transform' mir.log_prob}

(* Apparently you need to completely copy/paste type definitions between
ml and mli files?*)
type optimization_settings =
{ function_inlining: bool
; static_loop_unrolling: bool
Expand Down
3 changes: 2 additions & 1 deletion src/analysis_and_optimization/Optimize.mli
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
(* Code for optimization passes on the MIR *)

open Middle

val function_inlining : Program.Typed.t -> Program.Typed.t
Expand Down Expand Up @@ -59,7 +60,7 @@ val optimize_ad_levels : Program.Typed.t -> Program.Typed.t

val allow_uninitialized_decls : Program.Typed.t -> Program.Typed.t
(** Marks Decl types such that, if the first assignment after the decl
assigns to the full object, allow the object to be constructed but
assigns to the full object, allow the object to be constructed but
not uninitialized. *)

(** Interface for turning individual optimizations on/off. Useful for testing
Expand Down
4 changes: 2 additions & 2 deletions src/analysis_and_optimization/Partial_evaluator.ml
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ let rec eval_expr ?(preserve_stability = false) (e : Expr.Typed.t) =
Operator.of_string_opt name
|> Option.value_map
~f:(fun op ->
Frontend.Typechecker.operator_stan_math_return_type op
Frontend.Typechecker.operator_return_type op
argument_types
|> Option.map ~f:fst )
~default:
(Frontend.Typechecker.stan_math_return_type name
(Frontend.Typechecker.library_function_return_type name
argument_types ) in
let try_partially_evaluate_stanlib e =
Expr.Fixed.Pattern.(
Expand Down
7 changes: 3 additions & 4 deletions src/analysis_and_optimization/dune
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
(library
(name analysis_and_optimization)
(public_name stanc.analysis)
(libraries core_kernel str fmt common middle frontend)
(inline_tests)
;; TODO: Not sure what's going on but it's throwing an error that this module has no implementation
(modules_without_implementation monotone_framework_sigs)
(libraries core_kernel str fmt common middle frontend stan_math_backend)
(inline_tests
(libraries stan_math_library))
(preprocess
(pps ppx_jane ppx_deriving.map ppx_deriving.fold)))
7 changes: 4 additions & 3 deletions src/frontend/Ast_to_Mir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ let truncate_dist ud_dists (id : Ast.identifier) ast_obs ast_args t =
| None ->
( Ast.StanLib FnPlain
, Set.to_list possible_names |> List.hd_exn
, if Stan_math_signatures.is_stan_math_function_name (id.name ^ "_lpmf")
then UnsizedType.UInt
, if Library.is_stdlib_function_name (id.name ^ "_lpmf") then
UnsizedType.UInt
else UnsizedType.UReal (* close enough *) ) in
let trunc cond_op (x : Ast.typed_expression) y =
let smeta = x.Ast.emeta.loc in
Expand Down Expand Up @@ -431,7 +431,8 @@ let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) =
| Ast.IncrementLogProb e | Ast.TargetPE e -> TargetPE (trans_expr e) |> swrap
| Ast.Tilde {arg; distribution; args; truncation} ->
let suffix =
Stan_math_signatures.dist_name_suffix ud_dists distribution.name in
Std_library_utils.dist_name_suffix Library.is_stdlib_function_name
ud_dists distribution.name in
let name = distribution.name ^ suffix in
let kind =
let possible_names =
Expand Down
4 changes: 2 additions & 2 deletions src/frontend/Canonicalize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ let rec replace_deprecated_expr
if is_deprecated_distribution name then
CondDistApp
( StanLib suffix
, {name= rename_deprecated deprecated_distributions name; id_loc}
, {name= rename_deprecated_distribution name; id_loc}
, List.map ~f:(replace_deprecated_expr deprecated_userdefined) e )
else if String.is_suffix name ~suffix:"_cdf" then
CondDistApp
Expand All @@ -59,7 +59,7 @@ let rec replace_deprecated_expr
else
FunApp
( StanLib suffix
, {name= rename_deprecated deprecated_functions name; id_loc}
, {name= rename_deprecated_function name; id_loc}
, List.map ~f:(replace_deprecated_expr deprecated_userdefined) e )
| FunApp (UserDefined suffix, {name; id_loc}, e) -> (
match String.Map.find deprecated_userdefined name with
Expand Down
62 changes: 16 additions & 46 deletions src/frontend/Deprecation_analysis.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,29 @@ open Core_kernel
open Ast
open Middle

let deprecated_functions =
String.Map.of_alist_exn
[ ("multiply_log", ("lmultiply", "2.32.0"))
; ("binomial_coefficient_log", ("lchoose", "2.32.0"))
; ("cov_exp_quad", ("gp_exp_quad_cov", "2.32.0")) ]

let deprecated_odes =
String.Map.of_alist_exn
[ ("integrate_ode", ("ode_rk45", "3.0"))
; ("integrate_ode_rk45", ("ode_rk45", "3.0"))
; ("integrate_ode_bdf", ("ode_bdf", "3.0"))
; ("integrate_ode_adams", ("ode_adams", "3.0")) ]

let deprecated_distributions =
String.Map.of_alist_exn
(List.map
~f:(fun (x, y) -> (x, (y, "2.32.0")))
(List.concat_map Middle.Stan_math_signatures.distributions
~f:(fun (fnkinds, name, _, _) ->
List.filter_map fnkinds ~f:(function
| Lpdf -> Some (name ^ "_log", name ^ "_lpdf")
| Lpmf -> Some (name ^ "_log", name ^ "_lpmf")
| Cdf -> Some (name ^ "_cdf_log", name ^ "_lcdf")
| Ccdf -> Some (name ^ "_ccdf_log", name ^ "_lccdf")
| Rng | UnaryVectorized -> None ) ) ) )

let stan_lib_deprecations =
Map.merge_skewed deprecated_distributions deprecated_functions
Map.merge_skewed Library.deprecated_distributions Library.deprecated_functions
~combine:(fun ~key x y ->
Common.FatalError.fatal_error_msg
[%message
"Common key in deprecation map"
(key : string)
(x : string * string)
(y : string * string)] )
(x : Std_library_utils.deprecation_info)
(y : Std_library_utils.deprecation_info)] )

let is_deprecated_distribution name =
Option.is_some (Map.find deprecated_distributions name)
Map.mem Library.deprecated_distributions name

let rename_deprecated map name =
Map.find map name |> Option.map ~f:fst |> Option.value ~default:name
Map.find map name
|> Option.map ~f:(fun Std_library_utils.{replacement; canonicalize_away; _} ->
if canonicalize_away then replacement else name )
|> Option.value ~default:name

let rename_deprecated_distribution =
rename_deprecated Library.deprecated_distributions

let rename_deprecated_function = rename_deprecated Library.deprecated_functions

let distribution_suffix name =
let open String in
Expand Down Expand Up @@ -119,30 +101,18 @@ let rec collect_deprecated_expr (acc : (Location_span.t * string) list)
| FunApp ((StanLib _ | UserDefined _), {name; _}, l) ->
let w =
match Map.find stan_lib_deprecations name with
| Some (rename, version) ->
| Some {replacement; version; extra_message; _} ->
[ ( emeta.loc
, name ^ " is deprecated and will be removed in Stan " ^ version
^ ". Use " ^ rename
^ " instead. This can be automatically changed using the \
canonicalize flag for stanc" ) ]
^ ". Use " ^ replacement ^ " instead. " ^ extra_message ) ]
| _ when String.is_suffix name ~suffix:"_cdf" ->
[ ( emeta.loc
, "Use of " ^ name
^ " without a vertical bar (|) between the first two arguments \
of a CDF is deprecated and will be removed in Stan 2.32.0. \
This can be automatically changed using the canonicalize \
flag for stanc" ) ]
| _ -> (
match Map.find deprecated_odes name with
| Some (rename, version) ->
[ ( emeta.loc
, name ^ " is deprecated and will be removed in Stan " ^ version
^ ". Use " ^ rename
^ " instead. \n\
The new interface is slightly different, see: \
https://mc-stan.org/users/documentation/case-studies/convert_odes.html"
) ]
| _ -> [] ) in
| _ -> [] in
acc @ w @ List.concat_map l ~f:(fun e -> collect_deprecated_expr [] e)
| _ -> fold_expression collect_deprecated_expr (fun l _ -> l) acc expr

Expand Down
Loading