Skip to content

Commit

Permalink
Working registers and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lfrenot committed Jan 17, 2025
1 parent f87c292 commit bdd48ad
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 13 deletions.
26 changes: 26 additions & 0 deletions src/lib/state.ml
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,32 @@ let register_refs_coq doc_id coq_record_update env registers =
in
separate hardline [generic_convs; refs; getters_setters]

let register_refs_lean doc_id doc_typ registers =
let generic_convs = separate_map hardline string [""; "variable [MonadReg]"; ""; "open MonadReg"; ""] in
let register_ref (typ, id, _) =
let idd = doc_id id in
let typp = doc_typ typ in
concat
[
string " set_";
idd;
colon;
space;
typp;
string " -> SailM Unit";
hardline;
string " get_";
idd;
colon;
space;
string "SailM (";
typp;
string ")";
]
in
let refs = separate_map hardline register_ref registers in
separate hardline [string "class MonadReg where"; refs; generic_convs]

let generate_regstate_defs ctx env ast =
let defs = ast.defs in
let registers = find_registers defs in
Expand Down
94 changes: 89 additions & 5 deletions src/sail_lean_backend/pretty_print_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -257,14 +257,82 @@ let string_of_exp_con (E_aux (e, _)) =
| E_vector _ -> "E_vector"
| E_let _ -> "E_let"

let string_of_pat_con (P_aux (p, _)) =
match p with
| P_app _ -> "P_app"
| P_wild -> "P_wild"
| P_lit _ -> "P_lit"
| P_or _ -> "P_or"
| P_not _ -> "P_not"
| P_as _ -> "P_as"
| P_typ _ -> "P_typ"
| P_id _ -> "P_id"
| P_var _ -> "P_var"
| P_vector _ -> "P_vector"
| P_vector_concat _ -> "P_vector_concat"
| P_vector_subrange _ -> "P_vector_subrange"
| P_tuple _ -> "P_tuple"
| P_list _ -> "P_list"
| P_cons _ -> "P_cons"
| P_string_append _ -> "P_string_append"
| P_struct _ -> "P_struct"

let rec doc_pat ctxt apat_needed (P_aux (p, (l, annot)) as pat) =
let env = env_of_annot (l, annot) in
let typ = Env.expand_synonyms env (typ_of_annot (l, annot)) in
match p with
| P_typ (ptyp, p) ->
let doc_p = doc_pat ctxt true p in
doc_p
| P_id id -> doc_id_ctor id
| _ -> failwith ("Pattern " ^ string_of_pat_con pat ^ " " ^ string_of_pat pat ^ " not translatable yet.")

(* Copied from the Coq PP *)
let rebind_cast_pattern_vars pat typ exp =
let rec aux pat typ =
match (pat, typ) with
| P_aux (P_typ (target_typ, P_aux (P_id id, (l, ann))), _), source_typ when not (is_enum (env_of exp) id) ->
if Typ.compare target_typ source_typ == 0 then []
else (
let l = Parse_ast.Generated l in
let cast_annot = Type_check.replace_typ source_typ ann in
let e_annot = Type_check.mk_tannot (env_of exp) source_typ in
[LB_aux (LB_val (pat, E_aux (E_id id, (l, e_annot))), (l, ann))]
)
| P_aux (P_tuple pats, _), Typ_aux (Typ_tuple typs, _) -> List.concat (List.map2 aux pats typs)
| _ -> []
in
let add_lb (E_aux (_, ann) as exp) lb = E_aux (E_let (lb, exp), ann) in
(* Don't introduce new bindings at the top-level, we'd just go into a loop. *)
let lbs =
match (pat, typ) with
| P_aux (P_tuple pats, _), Typ_aux (Typ_tuple typs, _) -> List.concat (List.map2 aux pats typs)
| _ -> []
in
List.fold_left add_lb exp lbs

let rec doc_exp ctxt (E_aux (e, (l, annot)) as full_exp) =
let env = env_of_tannot annot in
match e with
| E_id id -> string (string_of_id id) (* TODO replace by a translating via a binding map *)
| E_lit l -> doc_lit l
| E_app (Id_aux (Id "internal_pick", _), _) ->
string "sorry" (* TODO replace by actual implementation of internal_pick *)
| E_internal_plet _ -> string "sorry" (* TODO replace by actual implementation of internal_plet *)
| E_internal_plet (pat, e1, e2) ->
(* doc_exp ctxt e1 ^^ hardline ^^ doc_exp ctxt e2 *)
let e0 = doc_pat ctxt false pat in
let e1_pp = doc_exp ctxt e1 in
let e2' = rebind_cast_pattern_vars pat (typ_of e1) e2 in
let e2_pp = doc_exp ctxt e2' in
(* infix 0 1 middle e1_pp e2_pp *)
let e0_pp =
begin
match pat with
| P_aux (P_typ (_, P_aux (P_wild, _)), _) -> string ""
| _ -> separate space [string "let"; e0; string ":="] ^^ space
end
in
e0_pp ^^ e1_pp ^^ hardline ^^ e2_pp
| E_app (f, args) ->
let d_id =
if Env.is_extern f env "lean" then string (Env.get_extern f env "lean")
Expand All @@ -273,7 +341,13 @@ let rec doc_exp ctxt (E_aux (e, (l, annot)) as full_exp) =
let d_args = List.map (doc_exp ctxt) args in
nest 2 (parens (flow (break 1) (d_id :: d_args)))
| E_vector vals -> failwith "vector found"
| E_typ (typ, e) -> parens (separate space [doc_exp ctxt e; colon; doc_typ ctxt typ])
| E_typ (typ, e) -> (
match e with
| E_aux (E_assign _, _) -> doc_exp ctxt e
| E_aux (E_app (Id_aux (Id "internal_pick", _), _), _) ->
string "return " ^^ nest 7 (parens (separate space [doc_exp ctxt e; colon; doc_typ ctxt typ]))
| _ -> parens (separate space [doc_exp ctxt e; colon; doc_typ ctxt typ])
)
| E_tuple es -> parens (separate_map (comma ^^ space) (doc_exp ctxt) es)
| E_let (LB_aux (LB_val (lpat, lexp), _), e) ->
let id =
Expand All @@ -290,6 +364,13 @@ let rec doc_exp ctxt (E_aux (e, (l, annot)) as full_exp) =
| E_struct_update (exp, fexps) ->
let args = List.map (doc_fexp ctxt) fexps in
braces (space ^^ doc_exp ctxt exp ^^ string " with " ^^ separate (comma ^^ space) args ^^ space)
| E_assign ((LE_aux (le_act, tannot) as le), e) -> (
match le_act with
| LE_id id | LE_typ (_, id) -> string "set_" ^^ doc_id_ctor id ^^ space ^^ doc_exp ctxt e
| LE_deref e -> string "sorry /- deref -/"
| _ -> failwith ("assign " ^ string_of_lexp le ^ "not implemented yet")
)
| E_internal_return e -> nest 2 (string "return" ^^ space ^^ nest 5 (doc_exp ctxt e))
| _ -> failwith ("Expression " ^ string_of_exp_con full_exp ^ " " ^ string_of_exp full_exp ^ " not translatable yet.")

and doc_fexp ctxt (FE_aux (FE_fexp (field, exp), _)) = doc_id_ctor field ^^ string " := " ^^ doc_exp ctxt exp
Expand Down Expand Up @@ -353,8 +434,7 @@ let doc_funcl_init (FCL_aux (FCL_funcl (id, pexp), annot)) =

let doc_funcl_body (FCL_aux (FCL_funcl (id, pexp), annot)) =
let _, _, exp, _ = destruct_pexp pexp in
let is_monadic = effectful (effect_of exp) in
if is_monadic then nest 2 (flow (break 1) [string "return"; doc_exp empty_context exp]) else doc_exp empty_context exp
doc_exp empty_context exp

let doc_funcl funcl =
let comment, signature = doc_funcl_init funcl in
Expand Down Expand Up @@ -415,6 +495,10 @@ let rec remove_imports (defs : (Libsail.Type_check.tannot, Libsail.Type_check.en

let pp_ast_lean ({ defs; _ } as ast : Libsail.Type_check.typed_ast) o =
let defs = remove_imports defs 0 in
let regs = State.find_registers defs in
let register_refs =
match regs with [] -> empty | _ -> State.register_refs_lean doc_id_ctor (doc_typ empty_context) regs ^^ hardline
in
let output : document = separate_map empty (doc_def empty_context) defs in
print o output;
print o (register_refs ^^ output);
()
20 changes: 13 additions & 7 deletions test/lean/bitfield.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ open Sail
def cr_type := (BitVec 8)

def undefined_cr_type (lit : Unit) : SailM (BitVec 8) :=
return ((undefined_bitvector 8) : (BitVec 8))
((undefined_bitvector 8) : (BitVec 8))

def Mk_cr_type (v : (BitVec 8)) : (BitVec 8) :=
v
Expand All @@ -17,7 +17,8 @@ def _update_cr_type_bits (v : (BitVec 8)) (x : (BitVec 8)) : (BitVec 8) :=
(Sail.BitVec.updateSubrange v (HSub.hSub 8 1) 0 x)

def _set_cr_type_bits (r_ref : RegisterRef Unit Unit (BitVec 8)) (v : (BitVec 8)) : SailM Unit :=
return sorry
let r := (reg_deref r_ref)
sorry /- deref -/

def _get_cr_type_CR0 (v : (BitVec 8)) : (BitVec 4) :=
(Sail.BitVec.extractLsb v 7 4)
Expand All @@ -26,7 +27,8 @@ def _update_cr_type_CR0 (v : (BitVec 8)) (x : (BitVec 4)) : (BitVec 8) :=
(Sail.BitVec.updateSubrange v 7 4 x)

def _set_cr_type_CR0 (r_ref : RegisterRef Unit Unit (BitVec 8)) (v : (BitVec 4)) : SailM Unit :=
return sorry
let r := (reg_deref r_ref)
sorry /- deref -/

def _get_cr_type_CR1 (v : (BitVec 8)) : (BitVec 2) :=
(Sail.BitVec.extractLsb v 3 2)
Expand All @@ -35,7 +37,8 @@ def _update_cr_type_CR1 (v : (BitVec 8)) (x : (BitVec 2)) : (BitVec 8) :=
(Sail.BitVec.updateSubrange v 3 2 x)

def _set_cr_type_CR1 (r_ref : RegisterRef Unit Unit (BitVec 8)) (v : (BitVec 2)) : SailM Unit :=
return sorry
let r := (reg_deref r_ref)
sorry /- deref -/

def _get_cr_type_CR3 (v : (BitVec 8)) : (BitVec 2) :=
(Sail.BitVec.extractLsb v 1 0)
Expand All @@ -44,7 +47,8 @@ def _update_cr_type_CR3 (v : (BitVec 8)) (x : (BitVec 2)) : (BitVec 8) :=
(Sail.BitVec.updateSubrange v 1 0 x)

def _set_cr_type_CR3 (r_ref : RegisterRef Unit Unit (BitVec 8)) (v : (BitVec 2)) : SailM Unit :=
return sorry
let r := (reg_deref r_ref)
sorry /- deref -/

def _get_cr_type_GT (v : (BitVec 8)) : (BitVec 1) :=
(Sail.BitVec.extractLsb v 6 6)
Expand All @@ -53,7 +57,8 @@ def _update_cr_type_GT (v : (BitVec 8)) (x : (BitVec 1)) : (BitVec 8) :=
(Sail.BitVec.updateSubrange v 6 6 x)

def _set_cr_type_GT (r_ref : RegisterRef Unit Unit (BitVec 8)) (v : (BitVec 1)) : SailM Unit :=
return sorry
let r := (reg_deref r_ref)
sorry /- deref -/

def _get_cr_type_LT (v : (BitVec 8)) : (BitVec 1) :=
(Sail.BitVec.extractLsb v 7 7)
Expand All @@ -62,7 +67,8 @@ def _update_cr_type_LT (v : (BitVec 8)) (x : (BitVec 1)) : (BitVec 8) :=
(Sail.BitVec.updateSubrange v 7 7 x)

def _set_cr_type_LT (r_ref : RegisterRef Unit Unit (BitVec 8)) (v : (BitVec 1)) : SailM Unit :=
return sorry
let r := (reg_deref r_ref)
sorry /- deref -/

def initialize_registers : Unit :=
()
Expand Down
16 changes: 16 additions & 0 deletions test/lean/reg.expected.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import Out.Sail.Sail

open Sail

class MonadReg where
set_R0: (BitVec 64) -> SailM Unit
get_R0: SailM ((BitVec 64))

variable [MonadReg]

open MonadReg

def initialize_registers : SailM Unit :=
let w__0 := (undefined_bitvector 64)
set_R0 w__0

5 changes: 5 additions & 0 deletions test/lean/reg.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
default Order dec

$include <prelude.sail>

register R0 : bits(64)
5 changes: 4 additions & 1 deletion test/lean/struct.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ structure My_struct where
field2 : (BitVec 1)

def undefined_My_struct (lit : Unit) : SailM My_struct :=
return sorry
let w__0 := (undefined_int ())
let w__1 := (undefined_bit ())
return { field1 := w__0
field2 := w__1 }

def struct_field2 (s : My_struct) : (BitVec 1) :=
s.field2
Expand Down

0 comments on commit bdd48ad

Please sign in to comment.