Skip to content

Commit

Permalink
Merge pull request #549 from rybern/optimization-levels
Browse files Browse the repository at this point in the history
[WIP] Optimization level interface
  • Loading branch information
SteveBronder authored Dec 24, 2021
2 parents 35af5e5 + e3f4e1b commit f9be294
Show file tree
Hide file tree
Showing 9 changed files with 20,531 additions and 54 deletions.
8 changes: 3 additions & 5 deletions src/analysis_and_optimization/Monotone_framework.ml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ open Monotone_framework_sigs
open Mir_utils
open Middle

let preserve_stability = false

(** Debugging tool to print out MFP sets **)
let print_mfp to_string (mfp : (int, 'a entry_exit) Map.Poly.t)
(flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) : unit =
Expand Down Expand Up @@ -309,8 +307,8 @@ let minimal_variables_lattice initial_variables =
let initial = initial_variables
end )

(** The transfer function for a constant propagation analysis *)
let constant_propagation_transfer
(* The transfer function for a constant propagation analysis *)
let constant_propagation_transfer ?(preserve_stability = false)
(flowgraph_to_mir : (int, Stmt.Located.Non_recursive.t) Map.Poly.t) =
( module struct
type labels = int
Expand Down Expand Up @@ -354,7 +352,7 @@ let label_top_decls

(** The transfer function for an expression propagation analysis,
AKA forward substitution (see page 396 of Muchnick) *)
let expression_propagation_transfer
let expression_propagation_transfer ?(preserve_stability = false)
(can_side_effect_expr : Middle.Expr.Typed.t -> bool)
(flowgraph_to_mir : (int, Middle.Stmt.Located.Non_recursive.t) Map.Poly.t) =
( module struct
Expand Down
60 changes: 41 additions & 19 deletions src/analysis_and_optimization/Optimize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ open Common
open Middle
open Mir_utils

let preserve_stability = false

(**
Apply the transformation to each function body and to the rest of the program as one
block.
Expand Down Expand Up @@ -612,8 +610,9 @@ let propagation
propagate_stmt (Map.find_exn flowgraph_to_mir 1) in
transform_program mir transform

let constant_propagation =
propagation Monotone_framework.constant_propagation_transfer
let constant_propagation ?(preserve_stability = false) =
propagation
(Monotone_framework.constant_propagation_transfer ~preserve_stability)

let rec expr_any pred (e : Expr.Typed.t) =
match e.pattern with
Expand All @@ -632,7 +631,7 @@ let can_side_effect_top_expr (e : Expr.Typed.t) =
Internal_fun.can_side_effect internal_fn
| _ -> false

let cannot_duplicate_expr (e : Expr.Typed.t) =
let cannot_duplicate_expr ?(preserve_stability = false) (e : Expr.Typed.t) =
let pred e =
can_side_effect_top_expr e
|| ( match e.pattern with
Expand All @@ -643,9 +642,10 @@ let cannot_duplicate_expr (e : Expr.Typed.t) =

let cannot_remove_expr (e : Expr.Typed.t) = expr_any can_side_effect_top_expr e

let expression_propagation mir =
let expression_propagation ?(preserve_stability = false) mir =
propagation
(Monotone_framework.expression_propagation_transfer cannot_duplicate_expr)
(Monotone_framework.expression_propagation_transfer ~preserve_stability
(cannot_duplicate_expr ~preserve_stability) )
mir

let copy_propagation mir =
Expand Down Expand Up @@ -840,7 +840,7 @@ let transform_mir_blocks (mir : (Expr.Typed.t, Stmt.Located.t) Program.t)
let allow_uninitialized_decls mir =
transform_mir_blocks mir unenforce_initialize

let lazy_code_motion (mir : Program.Typed.t) =
let lazy_code_motion ?(preserve_stability = false) (mir : Program.Typed.t) =
(* TODO: clean up this code. It is not very pretty. *)
(* TODO: make lazy code motion operate on transformed parameters and models blocks
simultaneously *)
Expand Down Expand Up @@ -892,7 +892,7 @@ let lazy_code_motion (mir : Program.Typed.t) =
match e.pattern with
| Lit (_, _) -> accum
| Var _ -> accum
| _ when cannot_duplicate_expr e ->
| _ when cannot_duplicate_expr ~preserve_stability e ->
(* Immovable expressions might have movable subexpressions *)
Expr.Fixed.Pattern.fold collect_expressions accum e.pattern
| _ -> Map.set accum ~key:e ~data:(Gensym.generate ~prefix:"lcm_" ())
Expand Down Expand Up @@ -1133,7 +1133,8 @@ type optimization_settings =
; dead_code_elimination: bool
; partial_evaluation: bool
; lazy_code_motion: bool
; optimize_ad_levels: bool }
; optimize_ad_levels: bool
; preserve_stability: bool }

let settings_const b =
{ function_inlining= b
Expand All @@ -1148,43 +1149,64 @@ let settings_const b =
; dead_code_elimination= b
; partial_evaluation= b
; lazy_code_motion= b
; optimize_ad_levels= b }
; optimize_ad_levels= b
; preserve_stability= not b }

let all_optimizations : optimization_settings = settings_const true
let no_optimizations : optimization_settings = settings_const false

let settings_default : optimization_settings =
let xx = settings_const false in
{xx with allow_uninitialized_decls= false}
type optimization_level = O0 | O1 | Oexperimental

let level_optimizations (lvl : optimization_level) : optimization_settings =
match lvl with
| O0 -> {no_optimizations with allow_uninitialized_decls= false}
| O1 ->
{ function_inlining= false
; static_loop_unrolling= false
; one_step_loop_unrolling= false
; list_collapsing= true
; block_fixing= true
; constant_propagation= true
; expression_propagation= false
; copy_propagation= true
; dead_code_elimination= true
; partial_evaluation= true
; lazy_code_motion= false
; allow_uninitialized_decls= false
; optimize_ad_levels= true
; preserve_stability= false }
| Oexperimental -> all_optimizations

let optimization_suite ?(settings = all_optimizations) mir =
let preserve_stability = settings.preserve_stability in
let maybe_optimizations =
[ (* Phase order. See phase-ordering-nodes.org for details *)
(* Book section A *)
(* Book section B *)
(* Book: Procedure integration *)
(function_inlining, settings.function_inlining)
(* Book: Sparse conditional constant propagation *)
; (constant_propagation, settings.constant_propagation)
; (constant_propagation ~preserve_stability, settings.constant_propagation)
(* Book section C *)
(* Book: Local and global copy propagation *)
; (copy_propagation, settings.copy_propagation)
(* Book: Sparse conditional constant propagation *)
; (constant_propagation, settings.constant_propagation)
; (constant_propagation ~preserve_stability, settings.constant_propagation)
(* Book: Dead-code elimination *)
; (dead_code_elimination, settings.dead_code_elimination)
(* Matthijs: Before lazy code motion to get loop-invariant code motion *)
; (one_step_loop_unrolling, settings.one_step_loop_unrolling)
(* Matthjis: expression_propagation < partial_evaluation *)
; (expression_propagation, settings.expression_propagation)
; ( expression_propagation ~preserve_stability
, settings.expression_propagation )
(* Matthjis: partial_evaluation < lazy_code_motion *)
; (partial_evaluation, settings.partial_evaluation)
(* Book: Loop-invariant code motion *)
; (lazy_code_motion, settings.lazy_code_motion)
; (lazy_code_motion ~preserve_stability, settings.lazy_code_motion)
(* Matthijs: lazy_code_motion < copy_propagation TODO: Check if this is necessary *)
; (copy_propagation, settings.copy_propagation)
(* Matthijs: Constant propagation before static loop unrolling *)
; (constant_propagation, settings.constant_propagation)
; (constant_propagation ~preserve_stability, settings.constant_propagation)
(* Book: Loop simplification *)
; (static_loop_unrolling, settings.static_loop_unrolling)
(*Remove decls immediately assigned to*)
Expand Down
17 changes: 12 additions & 5 deletions src/analysis_and_optimization/Optimize.mli
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ val block_fixing : Program.Typed.t -> Program.Typed.t
constructors are replaced with Block constructors.
This should probably be run before we generate code. *)

val constant_propagation : Program.Typed.t -> Program.Typed.t
val constant_propagation :
?preserve_stability:bool -> Program.Typed.t -> Program.Typed.t
(** Propagate constant values through variable assignments *)

val expression_propagation : Program.Typed.t -> Program.Typed.t
val expression_propagation :
?preserve_stability:bool -> Program.Typed.t -> Program.Typed.t
(** Propagate arbitrary expressions through variable assignments.
This can be useful for opening up new possibilities for partial evaluation.
It should be followed by some CSE or lazy code motion pass, however. *)
Expand All @@ -45,7 +47,8 @@ val partial_evaluation : Program.Typed.t -> Program.Typed.t
(** Partially evaluate expressions in the program. This includes simplification using
algebraic identities of logical and arithmetic operators as well as Stan math functions. *)

val lazy_code_motion : Program.Typed.t -> Program.Typed.t
val lazy_code_motion :
?preserve_stability:bool -> Program.Typed.t -> Program.Typed.t
(** Perform partial redundancy elmination using the lazy code motion algorithm. This
subsumes common subexpression elimination and loop-invariant code motion. *)

Expand Down Expand Up @@ -74,11 +77,15 @@ type optimization_settings =
; dead_code_elimination: bool
; partial_evaluation: bool
; lazy_code_motion: bool
; optimize_ad_levels: bool }
; optimize_ad_levels: bool
; preserve_stability: bool }

val all_optimizations : optimization_settings
val no_optimizations : optimization_settings
val settings_default : optimization_settings

type optimization_level = O0 | O1 | Oexperimental

val level_optimizations : optimization_level -> optimization_settings

val optimization_suite :
?settings:optimization_settings -> Program.Typed.t -> Program.Typed.t
Expand Down
20 changes: 13 additions & 7 deletions src/analysis_and_optimization/Partial_evaluator.ml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ open Middle

exception Rejected of Location_span.t * string

let preserve_stability = false

let is_int i Expr.Fixed.{pattern; _} =
let nums = List.map ~f:(fun s -> string_of_int i ^ s) [""; "."; ".0"] in
match pattern with
Expand Down Expand Up @@ -90,13 +88,13 @@ let is_multi_index = function
| Index.MultiIndex _ | Upfrom _ | Between _ | All -> true
| Single _ -> false

let rec eval_expr (e : Expr.Typed.t) =
let rec eval_expr ?(preserve_stability = false) (e : Expr.Typed.t) =
{ e with
pattern=
( match e.pattern with
| Var _ | Lit (_, _) -> e.pattern
| FunApp (kind, l) -> (
let l = List.map ~f:eval_expr l in
let l = List.map ~f:(eval_expr ~preserve_stability) l in
match kind with
| UserDefined _ | CompilerInternal _ -> FunApp (kind, l)
| StanLib (f, suffix, mem_type) ->
Expand Down Expand Up @@ -961,12 +959,18 @@ let rec eval_expr (e : Expr.Typed.t) =
| _ -> FunApp (kind, l) )
| _ -> FunApp (kind, l) ) )
| TernaryIf (e1, e2, e3) -> (
match (eval_expr e1, eval_expr e2, eval_expr e3) with
match
( eval_expr ~preserve_stability e1
, eval_expr ~preserve_stability e2
, eval_expr ~preserve_stability e3 )
with
| x, _, e3' when is_int 0 x -> e3'.pattern
| {pattern= Lit (Int, _); _}, e2', _ -> e2'.pattern
| e1', e2', e3' -> TernaryIf (e1', e2', e3') )
| EAnd (e1, e2) -> (
match (eval_expr e1, eval_expr e2) with
match
(eval_expr ~preserve_stability e1, eval_expr ~preserve_stability e2)
with
| {pattern= Lit (Int, s1); _}, {pattern= Lit (Int, s2); _} ->
let i1, i2 = (Int.of_string s1, Int.of_string s2) in
Lit (Int, Int.to_string (Bool.to_int (i1 <> 0 && i2 <> 0)))
Expand All @@ -975,7 +979,9 @@ let rec eval_expr (e : Expr.Typed.t) =
Lit (Int, Int.to_string (Bool.to_int (r1 <> 0. && r2 <> 0.)))
| e1', e2' -> EAnd (e1', e2') )
| EOr (e1, e2) -> (
match (eval_expr e1, eval_expr e2) with
match
(eval_expr ~preserve_stability e1, eval_expr ~preserve_stability e2)
with
| {pattern= Lit (Int, s1); _}, {pattern= Lit (Int, s2); _} ->
let i1, i2 = (Int.of_string s1, Int.of_string s2) in
Lit (Int, Int.to_string (Bool.to_int (i1 <> 0 || i2 <> 0)))
Expand Down
41 changes: 26 additions & 15 deletions src/stanc/stanc.ml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ let dump_tx_mir_pretty = ref false
let dump_opt_mir = ref false
let dump_opt_mir_pretty = ref false
let dump_stan_math_sigs = ref false
let optimize = ref false
let opt_lvl = ref Optimize.O0
let output_file = ref ""
let generate_data = ref false
let warn_uninitialized = ref false
Expand Down Expand Up @@ -138,9 +138,22 @@ let options =
, Arg.Set_string Typechecker.model_name
, " Take a string to set the model name (default = \
\"$model_filename_model\")" )
; ( "-O0"
, Arg.Unit (fun () -> opt_lvl := Optimize.O0)
, "\t(Default) Do not apply optimizations to the Stan code." )
; ( "-O1"
, Arg.Unit (fun () -> opt_lvl := Optimize.O1)
, "\tApply level 1 compiler optimizations (only basic optimizations)." )
; ( "-Oexperimental"
, Arg.Unit (fun () -> opt_lvl := Optimize.Oexperimental)
, "\t(Experimental) Apply all compiler optimizations. Some of these are \
not thorougly tested and may not always improve a programs \
performance." )
; ( "--O"
, Arg.Set optimize
, " Allow the compiler to apply all optimizations to the Stan code." )
, Arg.Unit (fun () -> opt_lvl := Optimize.Oexperimental)
, "\t(Experimental) Same as -Oexperimental. Apply all compiler \
optimizations. Some of these are not thorougly tested and may not \
always improve a programs performance." )
; ( "--o"
, Arg.Set_string output_file
, " Take the path to an output file for generated C++ code (default = \
Expand Down Expand Up @@ -263,23 +276,21 @@ let use_file filename =
else if !warn_uninitialized then
Pedantic_analysis.warn_uninitialized mir
|> pp_stderr (Warnings.pp_warnings ?printed_filename) ;
let tx_mir =
Optimize.optimization_suite ~settings:Optimize.settings_default
(Transform_Mir.trans_prog mir) in
let tx_mir = Transform_Mir.trans_prog mir in
if !dump_tx_mir then
Sexp.pp_hum Format.std_formatter [%sexp (tx_mir : Middle.Program.Typed.t)] ;
if !dump_tx_mir_pretty then Program.Typed.pp Format.std_formatter tx_mir ;
let opt_mir =
if !optimize then (
let opt = Optimize.optimization_suite tx_mir in
if !dump_opt_mir then
Sexp.pp_hum Format.std_formatter
[%sexp (opt : Middle.Program.Typed.t)] ;
if !dump_opt_mir_pretty then Program.Typed.pp Format.std_formatter opt ;
opt )
else tx_mir in
let opt =
Optimize.optimization_suite
~settings:(Optimize.level_optimizations !opt_lvl)
tx_mir in
if !dump_opt_mir then
Sexp.pp_hum Format.std_formatter [%sexp (opt : Middle.Program.Typed.t)] ;
if !dump_opt_mir_pretty then Program.Typed.pp Format.std_formatter opt ;
opt in
if !output_file = "" then output_file := remove_dotstan !model_file ^ ".hpp" ;
let cpp = Fmt.str "%a" Stan_math_code_gen.pp_prog opt_mir in
let cpp = Fmt.strf "%a" Stan_math_code_gen.pp_prog opt_mir in
Out_channel.write_all !output_file ~data:cpp ;
if !print_model_cpp then print_endline cpp )

Expand Down
10 changes: 8 additions & 2 deletions test/integration/cli-args/canonicalize/canonicalize.t
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ Test that a nonsense argument is caught
--print-canonical Prints the canonicalized program. Equivalent to --auto-format --canonicalize [all options]
--version Display stanc version number
--name Take a string to set the model name (default = "$model_filename_model")
--O Allow the compiler to apply all optimizations to the Stan code.
-O0 (Default) Do not apply optimizations to the Stan code.
-O1 Apply level 1 compiler optimizations (only basic optimizations).
-Oexperimental (Experimental) Apply all compiler optimizations. Some of these are not thorougly tested and may not always improve a programs performance.
--O (Experimental) Same as -Oexperimental. Apply all compiler optimizations. Some of these are not thorougly tested and may not always improve a programs performance.
--o Take the path to an output file for generated C++ code (default = "$name.hpp") or auto-formatting output (default: no file/print to stdout)
--print-cpp If set, output the generated C++ Stan model class to stdout.
--allow-undefined Do not fail if a function is declared but not defined
Expand Down Expand Up @@ -62,7 +65,10 @@ Test capitalization - this should fail due to the lack of model_name, not the ca
--print-canonical Prints the canonicalized program. Equivalent to --auto-format --canonicalize [all options]
--version Display stanc version number
--name Take a string to set the model name (default = "$model_filename_model")
--O Allow the compiler to apply all optimizations to the Stan code.
-O0 (Default) Do not apply optimizations to the Stan code.
-O1 Apply level 1 compiler optimizations (only basic optimizations).
-Oexperimental (Experimental) Apply all compiler optimizations. Some of these are not thorougly tested and may not always improve a programs performance.
--O (Experimental) Same as -Oexperimental. Apply all compiler optimizations. Some of these are not thorougly tested and may not always improve a programs performance.
--o Take the path to an output file for generated C++ code (default = "$name.hpp") or auto-formatting output (default: no file/print to stdout)
--print-cpp If set, output the generated C++ Stan model class to stdout.
--allow-undefined Do not fail if a function is declared but not defined
Expand Down
5 changes: 4 additions & 1 deletion test/integration/cli-args/stanc.t
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ Show help
--print-canonical Prints the canonicalized program. Equivalent to --auto-format --canonicalize [all options]
--version Display stanc version number
--name Take a string to set the model name (default = "$model_filename_model")
--O Allow the compiler to apply all optimizations to the Stan code.
-O0 (Default) Do not apply optimizations to the Stan code.
-O1 Apply level 1 compiler optimizations (only basic optimizations).
-Oexperimental (Experimental) Apply all compiler optimizations. Some of these are not thorougly tested and may not always improve a programs performance.
--O (Experimental) Same as -Oexperimental. Apply all compiler optimizations. Some of these are not thorougly tested and may not always improve a programs performance.
--o Take the path to an output file for generated C++ code (default = "$name.hpp") or auto-formatting output (default: no file/print to stdout)
--print-cpp If set, output the generated C++ Stan model class to stdout.
--allow-undefined Do not fail if a function is declared but not defined
Expand Down
Loading

0 comments on commit f9be294

Please sign in to comment.