Skip to content

Commit

Permalink
Lean: add support for range types
Browse files Browse the repository at this point in the history
  • Loading branch information
ineol committed Jan 17, 2025
1 parent cf168b5 commit e42ef7e
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 70 deletions.
141 changes: 74 additions & 67 deletions src/sail_lean_backend/pretty_print_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,32 @@ open PPrint
open Pretty_print_common

type context = {
env : Type_check.env;
kid_id_renames : id option KBindings.t; (* tyvar -> argument renames *)
kid_id_renames_rev : kid Bindings.t; (* reverse of kid_id_renames *)
}

let empty_context = { kid_id_renames = KBindings.empty; kid_id_renames_rev = Bindings.empty }
let initial_context env = { env; kid_id_renames = KBindings.empty; kid_id_renames_rev = Bindings.empty }

let add_single_kid_id_rename ctxt id kid =
let add_single_kid_id_rename ctx id kid =
let kir =
match Bindings.find_opt id ctxt.kid_id_renames_rev with
| Some kid -> KBindings.add kid None ctxt.kid_id_renames
| None -> ctxt.kid_id_renames
match Bindings.find_opt id ctx.kid_id_renames_rev with
| Some kid -> KBindings.add kid None ctx.kid_id_renames
| None -> ctx.kid_id_renames
in
{
(* ctxt with *)
ctx with
kid_id_renames = KBindings.add kid (Some id) kir;
kid_id_renames_rev = Bindings.add id kid ctxt.kid_id_renames_rev;
kid_id_renames_rev = Bindings.add id kid ctx.kid_id_renames_rev;
}

let implicit_parens x = enclose (string "{") (string "}") x

let doc_id_ctor (Id_aux (i, _)) =
match i with Id i -> string i | Operator x -> string (Util.zencode_string ("op " ^ x))

let doc_kid ctxt (Kid_aux (Var x, _) as ki) =
match KBindings.find_opt ki ctxt.kid_id_renames with
let doc_kid ctx (Kid_aux (Var x, _) as ki) =
match KBindings.find_opt ki ctx.kid_id_renames with
| Some (Some i) -> string (string_of_id i)
| _ -> string ("k_" ^ String.sub x 1 (String.length x - 1))

Expand Down Expand Up @@ -108,10 +109,10 @@ let string_of_nexp_con (Nexp_aux (n, l)) =
| Nexp_neg _ -> "Nexp_neg"
| Nexp_exp _ -> "Nexp_exp"

let doc_nexp ctxt (Nexp_aux (n, l) as nexp) =
let doc_nexp ctx (Nexp_aux (n, l) as nexp) =
match n with
| Nexp_constant i -> string (Big_int.to_string i)
| Nexp_var ki -> doc_kid ctxt ki
| Nexp_var ki -> doc_kid ctx ki
| _ -> failwith ("NExp " ^ string_of_nexp_con nexp ^ " " ^ string_of_nexp nexp ^ " not translatable yet.")

let string_of_typ_con (Typ_aux (t, _)) =
Expand All @@ -125,7 +126,9 @@ let string_of_typ_con (Typ_aux (t, _)) =
| Typ_internal_unknown -> "Typ_internal_unknown"
| Typ_id _ -> "Typ_id"

let rec doc_typ ctxt (Typ_aux (t, _) as typ) =
let provably_nneg ctx x = Type_check.prove __POS__ ctx.env (nc_gteq x (nint 0))

let rec doc_typ ctx (Typ_aux (t, _) as typ) =
match t with
| Typ_id (Id_aux (Id "unit", _)) -> string "Unit"
| Typ_id (Id_aux (Id "int", _)) -> string "Int"
Expand All @@ -134,13 +137,14 @@ let rec doc_typ ctxt (Typ_aux (t, _) as typ) =
| Typ_id (Id_aux (Id "nat", _)) -> string "Nat"
| Typ_app (Id_aux (Id "bitvector", _), [A_aux (A_nexp m, _)]) | Typ_app (Id_aux (Id "bits", _), [A_aux (A_nexp m, _)])
->
parens (string "BitVec " ^^ doc_nexp ctxt m)
| Typ_app (Id_aux (Id "atom", _), [A_aux (A_nexp (Nexp_aux (Nexp_var ki, _)), _)]) ->
string "Int" (* TODO This probably has to be generalized *)
parens (string "BitVec " ^^ doc_nexp ctx m)
| Typ_app (Id_aux (Id "atom", _), [A_aux (A_nexp x, _)]) -> if provably_nneg ctx x then string "Nat" else string "Int"
| Typ_app (Id_aux (Id "implicit", _), [A_aux (A_nexp (Nexp_aux (Nexp_var ki, _)), _)]) ->
underscore (* TODO check if the type of implicit arguments can really be always inferred *)
| Typ_tuple ts -> parens (separate_map (space ^^ string "×" ^^ space) (doc_typ ctxt) ts)
| Typ_tuple ts -> parens (separate_map (space ^^ string "×" ^^ space) (doc_typ ctx) ts)
| Typ_id (Id_aux (Id id, _)) -> string id
| Typ_app (Id_aux (Id "range", _), [A_aux (A_nexp low, _); A_aux (A_nexp high, _)]) ->
if provably_nneg ctx low then string "Nat" else string "Int"
| _ -> failwith ("Type " ^ string_of_typ_con typ ^ " " ^ string_of_typ typ ^ " not translatable yet.")

let rec captured_typ_var ((i, Typ_aux (t, _)) as typ) =
Expand All @@ -150,45 +154,45 @@ let rec captured_typ_var ((i, Typ_aux (t, _)) as typ) =
Some (i, ki)
| _ -> None

let doc_typ_id ctxt (typ, fid) = flow (break 1) [doc_id_ctor fid; colon; doc_typ ctxt typ]
let doc_typ_id ctx (typ, fid) = flow (break 1) [doc_id_ctor fid; colon; doc_typ ctx typ]

let doc_kind (K_aux (k, _)) =
match k with
| K_int -> string "Int"
| K_bool -> string "Bool"
| _ -> failwith ("Kind " ^ string_of_kind_aux k ^ " not translatable yet.")

let doc_typ_arg ctxt ta = string "foo" (* TODO implement *)
let doc_typ_arg ctx ta = string "foo" (* TODO implement *)

let rec doc_nconstraint ctxt (NC_aux (nc, _)) =
let rec doc_nconstraint ctx (NC_aux (nc, _)) =
match nc with
| NC_and (n1, n2) -> flow (break 1) [doc_nconstraint ctxt n1; string ""; doc_nconstraint ctxt n2]
| NC_or (n1, n2) -> flow (break 1) [doc_nconstraint ctxt n1; string ""; doc_nconstraint ctxt n2]
| NC_equal (a1, a2) -> flow (break 1) [doc_typ_arg ctxt a1; string "="; doc_typ_arg ctxt a2]
| NC_not_equal (a1, a2) -> flow (break 1) [doc_typ_arg ctxt a1; string ""; doc_typ_arg ctxt a2]
| NC_app (f, args) -> string (string_of_id f) ^^ parens (separate_map comma_sp (doc_typ_arg ctxt) args)
| NC_and (n1, n2) -> flow (break 1) [doc_nconstraint ctx n1; string ""; doc_nconstraint ctx n2]
| NC_or (n1, n2) -> flow (break 1) [doc_nconstraint ctx n1; string ""; doc_nconstraint ctx n2]
| NC_equal (a1, a2) -> flow (break 1) [doc_typ_arg ctx a1; string "="; doc_typ_arg ctx a2]
| NC_not_equal (a1, a2) -> flow (break 1) [doc_typ_arg ctx a1; string ""; doc_typ_arg ctx a2]
| NC_app (f, args) -> string (string_of_id f) ^^ parens (separate_map comma_sp (doc_typ_arg ctx) args)
| NC_false -> string "false"
| NC_true -> string "true"
| NC_ge (n1, n2) -> flow (break 1) [doc_nexp ctxt n1; string ""; doc_nexp ctxt n2]
| NC_le (n1, n2) -> flow (break 1) [doc_nexp ctxt n1; string ""; doc_nexp ctxt n2]
| NC_gt (n1, n2) -> flow (break 1) [doc_nexp ctxt n1; string ">"; doc_nexp ctxt n2]
| NC_lt (n1, n2) -> flow (break 1) [doc_nexp ctxt n1; string "<"; doc_nexp ctxt n2]
| NC_ge (n1, n2) -> flow (break 1) [doc_nexp ctx n1; string ""; doc_nexp ctx n2]
| NC_le (n1, n2) -> flow (break 1) [doc_nexp ctx n1; string ""; doc_nexp ctx n2]
| NC_gt (n1, n2) -> flow (break 1) [doc_nexp ctx n1; string ">"; doc_nexp ctx n2]
| NC_lt (n1, n2) -> flow (break 1) [doc_nexp ctx n1; string "<"; doc_nexp ctx n2]
| NC_id i -> string (string_of_id i)
| NC_set (n, vs) ->
flow (break 1)
[
doc_nexp ctxt n;
doc_nexp ctx n;
string "";
implicit_parens (separate_map comma_sp (fun x -> string (Nat_big_num.to_string x)) vs);
]
| NC_var ki -> doc_kid ctxt ki
| NC_var ki -> doc_kid ctx ki

let doc_quant_item ctxt (QI_aux (qi, _)) =
let doc_quant_item ctx (QI_aux (qi, _)) =
match qi with
| QI_id (KOpt_aux (KOpt_kind (k, ki), _)) -> flow (break 1) [doc_kid ctxt ki; colon; doc_kind k]
| QI_constraint c -> doc_nconstraint ctxt c
| QI_id (KOpt_aux (KOpt_kind (k, ki), _)) -> flow (break 1) [doc_kid ctx ki; colon; doc_kind k]
| QI_constraint c -> doc_nconstraint ctx c

let doc_typ_quant ctxt tq = match tq with TypQ_tq qs -> List.map (doc_quant_item ctxt) qs | TypQ_no_forall -> []
let doc_typ_quant ctx tq = match tq with TypQ_tq qs -> List.map (doc_quant_item ctx) qs | TypQ_no_forall -> []

let lean_escape_string s = Str.global_replace (Str.regexp "\"") "\"\"" s

Expand Down Expand Up @@ -248,7 +252,7 @@ let string_of_exp_con (E_aux (e, _)) =
| E_vector _ -> "E_vector"
| E_let _ -> "E_let"

let rec doc_exp ctxt (E_aux (e, (l, annot)) as full_exp) =
let rec doc_exp ctx (E_aux (e, (l, annot)) as full_exp) =
let env = env_of_tannot annot in
match e with
| E_id id -> string (string_of_id id) (* TODO replace by a translating via a binding map *)
Expand All @@ -259,33 +263,33 @@ let rec doc_exp ctxt (E_aux (e, (l, annot)) as full_exp) =
| E_app (f, args) ->
let d_id =
if Env.is_extern f env "lean" then string (Env.get_extern f env "lean")
else doc_exp ctxt (E_aux (E_id f, (l, annot)))
else doc_exp ctx (E_aux (E_id f, (l, annot)))
in
let d_args = List.map (doc_exp ctxt) args in
let d_args = List.map (doc_exp ctx) args in
nest 2 (parens (flow (break 1) (d_id :: d_args)))
| E_vector vals -> failwith "vector found"
| E_typ (typ, e) -> parens (separate space [doc_exp ctxt e; colon; doc_typ ctxt typ])
| E_tuple es -> parens (separate_map (comma ^^ space) (doc_exp ctxt) es)
| E_typ (typ, e) -> parens (separate space [doc_exp ctx e; colon; doc_typ ctx typ])
| E_tuple es -> parens (separate_map (comma ^^ space) (doc_exp ctx) es)
| E_let (LB_aux (LB_val (lpat, lexp), _), e) ->
let id =
match pat_is_plain_binder env lpat with
| Some (Some (Id_aux (Id id, _))) -> id
| Some None -> "x" (* TODO fresh name or wildcard instead of x *)
| _ -> failwith "Let pattern not translatable yet."
in
nest 2 (flow (break 1) [string "let"; string id; coloneq; doc_exp ctxt lexp]) ^^ hardline ^^ doc_exp ctxt e
nest 2 (flow (break 1) [string "let"; string id; coloneq; doc_exp ctx lexp]) ^^ hardline ^^ doc_exp ctx e
| _ -> failwith ("Expression " ^ string_of_exp_con full_exp ^ " " ^ string_of_exp full_exp ^ " not translatable yet.")

let doc_binder ctxt i t =
let doc_binder ctx i t =
let paranthesizer =
match t with
| Typ_aux (Typ_app (Id_aux (Id "implicit", _), [A_aux (A_nexp (Nexp_aux (Nexp_var ki, _)), _)]), _) ->
implicit_parens
| _ -> parens
in
(* Overwrite the id if it's captured *)
let ctxt = match captured_typ_var (i, t) with Some (i, ki) -> add_single_kid_id_rename ctxt i ki | _ -> ctxt in
(ctxt, separate space [string (string_of_id i); colon; doc_typ ctxt t] |> paranthesizer)
let ctx = match captured_typ_var (i, t) with Some (i, ki) -> add_single_kid_id_rename ctx i ki | _ -> ctx in
(ctx, separate space [string (string_of_id i); colon; doc_typ ctx t] |> paranthesizer)

let doc_funcl_init (FCL_aux (FCL_funcl (id, pexp), annot)) =
let env = env_of_tannot (snd annot) in
Expand All @@ -306,46 +310,49 @@ let doc_funcl_init (FCL_aux (FCL_funcl (id, pexp), annot)) =
| _ -> failwith "Argument pattern not translatable yet."
)
in
let ctxt = empty_context in
let ctxt, binders =
let ctx = initial_context env in
let ctx, binders =
List.fold_left
(fun (ctxt, bs) (i, t) ->
let ctxt, d = doc_binder ctxt i t in
(ctxt, bs @ [d])
(fun (ctx, bs) (i, t) ->
let ctx, d = doc_binder ctx i t in
(ctx, bs @ [d])
)
(ctxt, []) binders
(ctx, []) binders
in
let typ_quants = doc_typ_quant ctxt tq in
let typ_quants = doc_typ_quant ctx tq in
let typ_quant_comment =
if List.length typ_quants > 0 then
string "/-- Type quantifiers: " ^^ nest 2 (flow comma_sp typ_quants) ^^ string " -/" ^^ hardline
else empty
in
(* Use auto-implicits for type quanitifiers for now and see if this works *)
let doc_ret_typ = doc_typ ctxt ret_typ in
let doc_ret_typ = doc_typ ctx ret_typ in
let is_monadic = effectful (effect_of exp) in
(* Add monad for stateful functions *)
let doc_ret_typ = if is_monadic then string "SailM " ^^ doc_ret_typ else doc_ret_typ in
let decl_val = [doc_ret_typ; coloneq] in
(* Add do block for stateful functions *)
let decl_val = if is_monadic then decl_val @ [string "do"] else decl_val in
( typ_quant_comment,
separate space ([string "def"; string (string_of_id id)] @ binders @ [colon; doc_ret_typ; coloneq])
separate space ([string "def"; string (string_of_id id)] @ binders @ [colon; doc_ret_typ; coloneq]),
env
)

let doc_funcl_body (FCL_aux (FCL_funcl (id, pexp), annot)) =
let env = env_of_tannot (snd annot) in
let ctx = initial_context env in
let _, _, exp, _ = destruct_pexp pexp in
let is_monadic = effectful (effect_of exp) in
if is_monadic then nest 2 (flow (break 1) [string "return"; doc_exp empty_context exp]) else doc_exp empty_context exp
if is_monadic then nest 2 (flow (break 1) [string "return"; doc_exp ctx exp]) else doc_exp ctx exp

let doc_funcl funcl =
let comment, signature = doc_funcl_init funcl in
let doc_funcl ctx funcl =
let comment, signature, env = doc_funcl_init funcl in
comment ^^ nest 2 (signature ^^ hardline ^^ doc_funcl_body funcl)

let doc_fundef (FD_aux (FD_function (r, typa, fcls), fannot)) =
let doc_fundef ctx (FD_aux (FD_function (r, typa, fcls), fannot)) =
match fcls with
| [] -> failwith "FD_function with empty function list"
| [funcl] -> doc_funcl funcl
| [funcl] -> doc_funcl ctx funcl
| _ -> failwith "FD_function with more than one clause"

let string_of_type_def_con (TD_aux (td, _)) =
Expand All @@ -357,7 +364,7 @@ let string_of_type_def_con (TD_aux (td, _)) =
| TD_bitfield _ -> "TD_bitfield"
| TD_enum _ -> "TD_enum"

let doc_typdef ctxt (TD_aux (td, tannot) as full_typdef) =
let doc_typdef ctx (TD_aux (td, tannot) as full_typdef) =
match td with
| TD_enum (Id_aux (Id id, _), fields, _) ->
let derivers = if List.length fields > 0 then [string "Inhabited"] else [] in
Expand All @@ -370,21 +377,21 @@ let doc_typdef ctxt (TD_aux (td, tannot) as full_typdef) =
^^ separate (comma ^^ space) derivers
)
| TD_record (Id_aux (Id id, _), TypQ_aux (tq, _), fields, _) ->
let fields = List.map (doc_typ_id ctxt) fields in
let fields = List.map (doc_typ_id ctx) fields in
let enums_doc = separate hardline fields in
let rectyp = doc_typ_quant ctxt tq in
let rectyp = doc_typ_quant ctx tq in
(* TODO don't ignore type quantifiers *)
nest 2 (flow (break 1) [string "structure"; string id; string "where"] ^^ hardline ^^ enums_doc)
| TD_abbrev (Id_aux (Id id, _), tq, A_aux (A_typ t, _)) ->
nest 2 (flow (break 1) [string "def"; string id; coloneq; doc_typ ctxt t])
nest 2 (flow (break 1) [string "def"; string id; coloneq; doc_typ ctx t])
| TD_abbrev (Id_aux (Id id, _), tq, A_aux (A_nexp ne, _)) ->
nest 2 (flow (break 1) [string "def"; string id; colon; string "Int"; coloneq; doc_nexp ctxt ne])
nest 2 (flow (break 1) [string "def"; string id; colon; string "Int"; coloneq; doc_nexp ctx ne])
| _ -> failwith ("Type definition " ^ string_of_type_def_con full_typdef ^ " not translatable yet.")

let doc_def ctxt (DEF_aux (aux, def_annot) as def) =
let doc_def ctx (DEF_aux (aux, def_annot) as def) =
match aux with
| DEF_fundef fdef -> group (doc_fundef fdef) ^/^ hardline
| DEF_type tdef -> group (doc_typdef ctxt tdef) ^/^ hardline
| DEF_fundef fdef -> group (doc_fundef ctx fdef) ^/^ hardline
| DEF_type tdef -> group (doc_typdef ctx tdef) ^/^ hardline
| _ -> empty

(* Remove all imports for now, they will be printed in other files. Probably just for testing. *)
Expand All @@ -395,8 +402,8 @@ let rec remove_imports (defs : (Libsail.Type_check.tannot, Libsail.Type_check.en
| DEF_aux (DEF_pragma ("include_end", _, _), _) :: ds -> remove_imports ds (depth - 1)
| d :: ds -> if depth > 0 then remove_imports ds depth else d :: remove_imports ds depth

let pp_ast_lean ({ defs; _ } as ast : Libsail.Type_check.typed_ast) o =
let pp_ast_lean (env : Type_check.env) ({ defs; _ } as ast : Libsail.Type_check.typed_ast) o =
let defs = remove_imports defs 0 in
let output : document = separate_map empty (doc_def empty_context) defs in
let output : document = separate_map empty (doc_def (initial_context env)) defs in
print o output;
()
6 changes: 3 additions & 3 deletions src/sail_lean_backend/sail_plugin_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,15 @@ let create_lake_project (out_name : string) default_sail_dir =
output_string project_main "open Sail\n\n";
project_main

let output (out_name : string) ast default_sail_dir =
let output (out_name : string) env ast default_sail_dir =
let project_main = create_lake_project out_name default_sail_dir in
(* Uncomment for debug output of the Sail code after the rewrite passes *)
(* Pretty_print_sail.output_ast stdout (Type_check.strip_ast ast); *)
Pretty_print_lean.pp_ast_lean ast project_main;
Pretty_print_lean.pp_ast_lean env ast project_main;
close_out project_main

let lean_target out_name { default_sail_dir; ctx; ast; effect_info; env; _ } =
let out_name = match out_name with Some f -> f | None -> "out" in
output out_name ast default_sail_dir
output out_name env ast default_sail_dir

let _ = Target.register ~name:"lean" ~options:lean_options ~rewrites:lean_rewrites ~asserts_termination:true lean_target
27 changes: 27 additions & 0 deletions test/lean/range.expected.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import Out.Sail.Sail

open Sail

/-- Type quantifiers: x : Int, 0 ≤ x ∧ x ≤ 31 -/
def f_int (x : Nat) : Int :=
0

/-- Type quantifiers: x : Int, 0 ≤ x ∧ x ≤ 31 -/
def f_nat (x : Nat) : Nat :=
0

/-- Type quantifiers: x : Int, k_n : Int, 0 ≤ x ∧ x ≤ k_n -/
def f_negvar (x : Nat) : Int :=
x

/-- Type quantifiers: x : Int, k_n : Int, 0 ≤ x ∧ x ≤ k_n -/
def f_nnegvar (x : Nat) : Nat :=
x

/-- Type quantifiers: x : Int, k_n : Int, k_m : Int, k_n ≤ x ∧ x ≤ k_m -/
def f_unkn (x : Int) : Int :=
x

def initialize_registers : Unit :=
()

28 changes: 28 additions & 0 deletions test/lean/range.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
default Order dec
$include <prelude.sail>


val f_int : range(0, 31) -> range(-2, 2)
function f_int(x) = {
0
}

val f_nat : range(0, 31) -> range(0, 2)
function f_nat(x) = {
0
}

val f_negvar : forall 'n. range(0, 'n) -> range(- 'n, 'n)
function f_negvar(x) = {
x
}

val f_nnegvar : forall 'n. range(0, 'n) -> range(0, 'n)
function f_nnegvar(x) = {
x
}

val f_unkn : forall 'n 'm. range('n, 'm) -> range('n, 'm)
function f_unkn(x) = {
x
}

0 comments on commit e42ef7e

Please sign in to comment.