Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite termination measures sooner #930

Merged
merged 4 commits into from
Feb 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/lib/callgraph.ml
Original file line number Diff line number Diff line change
Expand Up @@ -582,3 +582,27 @@ let slice_instantiation_types sail_dir ast =
let g = G.prune roots NodeSet.empty g in
let ast = filter_ast_extra NodeSet.empty g ast false in
filter_library_files sail_dir ast

module FCG = Graph.Make (Id)

let function_call_graph ast =
let module G = Graph.Make (Id) in
let scan_funcl graph (FCL_aux (FCL_funcl (id, pexp), _)) =
let callees =
fold_pexp
{
(pure_exp_alg [] ( @ )) with
e_app = (fun (id', args) -> id' :: List.concat args);
e_app_infix = (fun (arg1, id', arg2) -> (id' :: arg1) @ arg2);
}
pexp
in
FCG.add_edges id callees graph
in
let scan_function graph (FD_aux (FD_function (_, _, funcls), _)) = List.fold_left scan_funcl graph funcls in
let scan_def graph = function
| DEF_aux (DEF_fundef fd, _) -> scan_function graph fd
| DEF_aux (DEF_internal_mutrec fds, _) -> List.fold_left scan_function graph fds
| _ -> graph
in
List.fold_left scan_def FCG.empty ast.defs
8 changes: 8 additions & 0 deletions src/lib/callgraph.mli
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,11 @@ val filter_ast_extra : Set.Make(Node).t -> callgraph -> ('a, 'b) ast -> bool ->
val top_sort_defs : Type_check.typed_ast -> Type_check.typed_ast

val slice_instantiation_types : string -> Type_check.typed_ast -> Type_check.typed_ast

(** Callgraph consisting *only* of calls, not other dependencies. Doesn't rely on types. *)

module FCG : sig
include Graph.S with type node = id and type node_set = IdSet.t and type graph = Graph.Make(Id).graph
end

val function_call_graph : ('a, 'b) Ast_defs.ast -> FCG.graph
75 changes: 63 additions & 12 deletions src/lib/rewrites.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1754,22 +1754,29 @@ let swaptyp typ (l, tannot) =
| Some (env, typ') -> (l, mk_tannot env typ)
| _ -> raise (Reporting.err_unreachable l __POS__ "swaptyp called with empty type annotation")

let is_funcl_rec (FCL_aux (FCL_funcl (id, pexp), _)) =
let recursive_fn_map ast =
let cg = Callgraph.function_call_graph ast in
let components = Callgraph.FCG.scc cg in
List.fold_left (fun m ids -> List.fold_left (fun m id -> Bindings.add id ids m) m ids) Bindings.empty components

let is_funcl_rec rec_fns (FCL_aux (FCL_funcl (id, pexp), _)) =
let ids = Bindings.find id rec_fns in
fold_pexp
{
(pure_exp_alg false ( || )) with
e_app = (fun (id', args) -> Id.compare id id' == 0 || List.exists (fun x -> x) args);
e_app_infix = (fun (arg1, id', arg2) -> arg1 || arg2 || Id.compare id id' == 0);
e_app = (fun (id', args) -> List.exists (fun id -> Id.compare id id' == 0) ids || List.exists (fun x -> x) args);
e_app_infix = (fun (arg1, id', arg2) -> arg1 || arg2 || List.exists (fun id -> Id.compare id id' == 0) ids);
}
pexp

(* Sail code isn't required to declare recursive functions as
recursive, so if a backend needs them then this rewrite updates
them. (Also see minimise_recursive_functions.) *)
let rewrite_add_unspecified_rec env ast =
let rec_fn_map = recursive_fn_map ast in
let rewrite_function (FD_aux (FD_function (recopt, topt, funcls), ann) as fd) =
match recopt with
| Rec_aux (Rec_nonrec, l) when List.exists is_funcl_rec funcls ->
| Rec_aux (Rec_nonrec, l) when List.exists (is_funcl_rec rec_fn_map) funcls ->
FD_aux (FD_function (Rec_aux (Rec_rec, Generated l), topt, funcls), ann)
| _ -> fd
in
Expand Down Expand Up @@ -3822,14 +3829,15 @@ end

(* Splitting a function (e.g., an execute function on an AST) can produce
new functions that appear to be recursive but are not. This checks to
see if the flag can be turned off. Doesn't handle mutual recursion
for now. *)
see if the flag can be turned off. *)

let minimise_recursive_functions env ast =
let rec_fn_map = recursive_fn_map ast in
let rewrite_function (FD_aux (FD_function (recopt, topt, funcls), ann) as fd) =
match recopt with
| Rec_aux (Rec_nonrec, _) -> fd
| Rec_aux ((Rec_rec | Rec_measure _), l) ->
if List.exists is_funcl_rec funcls then fd
if List.exists (is_funcl_rec rec_fn_map) funcls then fd
else FD_aux (FD_function (Rec_aux (Rec_nonrec, Generated l), topt, funcls), ann)
in
let rewrite_def = function
Expand Down Expand Up @@ -3897,30 +3905,73 @@ let move_loop_measures ast =
in
{ ast with defs = List.rev rev_defs }

let called_fns_in_exp exp =
fold_exp
{
(pure_exp_alg [] ( @ )) with
e_app = (fun (id', args) -> id' :: List.concat args);
e_app_infix = (fun (arg1, id', arg2) -> (id' :: arg1) @ arg2);
}
exp

(* Move recursive function termination measures into the function definitions. *)
let move_termination_measures env ast =
let measures =
(* To ensure that the result will type check, we need to move any valspecs for functions
directly used in the measure forward. The definitions themselves will be rearranged
later by the sorting rewrite. Note that the type checker will ensure that a valspec
always exists. *)
let measures, called_fns =
List.fold_left
(fun m def ->
(fun (m, called) def ->
match def with
| DEF_aux (DEF_measure (id, pat, exp), ann) ->
if Bindings.mem id m then
raise (Reporting.err_general ann.loc ("Second termination measure given for " ^ string_of_id id))
else Bindings.add id (pat, exp) m
else (
let called_fns = called_fns_in_exp exp in
(Bindings.add id (pat, exp, called_fns) m, List.fold_left (fun s id -> IdSet.add id s) called called_fns)
)
| _ -> (m, called)
)
(Bindings.empty, IdSet.empty) ast.defs
in
let specs_of_called =
List.fold_left
(fun m def ->
match def with
| DEF_aux (DEF_val (VS_aux (VS_val_spec (_, id, _), _)), _) as def when IdSet.mem id called_fns ->
Bindings.add id def m
| _ -> m
)
Bindings.empty ast.defs
in
let called_output = ref IdSet.empty in
let rec aux acc = function
| [] -> List.rev acc
| (DEF_aux (DEF_fundef (FD_aux (FD_function (r, ty, fs), (l, f_ann))), def_annot) as d) :: t -> begin
let id = match fs with [] -> assert false (* TODO *) | FCL_aux (FCL_funcl (id, _), _) :: _ -> id in
match Bindings.find_opt id measures with
| None -> aux (d :: acc) t
| Some (pat, exp) ->
| Some (pat, exp, called_fns) ->
let r = Rec_aux (Rec_measure (pat, exp), Generated l) in
aux (DEF_aux (DEF_fundef (FD_aux (FD_function (r, ty, fs), (l, f_ann))), def_annot) :: acc) t
let new_def = DEF_aux (DEF_fundef (FD_aux (FD_function (r, ty, fs), (l, f_ann))), def_annot) in
let moved_val_specs =
List.fold_left
(fun moved id ->
if not (IdSet.mem id !called_output) then (
called_output := IdSet.add id !called_output;
Bindings.find id specs_of_called :: moved
)
else moved
)
[] called_fns
in
aux ((new_def :: moved_val_specs) @ acc) t
end
| DEF_aux (DEF_val (VS_aux (VS_val_spec (_, id, _), _)), _) :: t when IdSet.mem id !called_output -> aux acc t
| (DEF_aux (DEF_val (VS_aux (VS_val_spec (_, id, _), _)), _) as d) :: t ->
called_output := IdSet.add id !called_output;
aux (d :: acc) t
| DEF_aux (DEF_measure _, _) :: t -> aux acc t
| h :: t -> aux (h :: acc) t
in
Expand Down
2 changes: 1 addition & 1 deletion src/sail_coq_backend/sail_plugin_coq.ml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ let coq_options =
let coq_rewrites =
let open Rewrites in
[
("move_termination_measures", []);
("instantiate_outcomes", [String_arg "coq"]);
("realize_mappings", []);
("remove_vector_subrange_pats", []);
Expand Down Expand Up @@ -151,7 +152,6 @@ let coq_rewrites =
like this again, this is where it would go.
("prover_regstate", [Bool_arg true]);*)
(* ("remove_assert", rewrite_ast_remove_assert); *)
("move_termination_measures", []);
("top_sort_defs", []);
("const_prop_mutrec", [String_arg "coq"]);
("exp_lift_assign", []);
Expand Down
2 changes: 1 addition & 1 deletion src/sail_lean_backend/sail_plugin_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ let lean_options =
let lean_rewrites =
let open Rewrites in
[
("move_termination_measures", []);
("instantiate_outcomes", [String_arg "coq"]);
("realize_mappings", []);
("remove_vector_subrange_pats", []);
Expand Down Expand Up @@ -121,7 +122,6 @@ let lean_rewrites =
which has to be followed by type checking *)
(* ("prover_regstate", [Bool_arg false]); *)
(* ("remove_assert", rewrite_ast_remove_assert); *)
("move_termination_measures", []);
("top_sort_defs", []);
("const_prop_mutrec", [String_arg "coq"]);
("exp_lift_assign", []);
Expand Down
2 changes: 1 addition & 1 deletion src/sail_lem_backend/sail_plugin_lem.ml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ let lem_options =
let lem_rewrites =
let open Rewrites in
[
("move_termination_measures", []);
("instantiate_outcomes", [String_arg "lem"]);
("realize_mappings", []);
("remove_vector_subrange_pats", []);
Expand Down Expand Up @@ -127,7 +128,6 @@ let lem_rewrites =
(* Put prover regstate generation after removing bitfield records,
which has to be followed by type checking *)
("prover_regstate", [Flag_arg Monomorphise.opt_mwords]);
("move_termination_measures", []);
("top_sort_defs", []);
("const_prop_mutrec", [String_arg "lem"]);
("vector_string_pats_to_bit_list", []);
Expand Down
24 changes: 24 additions & 0 deletions test/coq/pass/move_measure.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
default Order dec
$include <prelude.sail>

union foo = {
A : int,
B : nat,
C : bool,
}

function f(x : foo) -> int =
match x {
A(i) => i,
B(n) => f(A(n)),
C(b) => if b then 1 else 0,
}

function f_measure(x : foo) -> int =
match x {
A(_) => 1,
B(_) => 2,
C(_) => 1,
}

termination_measure f(x) = f_measure(x)
Loading