diff --git a/src/lib/state.ml b/src/lib/state.ml index 8ef47a010..f4c65558e 100644 --- a/src/lib/state.ml +++ b/src/lib/state.ml @@ -834,6 +834,33 @@ let register_refs_coq doc_id coq_record_update env registers = in separate hardline [generic_convs; refs; getters_setters] +let register_refs_lean doc_id doc_typ registers = + let generic_convs = separate_map hardline string [""; "variable [MonadReg]"; ""; "open MonadReg"; ""] in + let register_ref (typ, id, _) = + let idd = doc_id id in + let typp = doc_typ typ in + (* let field = if prefix_recordtype then string "regstate_" ^^ idd else idd in *) + concat + [ + string " set_"; + idd; + colon; + space; + typp; + string " -> SailM Unit"; + hardline; + string " get_"; + idd; + colon; + space; + string "SailM ("; + typp; + string ")"; + ] + in + let refs = separate_map hardline register_ref registers in + separate hardline [string "class MonadReg where"; refs; generic_convs] + let generate_regstate_defs ctx env ast = let defs = ast.defs in let registers = find_registers defs in diff --git a/src/sail_lean_backend/Sail/Sail.lean b/src/sail_lean_backend/Sail/Sail.lean index dc3dc9f92..3af997734 100644 --- a/src/sail_lean_backend/Sail/Sail.lean +++ b/src/sail_lean_backend/Sail/Sail.lean @@ -42,3 +42,9 @@ structure RegisterRef (regstate regval a : Type) where def undefined_bitvector (w : Nat) : BitVec w := 0 + +def undefined_int (_ : Unit) : Int := + 0 + +def undefined_bit (_ : Unit) : BitVec 1 := + undefined_bitvector 1 diff --git a/src/sail_lean_backend/pretty_print_lean.ml b/src/sail_lean_backend/pretty_print_lean.ml index 624ba7fab..e7ead06ef 100644 --- a/src/sail_lean_backend/pretty_print_lean.ml +++ b/src/sail_lean_backend/pretty_print_lean.ml @@ -9,12 +9,24 @@ open Rewriter open PPrint open Pretty_print_common +let prefix_recordtype = true + +type global_context = { effect_info : Effects.side_effect_info } + type context = { + global : global_context; kid_id_renames : id option KBindings.t; (* tyvar -> argument renames *) kid_id_renames_rev : kid Bindings.t; (* reverse of kid_id_renames *) + is_monadic : bool; } -let empty_context = { kid_id_renames = KBindings.empty; kid_id_renames_rev = Bindings.empty } +let empty_context = + { + global = { effect_info = Effects.empty_side_effect_info }; + kid_id_renames = KBindings.empty; + kid_id_renames_rev = Bindings.empty; + is_monadic = false; + } let add_single_kid_id_rename ctxt id kid = let kir = @@ -23,16 +35,22 @@ let add_single_kid_id_rename ctxt id kid = | None -> ctxt.kid_id_renames in { - (* ctxt with *) + ctxt with kid_id_renames = KBindings.add kid (Some id) kir; kid_id_renames_rev = Bindings.add id kid ctxt.kid_id_renames_rev; } let implicit_parens x = enclose (string "{") (string "}") x - -let doc_id_ctor (Id_aux (i, _)) = +let doc_id_ctor (ctxt : context) (Id_aux (i, _)) = match i with Id i -> string i | Operator x -> string (Util.zencode_string ("op " ^ x)) +let doc_id = doc_id_ctor + +let doc_field_name ctxt typ_id field_id = + if prefix_recordtype && string_of_id typ_id <> "regstate" then + doc_id ctxt typ_id ^^ string "_" ^^ doc_id ctxt field_id + else doc_id ctxt field_id + let doc_kid ctxt (Kid_aux (Var x, _) as ki) = match KBindings.find_opt ki ctxt.kid_id_renames with | Some (Some i) -> string (string_of_id i) @@ -159,7 +177,7 @@ let rec captured_typ_var ((i, Typ_aux (t, _)) as typ) = Some (i, ki) | _ -> None -let doc_typ_id ctxt (typ, fid) = flow (break 1) [doc_id_ctor fid; colon; doc_typ ctxt typ] +let doc_typ_id ctxt (typ, fid) = flow (break 1) [doc_id_ctor ctxt fid; colon; doc_typ ctxt typ] let doc_kind (K_aux (k, _)) = match k with @@ -217,6 +235,103 @@ let doc_lit (L_aux (lit, l)) = | L_string s -> utf8string ("\"" ^ lean_escape_string s ^ "\"") | L_real s -> utf8string s (* TODO test if this is really working *) +let string_of_pat_con (P_aux (p, _)) = + match p with + | P_app _ -> "P_app" + | P_wild -> "P_wild" + | P_lit _ -> "P_lit" + | P_or _ -> "P_or" + | P_not _ -> "P_not" + | P_as _ -> "P_as" + | P_typ _ -> "P_typ" + | P_id _ -> "P_id" + | P_var _ -> "P_var" + | P_vector _ -> "P_vector" + | P_vector_concat _ -> "P_vector_concat" + | P_vector_subrange _ -> "P_vector_subrange" + | P_tuple _ -> "P_tuple" + | P_list _ -> "P_list" + | P_cons _ -> "P_cons" + | P_string_append _ -> "P_string_append" + | P_struct _ -> "P_struct" + +let rec doc_pat ctxt apat_needed (P_aux (p, (l, annot))) = + let env = env_of_annot (l, annot) in + let typ = Env.expand_synonyms env (typ_of_annot (l, annot)) in + match p with + (* Special case translation of the None constructor to remove the unit arg *) + | P_app (id, _) when string_of_id id = "None" -> string "None" + | P_app (id, (_ :: _ as pats)) -> begin + let pats_pp = separate_map comma (doc_pat ctxt true) pats in + let pats_pp = match pats with [_] -> pats_pp | _ -> parens pats_pp in + let ppp = doc_unop (doc_id_ctor ctxt id) pats_pp in + if apat_needed then parens ppp else ppp + end + | P_app (id, []) -> doc_id_ctor ctxt id + | P_lit lit -> doc_lit lit + | P_wild -> underscore + | P_id id -> doc_id ctxt id + | P_var (p, _) -> doc_pat ctxt true p + | P_as (p, id) -> parens (separate space [doc_pat ctxt true p; string "as"; doc_id ctxt id]) + | P_typ (ptyp, p) -> + let doc_p = doc_pat ctxt true p in + doc_p + (* Type annotations aren't allowed everywhere in patterns in Coq *) + (*parens (doc_op colon doc_p (doc_typ typ))*) + | P_vector pats -> + let ppp = brackets (separate_map semi (fun p -> doc_pat ctxt true p) pats) in + if apat_needed then parens ppp else ppp + | P_vector_concat pats -> + raise + (Reporting.err_unreachable l __POS__ + "vector concatenation patterns should have been removed before pretty-printing" + ) + | P_vector_subrange _ -> unreachable l __POS__ "Must have been rewritten before Coq backend" + | P_tuple pats -> ( + match pats with [p] -> doc_pat ctxt apat_needed p | _ -> parens (separate_map comma_sp (doc_pat ctxt false) pats) + ) + | P_list pats -> brackets (separate_map semi (doc_pat ctxt false) pats) + | P_cons (p, p') -> + let ppp = doc_op (string "::") (doc_pat ctxt true p) (doc_pat ctxt true p') in + if apat_needed then parens ppp else ppp + | P_string_append _ -> unreachable l __POS__ "string append pattern found in Coq backend, should have been rewritten" + | P_struct (fpats, _) -> + let type_id = + match typ with + | (Typ_aux (Typ_id tid, _) | Typ_aux (Typ_app (tid, _), _)) when Env.is_record tid env -> tid + | _ -> Reporting.unreachable l __POS__ "P_struct pattern with no record type" + in + string "{|" ^^ space + ^^ separate_map (semi ^^ space) + (fun (field, pat) -> separate space [doc_field_name ctxt type_id field; coloneq; doc_pat ctxt false pat]) + fpats + ^^ space ^^ string "|}" + | P_not _ -> unreachable l __POS__ "Coq backend doesn't support not patterns" + | P_or _ -> unreachable l __POS__ "Coq backend doesn't support or patterns yet" + +let rebind_cast_pattern_vars pat typ exp = + let rec aux pat typ = + match (pat, typ) with + | P_aux (P_typ (target_typ, P_aux (P_id id, (l, ann))), _), source_typ when not (is_enum (env_of exp) id) -> + if Typ.compare target_typ source_typ == 0 then [] + else ( + let l = Parse_ast.Generated l in + let cast_annot = Type_check.replace_typ source_typ ann in + let e_annot = Type_check.mk_tannot (env_of exp) source_typ in + [LB_aux (LB_val (pat, E_aux (E_id id, (l, e_annot))), (l, ann))] + ) + | P_aux (P_tuple pats, _), Typ_aux (Typ_tuple typs, _) -> List.concat (List.map2 aux pats typs) + | _ -> [] + in + let add_lb (E_aux (_, ann) as exp) lb = E_aux (E_let (lb, exp), ann) in + (* Don't introduce new bindings at the top-level, we'd just go into a loop. *) + let lbs = + match (pat, typ) with + | P_aux (P_tuple pats, _), Typ_aux (Typ_tuple typs, _) -> List.concat (List.map2 aux pats typs) + | _ -> [] + in + List.fold_left add_lb exp lbs + let string_of_exp_con (E_aux (e, _)) = match e with | E_block _ -> "E_block" @@ -257,14 +372,28 @@ let string_of_exp_con (E_aux (e, _)) = | E_vector _ -> "E_vector" | E_let _ -> "E_let" -let rec doc_exp ctxt (E_aux (e, (l, annot)) as full_exp) = +let rec doc_exp (ctxt : context) (E_aux (e, (l, annot)) as full_exp) = let env = env_of_tannot annot in match e with | E_id id -> string (string_of_id id) (* TODO replace by a translating via a binding map *) | E_lit l -> doc_lit l | E_app (Id_aux (Id "internal_pick", _), _) -> string "sorry" (* TODO replace by actual implementation of internal_pick *) - | E_internal_plet _ -> string "sorry" (* TODO replace by actual implementation of internal_plet *) + | E_internal_plet (pat, e1, e2) -> + (* doc_exp ctxt e1 ^^ hardline ^^ doc_exp ctxt e2 *) + let e0 = doc_pat ctxt false pat in + let e1_pp = doc_exp ctxt e1 in + let e2' = rebind_cast_pattern_vars pat (typ_of e1) e2 in + let e2_pp = doc_exp ctxt e2' in + (* infix 0 1 middle e1_pp e2_pp *) + let e0_pp = + begin + match pat with + | P_aux (P_typ (_, P_aux (P_wild, _)), _) -> string "" + | _ -> separate space [string "let"; e0; string ":="] ^^ space + end + in + e0_pp ^^ e1_pp ^^ hardline ^^ e2_pp | E_app (f, args) -> let d_id = if Env.is_extern f env "lean" then string (Env.get_extern f env "lean") @@ -273,7 +402,13 @@ let rec doc_exp ctxt (E_aux (e, (l, annot)) as full_exp) = let d_args = List.map (doc_exp ctxt) args in nest 2 (parens (flow (break 1) (d_id :: d_args))) | E_vector vals -> failwith "vector found" - | E_typ (typ, e) -> parens (separate space [doc_exp ctxt e; colon; doc_typ ctxt typ]) + | E_typ (typ, e) -> begin + match e with + | E_aux (E_assign _, _) -> doc_exp ctxt e + | E_aux (E_app (Id_aux (Id "internal_pick", _), _), _) -> + string "return " ^^ nest 7 (parens (separate space [doc_exp ctxt e; colon; doc_typ ctxt typ])) + | _ -> parens (separate space [doc_exp ctxt e; colon; doc_typ ctxt typ]) + end | E_tuple es -> parens (separate_map (comma ^^ space) (doc_exp ctxt) es) | E_let (LB_aux (LB_val (lpat, lexp), _), e) -> let id = @@ -286,14 +421,26 @@ let rec doc_exp ctxt (E_aux (e, (l, annot)) as full_exp) = | E_struct fexps -> let args = List.map (doc_fexp ctxt) fexps in braces (space ^^ nest 2 (separate hardline args) ^^ space) - | E_field (exp, id) -> doc_exp ctxt exp ^^ dot ^^ doc_id_ctor id + | E_field (exp, id) -> doc_exp ctxt exp ^^ dot ^^ doc_id_ctor ctxt id | E_struct_update (exp, fexps) -> let args = List.map (doc_fexp ctxt) fexps in braces (space ^^ doc_exp ctxt exp ^^ string " with " ^^ separate (comma ^^ space) args ^^ space) + | E_assign ((LE_aux (le_act, tannot) as le), e) -> ( + match le_act with + | LE_id id | LE_typ (_, id) -> string "set_" ^^ doc_id ctxt id ^^ space ^^ doc_exp ctxt e + | LE_deref e -> string "sorry /- deref -/" + | _ -> failwith ("assign " ^ string_of_lexp le ^ "not implemented yet") + ) + | E_internal_return e -> nest 2 (string "return" ^^ space ^^ nest 5 (doc_exp ctxt e)) | _ -> failwith ("Expression " ^ string_of_exp_con full_exp ^ " " ^ string_of_exp full_exp ^ " not translatable yet.") -and doc_fexp ctxt (FE_aux (FE_fexp (field, exp), _)) = doc_id_ctor field ^^ string " := " ^^ doc_exp ctxt exp +and doc_fexp ctxt (FE_aux (FE_fexp (field, exp), _)) = doc_id_ctor ctxt field ^^ string " := " ^^ doc_exp ctxt exp +and doc_lexp_deref ctxt (LE_aux (lexp, (l, annot))) = + match lexp with + | LE_id id -> doc_id ctxt id + | LE_typ (typ, id) -> doc_id ctxt id + | _ -> raise (Reporting.err_unreachable l __POS__ "doc_lexp_deref: Unsupported lexp") let doc_binder ctxt i t = let paranthesizer = match t with @@ -305,7 +452,7 @@ let doc_binder ctxt i t = let ctxt = match captured_typ_var (i, t) with Some (i, ki) -> add_single_kid_id_rename ctxt i ki | _ -> ctxt in (ctxt, separate space [string (string_of_id i); colon; doc_typ ctxt t] |> paranthesizer) -let doc_funcl_init (FCL_aux (FCL_funcl (id, pexp), annot)) = +let doc_funcl_init ctxt (FCL_aux (FCL_funcl (id, pexp), annot)) = let env = env_of_tannot (snd annot) in let TypQ_aux (tq, l), typ = Env.get_val_spec_orig id env in let arg_typs, ret_typ, _ = @@ -324,7 +471,6 @@ let doc_funcl_init (FCL_aux (FCL_funcl (id, pexp), annot)) = | _ -> failwith "Argument pattern not translatable yet." ) in - let ctxt = empty_context in let ctxt, binders = List.fold_left (fun (ctxt, bs) (i, t) -> @@ -351,19 +497,19 @@ let doc_funcl_init (FCL_aux (FCL_funcl (id, pexp), annot)) = separate space ([string "def"; string (string_of_id id)] @ binders @ [colon; doc_ret_typ; coloneq]) ) -let doc_funcl_body (FCL_aux (FCL_funcl (id, pexp), annot)) = +let doc_funcl_body ctxt (FCL_aux (FCL_funcl (id, pexp), annot)) = let _, _, exp, _ = destruct_pexp pexp in let is_monadic = effectful (effect_of exp) in if is_monadic then nest 2 (flow (break 1) [string "return"; doc_exp empty_context exp]) else doc_exp empty_context exp -let doc_funcl funcl = - let comment, signature = doc_funcl_init funcl in - comment ^^ nest 2 (signature ^^ hardline ^^ doc_funcl_body funcl) +let doc_funcl ctxt funcl = + let comment, signature = doc_funcl_init ctxt funcl in + comment ^^ nest 2 (signature ^^ hardline ^^ doc_funcl_body ctxt funcl) -let doc_fundef (FD_aux (FD_function (r, typa, fcls), fannot)) = +let doc_fundef ctxt (FD_aux (FD_function (r, typa, fcls), fannot)) = match fcls with | [] -> failwith "FD_function with empty function list" - | [funcl] -> doc_funcl funcl + | [funcl] -> doc_funcl ctxt funcl | _ -> failwith "FD_function with more than one clause" let string_of_type_def_con (TD_aux (td, _)) = @@ -379,7 +525,7 @@ let doc_typdef ctxt (TD_aux (td, tannot) as full_typdef) = match td with | TD_enum (Id_aux (Id id, _), fields, _) -> let derivers = if List.length fields > 0 then [string "Inhabited"] else [] in - let fields = List.map doc_id_ctor fields in + let fields = List.map (doc_id_ctor ctxt) fields in let fields = List.map (fun i -> space ^^ pipe ^^ space ^^ i) fields in let enums_doc = concat fields in nest 2 @@ -401,7 +547,7 @@ let doc_typdef ctxt (TD_aux (td, tannot) as full_typdef) = let doc_def ctxt (DEF_aux (aux, def_annot) as def) = match aux with - | DEF_fundef fdef -> group (doc_fundef fdef) ^/^ hardline + | DEF_fundef fdef -> group (doc_fundef ctxt fdef) ^/^ hardline | DEF_type tdef -> group (doc_typdef ctxt tdef) ^/^ hardline | _ -> empty @@ -413,8 +559,14 @@ let rec remove_imports (defs : (Libsail.Type_check.tannot, Libsail.Type_check.en | DEF_aux (DEF_pragma ("include_end", _, _), _) :: ds -> remove_imports ds (depth - 1) | d :: ds -> if depth > 0 then remove_imports ds depth else d :: remove_imports ds depth -let pp_ast_lean ({ defs; _ } as ast : Libsail.Type_check.typed_ast) o = +let pp_ast_lean ({ defs; _ } as ast : Libsail.Type_check.typed_ast) o effect_info = let defs = remove_imports defs 0 in - let output : document = separate_map empty (doc_def empty_context) defs in - print o output; + let global = { effect_info } in + let ctxt = { empty_context with global } in + let regs = State.find_registers defs in + let register_refs = + match regs with [] -> empty | _ -> State.register_refs_lean (doc_id ctxt) (doc_typ ctxt) regs ^^ hardline + in + let output : document = separate_map empty (doc_def ctxt) defs in + print o (register_refs ^^ output); () diff --git a/src/sail_lean_backend/sail_plugin_lean.ml b/src/sail_lean_backend/sail_plugin_lean.ml index 45240cc0b..91e200d89 100644 --- a/src/sail_lean_backend/sail_plugin_lean.ml +++ b/src/sail_lean_backend/sail_plugin_lean.ml @@ -190,15 +190,15 @@ let create_lake_project (out_name : string) default_sail_dir = output_string project_main "open Sail\n\n"; project_main -let output (out_name : string) ast default_sail_dir = +let output (out_name : string) ast default_sail_dir effect_info = let project_main = create_lake_project out_name default_sail_dir in (* Uncomment for debug output of the Sail code after the rewrite passes *) (* Pretty_print_sail.output_ast stdout (Type_check.strip_ast ast); *) - Pretty_print_lean.pp_ast_lean ast project_main; + Pretty_print_lean.pp_ast_lean ast project_main effect_info; close_out project_main let lean_target out_name { default_sail_dir; ctx; ast; effect_info; env; _ } = let out_name = match out_name with Some f -> f | None -> "out" in - output out_name ast default_sail_dir + output out_name ast default_sail_dir effect_info let _ = Target.register ~name:"lean" ~options:lean_options ~rewrites:lean_rewrites ~asserts_termination:true lean_target diff --git a/test/lean/reg.expected.lean b/test/lean/reg.expected.lean new file mode 100644 index 000000000..00e105040 --- /dev/null +++ b/test/lean/reg.expected.lean @@ -0,0 +1,16 @@ +import Out.Sail.Sail + +open Sail + +class MonadReg where + set_R0: (BitVec 64) -> SailM Unit + get_R0: SailM ((BitVec 64)) + +variable [MonadReg] + +open MonadReg + +def initialize_registers : SailM Unit := + let w__0 := (undefined_bitvector 64) + set_R0 w__0 + diff --git a/test/lean/reg.sail b/test/lean/reg.sail new file mode 100644 index 000000000..0b2a367ab --- /dev/null +++ b/test/lean/reg.sail @@ -0,0 +1,5 @@ +default Order dec + +$include + +register R0 : bits(64) \ No newline at end of file diff --git a/test/lean/struct.expected.lean b/test/lean/struct.expected.lean index 60424c7cc..eedc24d7e 100644 --- a/test/lean/struct.expected.lean +++ b/test/lean/struct.expected.lean @@ -7,7 +7,25 @@ structure My_struct where field2 : (BitVec 1) def undefined_My_struct (lit : Unit) : SailM My_struct := - return sorry + let w__0 := (undefined_int ()) + let w__1 := (undefined_bit ()) + return { field1 := w__0 + field2 := w__1 } + +def struct_field2 (s : My_struct) : (BitVec 1) := + s.field2 + +def struct_update_field2 (s : My_struct) (b : (BitVec 1)) : My_struct := + { s with field2 := b } + +/-- Type quantifiers: i : Int -/ +def struct_update_both_fields (s : My_struct) (i : Int) (b : (BitVec 1)) : My_struct := + { s with field1 := i, field2 := b } + +/-- Type quantifiers: i : Int -/ +def mk_struct (i : Int) (b : (BitVec 1)) : My_struct := + { field1 := i + field2 := b } def struct_field2 (s : My_struct) : (BitVec 1) := s.field2