Skip to content

Commit a295010

Browse files
authored
flambda2-types: New n-way join algorithm (#3538)
The existing join algorithm suffers from several drawbacks: - It can be slow due to the use of a quadratic algorithm, taking up to 60% of the total compilation time in -O3 mode in pathological cases (lambda_to_flambda_primitives.ml). See also #3300. - It is inefficient as it computes the join of all types appearing in *any* joined environment prior to filtering out the types that are not needed, instead of first computing the types whose join will be needed. - It is sensitive to the names of local variables that only exist in some of the joined environments but not in the target environment. - It relies on a global binding time of variables across all joined environments and the target environment that does not exist, as figured in #3278. Subsequently, it can lose aliasing information, and breaks typing env invariants by recording the same variable as defined multiple times (with dubious semantics). This patch implements a new join algorithm, based on a n-way join of types. The new algorithm is: - Faster, as it avoids quadratic complexity (outside of complex nesting of env extensions). Compared to the existing join algorithm (with advanced meet), on my machine, the new join algorithm is 30x faster on the pathological lambda_to_flambda_primitives.ml, taking only around 10% of the total compilation time and speeding up the compilation of the file by 3.5x. On camlinternalFormat.ml, the new join is about 2.5-3x faster, reducing the time spent in the join from 20% to less than 10% and speeding up the total compilation time by about 20%. - More efficient, as it only computes a join if it can possibly result in a more precise type, i.e. if the variable has been assigned a new type in all joined environments (otherwise the existing type in the target environment is already the most precise). - Independent of the names of local variables. - Only depends on a consistent binding time *order* of the shared variables (defined in both the target environment and all joined environments), which is respected. Since the result is independent of the binding times of local / existential variables, the typing env invariants are respected.
1 parent 865e207 commit a295010

22 files changed

+5443
-312
lines changed

.github/workflows/build.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,14 @@ jobs:
5858
config: --enable-middle-end=flambda2 --enable-frame-pointers --enable-runtime5 --enable-poll-insertion --enable-flambda-invariants
5959
os: ubuntu-latest
6060
build_ocamlparam: ''
61-
ocamlparam: '_,O3=1,flambda2-expert-cont-lifting-budget=200'
61+
ocamlparam: '_,O3=1,flambda2-expert-cont-lifting-budget=200,flambda2-join-algorithm=n-way'
6262

6363
- name: flambda2_o3_advanced_meet_frame_pointers_runtime5_debug
6464
config: --enable-middle-end=flambda2 --enable-frame-pointers --enable-runtime5
6565
os: ubuntu-latest
6666
build_ocamlparam: ''
6767
use_runtime: d
68-
ocamlparam: '_,O3=1,flambda2-expert-cont-lifting-budget=200,cfg-invariants=1,cfg-eliminate-dead-trap-handlers=1'
68+
ocamlparam: '_,O3=1,flambda2-expert-cont-lifting-budget=200,cfg-invariants=1,cfg-eliminate-dead-trap-handlers=1,flambda2-join-algorithm=n-way'
6969

7070
- name: flambda2_frame_pointers_oclassic_polling
7171
config: --enable-middle-end=flambda2 --enable-frame-pointers --enable-poll-insertion --enable-flambda-invariants

driver/flambda_backend_args.ml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,15 @@ let mk_flambda2_advanced_meet f =
268268
Printf.sprintf " Use an advanced meet algorithm (deprecated) (Flambda 2 only)"
269269
;;
270270

271+
let mk_flambda2_join_algorithm f =
272+
"-flambda2-join-algorithm", Arg.Symbol (["binary"; "n-way"; "checked"], f),
273+
Printf.sprintf " Select the join algorithm to use (Flambda 2 only)\n \
274+
\ Valid values are: \n\
275+
\ \"binary\" is the legacy binary join;\n\
276+
\ \"n-way\" is the new n-way join;\n\
277+
\ \"checked\" runs both algorithms and compares them (use for \
278+
debugging)."
279+
;;
271280

272281
let mk_flambda2_join_points f =
273282
"-flambda2-join-points", Arg.Unit f,
@@ -777,6 +786,7 @@ module type Flambda_backend_options = sig
777786
val no_flambda2_result_types : unit -> unit
778787
val flambda2_basic_meet : unit -> unit
779788
val flambda2_advanced_meet : unit -> unit
789+
val flambda2_join_algorithm : string -> unit
780790
val flambda2_unbox_along_intra_function_control_flow : unit -> unit
781791
val no_flambda2_unbox_along_intra_function_control_flow : unit -> unit
782792
val flambda2_backend_cse_at_toplevel : unit -> unit
@@ -916,6 +926,7 @@ struct
916926
F.no_flambda2_result_types;
917927
mk_flambda2_basic_meet F.flambda2_basic_meet;
918928
mk_flambda2_advanced_meet F.flambda2_advanced_meet;
929+
mk_flambda2_join_algorithm F.flambda2_join_algorithm;
919930
mk_flambda2_unbox_along_intra_function_control_flow
920931
F.flambda2_unbox_along_intra_function_control_flow;
921932
mk_no_flambda2_unbox_along_intra_function_control_flow
@@ -1126,6 +1137,15 @@ module Flambda_backend_options_impl = struct
11261137
Flambda2.function_result_types := Flambda_backend_flags.Set Flambda_backend_flags.Never
11271138
let flambda2_basic_meet () = ()
11281139
let flambda2_advanced_meet () = ()
1140+
let flambda2_join_algorithm algorithm =
1141+
match algorithm with
1142+
| "binary" ->
1143+
Flambda2.join_algorithm := Flambda_backend_flags.Set Flambda_backend_flags.Binary
1144+
| "n-way" ->
1145+
Flambda2.join_algorithm := Flambda_backend_flags.Set Flambda_backend_flags.N_way
1146+
| "checked" ->
1147+
Flambda2.join_algorithm := Flambda_backend_flags.Set Flambda_backend_flags.Checked
1148+
| _ -> () (* This should not occur as we use Arg.Symbol *)
11291149
let flambda2_unbox_along_intra_function_control_flow =
11301150
set Flambda2.unbox_along_intra_function_control_flow
11311151
let no_flambda2_unbox_along_intra_function_control_flow =
@@ -1456,6 +1476,13 @@ module Extra_params = struct
14561476
| _ ->
14571477
Misc.fatal_error "Syntax: flambda2-meet_algorithm=basic|advanced");
14581478
true
1479+
| "flambda2-join-algorithm" ->
1480+
(match String.lowercase_ascii v with
1481+
| "binary" | "n-way" | "checked" as v ->
1482+
Flambda_backend_options_impl.flambda2_join_algorithm v
1483+
| _ ->
1484+
Misc.fatal_error "Syntax: flambda2-join-algorithm=binary|n-way|checked");
1485+
true
14591486
| "flambda2-unbox-along-intra-function-control-flow" ->
14601487
set Flambda2.unbox_along_intra_function_control_flow
14611488
| "flambda2-backend-cse-at-toplevel" ->

driver/flambda_backend_args.mli

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ module type Flambda_backend_options = sig
9191
val no_flambda2_result_types : unit -> unit
9292
val flambda2_basic_meet : unit -> unit
9393
val flambda2_advanced_meet : unit -> unit
94+
val flambda2_join_algorithm : string -> unit
9495
val flambda2_unbox_along_intra_function_control_flow : unit -> unit
9596
val no_flambda2_unbox_along_intra_function_control_flow : unit -> unit
9697
val flambda2_backend_cse_at_toplevel : unit -> unit

driver/flambda_backend_flags.ml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ let long_frames_threshold = ref max_long_frames_threshold (* -debug-long-frames-
9696
let caml_apply_inline_fast_path = ref false (* -caml-apply-inline-fast-path *)
9797

9898
type function_result_types = Never | Functors_only | All_functions
99+
type join_algorithm = Binary | N_way | Checked
99100
type opt_level = Oclassic | O2 | O3
100101
type 'a or_default = Set of 'a | Default
101102

@@ -128,6 +129,7 @@ module Flambda2 = struct
128129
let backend_cse_at_toplevel = false
129130
let cse_depth = 2
130131
let join_depth = 5
132+
let join_algorithm = Binary
131133
let function_result_types = Never
132134
let enable_reaper = false
133135
let unicode = true
@@ -141,6 +143,7 @@ module Flambda2 = struct
141143
backend_cse_at_toplevel : bool;
142144
cse_depth : int;
143145
join_depth : int;
146+
join_algorithm : join_algorithm;
144147
function_result_types : function_result_types;
145148
enable_reaper : bool;
146149
unicode : bool;
@@ -154,6 +157,7 @@ module Flambda2 = struct
154157
backend_cse_at_toplevel = Default.backend_cse_at_toplevel;
155158
cse_depth = Default.cse_depth;
156159
join_depth = Default.join_depth;
160+
join_algorithm = Default.join_algorithm;
157161
function_result_types = Default.function_result_types;
158162
enable_reaper = Default.enable_reaper;
159163
unicode = Default.unicode;
@@ -187,6 +191,7 @@ module Flambda2 = struct
187191
let backend_cse_at_toplevel = ref Default
188192
let cse_depth = ref Default
189193
let join_depth = ref Default
194+
let join_algorithm = ref Default
190195
let unicode = ref Default
191196
let kind_checks = ref Default
192197
let function_result_types = ref Default

driver/flambda_backend_flags.mli

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ val long_frames_threshold : int ref
8383
val caml_apply_inline_fast_path : bool ref
8484

8585
type function_result_types = Never | Functors_only | All_functions
86+
type join_algorithm = Binary | N_way | Checked
8687
type opt_level = Oclassic | O2 | O3
8788
type 'a or_default = Set of 'a | Default
8889

@@ -109,6 +110,7 @@ module Flambda2 : sig
109110
val backend_cse_at_toplevel : bool
110111
val cse_depth : int
111112
val join_depth : int
113+
val join_algorithm : join_algorithm
112114
val function_result_types : function_result_types
113115
val enable_reaper : bool
114116
val unicode : bool
@@ -125,6 +127,7 @@ module Flambda2 : sig
125127
backend_cse_at_toplevel : bool;
126128
cse_depth : int;
127129
join_depth : int;
130+
join_algorithm : join_algorithm;
128131
function_result_types : function_result_types;
129132
enable_reaper : bool;
130133
unicode : bool;
@@ -141,6 +144,7 @@ module Flambda2 : sig
141144
val backend_cse_at_toplevel : bool or_default ref
142145
val cse_depth : int or_default ref
143146
val join_depth : int or_default ref
147+
val join_algorithm : join_algorithm or_default ref
144148
val enable_reaper : bool or_default ref
145149
val unicode : bool or_default ref
146150
val kind_checks : bool or_default ref

middle_end/flambda2/tests/meet_test.ml

Lines changed: 152 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,153 @@ let meet_variants_don't_lose_aliases () =
139139
Format.eprintf "@[<hov 2>meet:@ %a@]@.@[<hov 2>env:@ %a@]@." T.print
140140
tag_meet_ty TE.print tag_meet_env)
141141

142+
let test_join_with_extensions () =
143+
let define ?(kind = K.value) env v =
144+
let v' = Bound_var.create v Name_mode.normal in
145+
TE.add_definition env (Bound_name.create_var v') kind
146+
in
147+
let env = create_env () in
148+
let y = Variable.create "y" in
149+
let x = Variable.create "x" in
150+
let a = Variable.create "a" in
151+
let b = Variable.create "b" in
152+
let env = define env y in
153+
let env = define env x in
154+
let env = define ~kind:K.naked_immediate env a in
155+
let env = define ~kind:K.naked_immediate env b in
156+
let tag_0 = Tag.Scannable.zero in
157+
let tag_1 = Option.get (Tag.Scannable.of_tag (Tag.create_exn 1)) in
158+
let make ty =
159+
T.variant
160+
~const_ctors:(T.bottom K.naked_immediate)
161+
~non_const_ctors:
162+
(Tag.Scannable.Map.of_list
163+
[ tag_0, (K.Block_shape.Scannable Value_only, [ty]);
164+
tag_1, (K.Block_shape.Scannable Value_only, []) ])
165+
Alloc_mode.For_types.heap
166+
in
167+
let env = TE.add_equation env (Name.var y) (make (T.unknown K.value)) in
168+
let scope = TE.current_scope env in
169+
let scoped_env = TE.increment_scope env in
170+
let left_env =
171+
TE.add_equation scoped_env (Name.var x)
172+
(T.tagged_immediate_alias_to ~naked_immediate:a)
173+
in
174+
let right_env =
175+
TE.add_equation scoped_env (Name.var x)
176+
(T.tagged_immediate_alias_to ~naked_immediate:b)
177+
in
178+
let ty_a = make (T.tagged_immediate_alias_to ~naked_immediate:a) in
179+
let ty_b = make (T.tagged_immediate_alias_to ~naked_immediate:b) in
180+
let left_env = TE.add_equation left_env (Name.var y) ty_a in
181+
let right_env =
182+
match T.meet right_env ty_a ty_b with
183+
| Ok (ty, right_env) -> TE.add_equation right_env (Name.var y) ty
184+
| Bottom -> assert false
185+
in
186+
Format.eprintf "Left:@.%a@." TE.print left_env;
187+
Format.eprintf "Right:@.%a@." TE.print right_env;
188+
let joined_env =
189+
T.cut_and_n_way_join scoped_env
190+
[ left_env, Apply_cont_rewrite_id.create (), Inlinable;
191+
right_env, Apply_cont_rewrite_id.create (), Inlinable ]
192+
~params:Bound_parameters.empty ~cut_after:scope
193+
~extra_allowed_names:Name_occurrences.empty
194+
~extra_lifted_consts_in_use_envs:Symbol.Set.empty
195+
in
196+
Format.eprintf "Res:@.%a@." TE.print joined_env
197+
198+
let test_join_with_complex_extensions () =
199+
let define ?(kind = K.value) env v =
200+
let v' = Bound_var.create v Name_mode.normal in
201+
TE.add_definition env (Bound_name.create_var v') kind
202+
in
203+
let env = create_env () in
204+
let y = Variable.create "y" in
205+
let x = Variable.create "x" in
206+
let w = Variable.create "w" in
207+
let z = Variable.create "z" in
208+
let a = Variable.create "a" in
209+
let b = Variable.create "b" in
210+
let c = Variable.create "c" in
211+
let d = Variable.create "d" in
212+
let env = define env z in
213+
let env = define env x in
214+
let env = define env y in
215+
let env = define env w in
216+
let env = define ~kind:K.naked_immediate env a in
217+
let env = define ~kind:K.naked_immediate env b in
218+
let env = define ~kind:K.naked_immediate env c in
219+
let env = define ~kind:K.naked_immediate env d in
220+
let tag_0 = Tag.Scannable.zero in
221+
let tag_1 = Option.get (Tag.Scannable.of_tag (Tag.create_exn 1)) in
222+
let make tys =
223+
T.variant
224+
~const_ctors:(T.bottom K.naked_immediate)
225+
~non_const_ctors:
226+
(Tag.Scannable.Map.of_list
227+
[ tag_0, (K.Block_shape.Scannable Value_only, tys);
228+
tag_1, (K.Block_shape.Scannable Value_only, []) ])
229+
Alloc_mode.For_types.heap
230+
in
231+
let env =
232+
TE.add_equation env (Name.var z)
233+
(make [T.unknown K.value; T.unknown K.value])
234+
in
235+
let scope = TE.current_scope env in
236+
let scoped_env = TE.increment_scope env in
237+
let left_env =
238+
TE.add_equation scoped_env (Name.var x)
239+
(T.tagged_immediate_alias_to ~naked_immediate:a)
240+
in
241+
let left_env =
242+
TE.add_equation left_env (Name.var y)
243+
(T.tagged_immediate_alias_to ~naked_immediate:a)
244+
in
245+
let left_env =
246+
TE.add_equation left_env (Name.var w)
247+
(T.tagged_immediate_alias_to ~naked_immediate:a)
248+
in
249+
let right_env =
250+
TE.add_equation scoped_env (Name.var x)
251+
(T.tagged_immediate_alias_to ~naked_immediate:b)
252+
in
253+
let right_env =
254+
TE.add_equation right_env (Name.var y)
255+
(T.tagged_immediate_alias_to ~naked_immediate:c)
256+
in
257+
let right_env =
258+
TE.add_equation right_env (Name.var w)
259+
(T.tagged_immediate_alias_to ~naked_immediate:d)
260+
in
261+
let ty_a =
262+
make
263+
[ T.tagged_immediate_alias_to ~naked_immediate:b;
264+
T.tagged_immediate_alias_to ~naked_immediate:b ]
265+
in
266+
let ty_b =
267+
make
268+
[ T.tagged_immediate_alias_to ~naked_immediate:c;
269+
T.tagged_immediate_alias_to ~naked_immediate:d ]
270+
in
271+
let left_env = TE.add_equation left_env (Name.var z) ty_a in
272+
let right_env =
273+
match T.meet right_env ty_a ty_b with
274+
| Ok (ty, right_env) -> TE.add_equation right_env (Name.var z) ty
275+
| Bottom -> assert false
276+
in
277+
Format.eprintf "Left:@.%a@." TE.print left_env;
278+
Format.eprintf "Right:@.%a@." TE.print right_env;
279+
let joined_env =
280+
T.cut_and_n_way_join scoped_env
281+
[ left_env, Apply_cont_rewrite_id.create (), Inlinable;
282+
right_env, Apply_cont_rewrite_id.create (), Inlinable ]
283+
~params:Bound_parameters.empty ~cut_after:scope
284+
~extra_allowed_names:Name_occurrences.empty
285+
~extra_lifted_consts_in_use_envs:Symbol.Set.empty
286+
in
287+
Format.eprintf "Res:@.%a@." TE.print joined_env
288+
142289
let test_meet_two_blocks () =
143290
let define env v =
144291
let v' = Bound_var.create v Name_mode.normal in
@@ -272,4 +419,8 @@ let () =
272419
Format.eprintf "@.MEET ALIAS TO RECOVER @\n@.";
273420
test_meet_recover_alias ();
274421
Format.eprintf "@.MEET BOTTOM AFTER ALIAS@\n@.";
275-
test_meet_bottom_after_alias ()
422+
test_meet_bottom_after_alias ();
423+
Format.eprintf "@.JOIN WITH EXTENSIONS@\n@.";
424+
test_join_with_extensions ();
425+
Format.eprintf "@.JOIN WITH COMPLEX EXTENSIONS@\n@.";
426+
test_join_with_complex_extensions ()

0 commit comments

Comments
 (0)