Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lean: add support for vectors and register vectors #911

Merged
merged 8 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/flow.sail
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ val eq_int = pure {ocaml: "eq_int", interpreter: "eq_int", lem: "eq", coq: "Z.eq

val eq_bool = pure {ocaml: "eq_bool", interpreter: "eq_bool", lem: "eq", coq: "Bool.eqb", lean: "Eq", _: "eq_bool"} : (bool, bool) -> bool

val neq_int = pure {lem: "neq"} : forall 'n 'm. (int('n), int('m)) -> bool('n != 'm)
val neq_int = pure {lem: "neq", lean: "Ne"} : forall 'n 'm. (int('n), int('m)) -> bool('n != 'm)
function neq_int (x, y) = not_bool(eq_int(x, y))

val neq_bool : (bool, bool) -> bool
Expand Down
1 change: 1 addition & 0 deletions lib/vector.sail
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ val plain_vector_access = pure {
interpreter: "access",
lem: "access_list_dec",
coq: "vec_access_dec",
lean: "vectorAccess",
_: "vector_access"
} : forall ('n : Int) ('m : Int) ('a : Type), 0 <= 'm < 'n. (vector('n, dec, 'a), int('m)) -> 'a

Expand Down
2 changes: 2 additions & 0 deletions src/sail_lean_backend/Sail/Sail.lean
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def writeRegRef (reg_ref : @RegisterRef Register RegisterType α) (a : α) :

def reg_deref (reg_ref : @RegisterRef Register RegisterType α) := readRegRef reg_ref

def vectorAccess [Inhabited α] (v : Vector α m) (n : Nat) := v[n]!

end Regs

namespace BitVec
Expand Down
91 changes: 74 additions & 17 deletions src/sail_lean_backend/pretty_print_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ let provably_nneg ctx x = Type_check.prove __POS__ ctx.env (nc_gteq x (nint 0))

let rec doc_typ ctx (Typ_aux (t, _) as typ) =
match t with
| Typ_app (Id_aux (Id "vector", _), [A_aux (A_nexp m, _); A_aux (A_typ elem_typ, _)]) ->
(* TODO: remove duplication with exists, below *)
string "Vector" ^^ space ^^ parens (doc_typ ctx elem_typ) ^^ space ^^ doc_nexp ctx m
| Typ_id (Id_aux (Id "unit", _)) -> string "Unit"
| Typ_id (Id_aux (Id "int", _)) -> string "Int"
| Typ_id (Id_aux (Id "bool", _)) -> string "Bool"
Expand Down Expand Up @@ -377,7 +380,8 @@ let rec doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) =
let d_args = List.map d_of_arg args in
let fn_monadic = not (Effects.function_is_pure f ctx.global.effect_info) in
nest 2 (wrap_with_pure (as_monadic && fn_monadic) (parens (flow (break 1) (d_id :: d_args))))
| E_vector vals -> failwith "vector found"
| E_vector vals ->
string "#v" ^^ wrap_with_pure as_monadic (brackets (nest 2 (flow (comma ^^ break 1) (List.map d_of_arg vals))))
| E_typ (typ, e) ->
if effectful (effect_of e) then
parens (separate space [doc_exp false ctx e; colon; string "SailM"; doc_typ ctx typ])
Expand Down Expand Up @@ -413,11 +417,30 @@ let rec doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) =
| LE_deref e' -> string "writeRegRef " ^^ doc_exp false ctx e' ^^ space ^^ doc_exp false ctx e
| _ -> failwith ("assign " ^ string_of_lexp le ^ "not implemented yet")
)
| E_if (c, t, e) -> if_exp ctx (env_of full_exp) (typ_of full_exp) false as_monadic c t e
| E_ref id -> string "Reg " ^^ doc_id_ctor id
| _ -> failwith ("Expression " ^ string_of_exp_con full_exp ^ " " ^ string_of_exp full_exp ^ " not translatable yet.")

and doc_fexp with_arrow ctx (FE_aux (FE_fexp (field, e), _)) =
doc_id_ctor field ^^ string " := " ^^ wrap_with_left_arrow with_arrow (doc_exp false ctx e)

and if_exp (ctxt : context) (full_env : env) (full_typ : typ) (elseif : bool) (as_monadic : bool) c t e =
let if_pp = string (if elseif then "else if" else "if") in
let c_pp = doc_exp as_monadic ctxt c in
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's right. You need to check whether c is monadic and use that as an argument to both the call of doc_exp and then add a call to wrap_with_left_arrow to pull the monadic value out in this case.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what I have locally as a minimal implementation, not thoroughly tested though:

| E_if (i, t, e) ->
    let condition_monadic = effectful (effect_of i) in
    nest 2 (string "if" ^^ space ^^ wrap_with_left_arrow condition_monadic (doc_exp condition_monadic ctx i)) ^^ hardline ^^
    nest 2 (string "then" ^^ space ^^ doc_exp as_monadic ctx t) ^^ hardline ^^
    nest 2 (string "else" ^^ space ^^ doc_exp as_monadic ctx e)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you're correct, I've changed it to something closer to your implementation

let t_pp = doc_exp as_monadic ctxt t in
let else_pp =
match e with
| E_aux (E_if (c', t', e'), _) | E_aux (E_typ (_, E_aux (E_if (c', t', e'), _)), _) ->
if_exp ctxt full_env full_typ true as_monadic c' t' e'
(* Special case to prevent current arm decoder becoming a staircase *)
(* TODO: replace with smarter pretty printing *)
| E_aux (E_internal_plet (pat, exp1, E_aux (E_typ (typ, (E_aux (E_if (_, _, _), _) as exp2)), _)), ann)
when Typ.compare typ unit_typ == 0 ->
string "else" ^/^ doc_exp as_monadic ctxt (E_aux (E_internal_plet (pat, exp1, exp2), ann))
| _ -> prefix 2 1 (string "else") (doc_exp as_monadic ctxt e)
in
prefix 2 1 (soft_surround 2 1 if_pp c_pp (string "then")) t_pp ^^ break 1 ^^ else_pp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this copied over from the Coq printer? If so maybe mark it as such so we can improve it later.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it was, it's not here anymore, but I have marked the other part that was copied from the Coq printer


let doc_binder ctx i t =
let paranthesizer =
match t with
Expand Down Expand Up @@ -523,16 +546,41 @@ let doc_typdef ctx (TD_aux (td, tannot) as full_typdef) =
nest 2 (flow (break 1) [string "def"; string id; colon; string "Int"; coloneq; doc_nexp ctx ne])
| _ -> failwith ("Type definition " ^ string_of_type_def_con full_typdef ^ " not translatable yet.")

let rec doc_defs_aux ctx defs types fundefs =
let doc_val ctx pat exp =
let id, pat_typ =
match pat with
| P_aux (P_typ (typ, P_aux (P_id id, _)), _) -> (id, Some typ)
| P_aux (P_id id, _) -> (id, None)
| P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_var kid, _)), _) when Id.compare id (id_of_kid kid) == 0 -> (id, None)
| P_aux (P_typ (typ, P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_var kid, _)), _)), _)
when Id.compare id (id_of_kid kid) == 0 ->
(id, Some typ)
| P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_app (app_id, [TP_aux (TP_var kid, _)]), _)), _)
when Id.compare app_id (mk_id "atom") == 0 && Id.compare id (id_of_kid kid) == 0 ->
(id, None)
| P_aux
(P_typ (typ, P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_app (app_id, [TP_aux (TP_var kid, _)]), _)), _)), _)
when Id.compare app_id (mk_id "atom") == 0 && Id.compare id (id_of_kid kid) == 0 ->
(id, Some typ)
| _ -> failwith ("Pattern " ^ string_of_pat_con pat ^ " " ^ string_of_pat pat ^ " not translatable yet.")
in
let typpp = match pat_typ with None -> empty | Some typ -> space ^^ colon ^^ space ^^ doc_typ ctx typ in
let idpp = doc_id_ctor id in
let base_pp = doc_exp false ctx exp in
group (string "def" ^^ space ^^ idpp ^^ typpp ^^ space ^^ coloneq ^/^ base_pp)

let rec doc_defs_rec ctx defs types docdefs =
match defs with
| [] -> (types, fundefs)
| [] -> (types, docdefs)
| DEF_aux (DEF_fundef fdef, _) :: defs' ->
doc_defs_aux ctx defs' types (fundefs ^^ group (doc_fundef ctx fdef) ^/^ hardline)
doc_defs_rec ctx defs' types (docdefs ^^ group (doc_fundef ctx fdef) ^/^ hardline)
| DEF_aux (DEF_type tdef, _) :: defs' ->
doc_defs_aux ctx defs' (types ^^ group (doc_typdef ctx tdef) ^/^ hardline) fundefs
| _ :: defs' -> doc_defs_aux ctx defs' types fundefs
doc_defs_rec ctx defs' (types ^^ group (doc_typdef ctx tdef) ^/^ hardline) docdefs
| DEF_aux (DEF_let (LB_aux (LB_val (pat, exp), _)), _) :: defs' ->
doc_defs_rec ctx defs' types (docdefs ^^ group (doc_val ctx pat exp) ^/^ hardline)
| _ :: defs' -> doc_defs_rec ctx defs' types docdefs

let doc_defs ctx defs = doc_defs_aux ctx defs empty empty
let doc_defs ctx defs = doc_defs_rec ctx defs empty empty

(* Remove all imports for now, they will be printed in other files. Probably just for testing. *)
let rec remove_imports (defs : (Libsail.Type_check.tannot, Libsail.Type_check.env) def list) depth =
Expand All @@ -542,15 +590,9 @@ let rec remove_imports (defs : (Libsail.Type_check.tannot, Libsail.Type_check.en
| DEF_aux (DEF_pragma ("include_end", _, _), _) :: ds -> remove_imports ds (depth - 1)
| d :: ds -> if depth > 0 then remove_imports ds depth else d :: remove_imports ds depth

let opt_cons v = function None -> Some [v] | Some t -> Some (v :: t)

let reg_type_name typ_id = prepend_id "register_" typ_id
let reg_case_name typ_id = prepend_id "R_" typ_id
let state_field_name typ_id = append_id typ_id "_s"
let ref_name reg = append_id reg "_ref"
let add_reg_typ env (typ_map, regs_map) (typ, id, has_init) =
let add_reg_typ typ_map (typ, id, _) =
let typ_id = State.id_of_regtyp IdSet.empty typ in
(Bindings.add typ_id typ typ_map, Bindings.update typ_id (opt_cons id) regs_map)
Bindings.add typ_id (id, typ) typ_map

let register_enums registers =
separate hardline
Expand All @@ -572,14 +614,29 @@ let type_enum ctx registers =
empty;
]

let inhabit_enum ctx typ_map =
separate_map hardline
(fun (_, (id, typ)) ->
string "instance : Inhabited (RegisterRef RegisterType "
^^ doc_typ ctx typ ^^ string ") where" ^^ hardline ^^ string " default := .Reg " ^^ doc_id_ctor id
)
typ_map

let doc_reg_info env registers =
let bare_ctx = initial_context env in
let ctx = initial_context env in

let type_map = List.fold_left add_reg_typ Bindings.empty registers in
let type_map = Bindings.bindings type_map in

separate hardline
[
register_enums registers;
type_enum bare_ctx registers;
type_enum ctx registers;
string "abbrev SailM := PreSailM RegisterType";
empty;
string "open RegisterRef";
inhabit_enum ctx type_map;
empty;
empty;
]

Expand Down
2 changes: 1 addition & 1 deletion src/sail_lean_backend/sail_plugin_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ let opt_lean_output_dir : string option ref = ref None

let opt_lean_force_output : bool ref = ref false

let lean_version : string = "lean4:nightly-2024-09-25"
let lean_version : string = "lean4:nightly-2025-01-22"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this bumped?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was bumped because nightly-2024-09-25 used a different implementation of Vectors


let lean_options =
[
Expand Down
4 changes: 4 additions & 0 deletions test/lean/bitfield.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ abbrev RegisterType : Register → Type

abbrev SailM := PreSailM RegisterType

open RegisterRef
instance : Inhabited (RegisterRef RegisterType (BitVec 8)) where
default := .Reg R

def undefined_cr_type (lit : Unit) : SailM (BitVec 8) := do
sorry

Expand Down
137 changes: 137 additions & 0 deletions test/lean/register_vector.expected.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import Out.Sail.Sail

open Sail

def reg_index := Nat
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather have abbrev RegIndex := Nat here. The def instead of abbrev will make the new type have no usable instances on it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit out of the scope of this PR, but I have changed it


inductive Register : Type where
| R0
| R1
| R2
| R3
| R4
| R5
| R6
| R7
| R8
| R9
| R10
| R11
| R12
| R13
| R14
| R15
| R16
| R17
| R18
| R19
| R20
| R21
| R22
| R23
| R24
| R25
| R26
| R27
| R28
| R29
| R30
| _PC
deriving DecidableEq, Hashable
open Register

abbrev RegisterType : Register → Type
| .R0 => (BitVec 64)
| .R1 => (BitVec 64)
| .R2 => (BitVec 64)
| .R3 => (BitVec 64)
| .R4 => (BitVec 64)
| .R5 => (BitVec 64)
| .R6 => (BitVec 64)
| .R7 => (BitVec 64)
| .R8 => (BitVec 64)
| .R9 => (BitVec 64)
| .R10 => (BitVec 64)
| .R11 => (BitVec 64)
| .R12 => (BitVec 64)
| .R13 => (BitVec 64)
| .R14 => (BitVec 64)
| .R15 => (BitVec 64)
| .R16 => (BitVec 64)
| .R17 => (BitVec 64)
| .R18 => (BitVec 64)
| .R19 => (BitVec 64)
| .R20 => (BitVec 64)
| .R21 => (BitVec 64)
| .R22 => (BitVec 64)
| .R23 => (BitVec 64)
| .R24 => (BitVec 64)
| .R25 => (BitVec 64)
| .R26 => (BitVec 64)
| .R27 => (BitVec 64)
| .R28 => (BitVec 64)
| .R29 => (BitVec 64)
| .R30 => (BitVec 64)
| ._PC => (BitVec 64)

abbrev SailM := PreSailM RegisterType

open RegisterRef
instance : Inhabited (RegisterRef RegisterType (BitVec 64)) where
default := .Reg _PC

def GPRs : Vector (RegisterRef RegisterType (BitVec 64)) 31 :=
#v[Reg R30, Reg R29, Reg R28, Reg R27, Reg R26, Reg R25, Reg R24, Reg R23, Reg R22, Reg R21, Reg R20,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indentation

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Reg R19, Reg R18, Reg R17, Reg R16, Reg R15, Reg R14, Reg R13, Reg R12, Reg R11, Reg R10, Reg R9,
Reg R8, Reg R7, Reg R6, Reg R5, Reg R4, Reg R3, Reg R2, Reg R1, Reg R0]

/-- Type quantifiers: n : Int, 0 ≤ n ∧ n ≤ 31 -/
def wX (n : Nat) (value : (BitVec 64)) : SailM Unit := do
if (Ne n 31) then writeRegRef (vectorAccess GPRs n) value
else (pure ())

/-- Type quantifiers: n : Int, 0 ≤ n ∧ n ≤ 31 -/
def rX (n : Nat) : SailM (BitVec 64) := do
if (Ne n 31) then (reg_deref (vectorAccess GPRs n))
else (pure (0x0000000000000000 : (BitVec 64)))

def rPC : SailM (BitVec 64) := do
readReg _PC

def wPC (pc : (BitVec 64)) : SailM Unit := do
writeReg _PC pc

def initialize_registers : SailM Unit := do
writeReg _PC sorry
writeReg R30 sorry
writeReg R29 sorry
writeReg R28 sorry
writeReg R27 sorry
writeReg R26 sorry
writeReg R25 sorry
writeReg R24 sorry
writeReg R23 sorry
writeReg R22 sorry
writeReg R21 sorry
writeReg R20 sorry
writeReg R19 sorry
writeReg R18 sorry
writeReg R17 sorry
writeReg R16 sorry
writeReg R15 sorry
writeReg R14 sorry
writeReg R13 sorry
writeReg R12 sorry
writeReg R11 sorry
writeReg R10 sorry
writeReg R9 sorry
writeReg R8 sorry
writeReg R7 sorry
writeReg R6 sorry
writeReg R5 sorry
writeReg R4 sorry
writeReg R3 sorry
writeReg R2 sorry
writeReg R1 sorry
writeReg R0 sorry

Loading
Loading