Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lean: add support for vectors and register vectors #911

Merged
merged 8 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/flow.sail
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ val and_bool_no_flow = pure {coq: "andb", lean: "Bool.and", _: "and_bool"} : (bo

val or_bool = pure {coq: "orb", lean: "Bool.or", _: "or_bool"} : forall ('p : Bool) ('q : Bool). (bool('p), bool('q)) -> bool('p | 'q)

val eq_int = pure {ocaml: "eq_int", interpreter: "eq_int", lem: "eq", coq: "Z.eqb", _: "eq_int"} : forall 'n 'm. (int('n), int('m)) -> bool('n == 'm)
val eq_int = pure {ocaml: "eq_int", interpreter: "eq_int", lem: "eq", coq: "Z.eqb", lean: "Eq", _: "eq_int"} : forall 'n 'm. (int('n), int('m)) -> bool('n == 'm)

val eq_bool = pure {ocaml: "eq_bool", interpreter: "eq_bool", lem: "eq", coq: "Bool.eqb", lean: "Eq", _: "eq_bool"} : (bool, bool) -> bool

val neq_int = pure {lem: "neq"} : forall 'n 'm. (int('n), int('m)) -> bool('n != 'm)
val neq_int = pure {lem: "neq", lean: "Ne"} : forall 'n 'm. (int('n), int('m)) -> bool('n != 'm)
function neq_int (x, y) = not_bool(eq_int(x, y))

val neq_bool : (bool, bool) -> bool
Expand Down
1 change: 1 addition & 0 deletions lib/vector.sail
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ val plain_vector_access = pure {
interpreter: "access",
lem: "access_list_dec",
coq: "vec_access_dec",
lean: "vectorAccess",
_: "vector_access"
} : forall ('n : Int) ('m : Int) ('a : Type), 0 <= 'm < 'n. (vector('n, dec, 'a), int('m)) -> 'a

Expand Down
2 changes: 2 additions & 0 deletions src/sail_lean_backend/Sail/Sail.lean
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def writeRegRef (reg_ref : @RegisterRef Register RegisterType α) (a : α) :

def reg_deref (reg_ref : @RegisterRef Register RegisterType α) := readRegRef reg_ref

def vectorAccess [Inhabited α] (v : Vector α m) (n : Nat) := v[n]!

end Regs

namespace BitVec
Expand Down
90 changes: 69 additions & 21 deletions src/sail_lean_backend/pretty_print_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ let provably_nneg ctx x = Type_check.prove __POS__ ctx.env (nc_gteq x (nint 0))

let rec doc_typ ctx (Typ_aux (t, _) as typ) =
match t with
| Typ_app (Id_aux (Id "vector", _), [A_aux (A_nexp m, _); A_aux (A_typ elem_typ, _)]) ->
(* TODO: remove duplication with exists, below *)
string "Vector" ^^ space ^^ parens (doc_typ ctx elem_typ) ^^ space ^^ doc_nexp ctx m
| Typ_id (Id_aux (Id "unit", _)) -> string "Unit"
| Typ_id (Id_aux (Id "int", _)) -> string "Int"
| Typ_id (Id_aux (Id "bool", _)) -> string "Bool"
Expand Down Expand Up @@ -343,9 +346,10 @@ let rec doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) =
let field_monadic = effectful (effect_of e) in
doc_fexp field_monadic ctx fexp
in
(* string (" /- " ^ string_of_exp_con full_exp ^ " -/ ") ^^ *)
match e with
| E_id id ->
if Env.is_register id env then string "readReg " ^^ doc_id_ctor id
if Env.is_register id env then wrap_with_left_arrow (not as_monadic) (string "readReg " ^^ doc_id_ctor id)
else wrap_with_pure as_monadic (string (string_of_id id))
| E_lit l -> wrap_with_pure as_monadic (doc_lit l)
| E_app (Id_aux (Id "undefined_int", _), _) (* TODO remove when we handle imports *)
Expand All @@ -360,7 +364,7 @@ let rec doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) =
let e0 = doc_pat ctx false pat in
let e1_pp = doc_exp false ctx e1 in
let e2' = rebind_cast_pattern_vars pat (typ_of e1) e2 in
let e2_pp = doc_exp false ctx e2' in
let e2_pp = doc_exp as_monadic ctx e2' in
let e0_pp =
begin
match pat with
Expand All @@ -377,7 +381,8 @@ let rec doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) =
let d_args = List.map d_of_arg args in
let fn_monadic = not (Effects.function_is_pure f ctx.global.effect_info) in
nest 2 (wrap_with_pure (as_monadic && fn_monadic) (parens (flow (break 1) (d_id :: d_args))))
| E_vector vals -> failwith "vector found"
| E_vector vals ->
string "#v" ^^ wrap_with_pure as_monadic (brackets (nest 2 (flow (comma ^^ break 1) (List.map d_of_arg vals))))
| E_typ (typ, e) ->
if effectful (effect_of e) then
parens (separate space [doc_exp false ctx e; colon; string "SailM"; doc_typ ctx typ])
Expand Down Expand Up @@ -413,6 +418,14 @@ let rec doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) =
| LE_deref e' -> string "writeRegRef " ^^ doc_exp false ctx e' ^^ space ^^ doc_exp false ctx e
| _ -> failwith ("assign " ^ string_of_lexp le ^ "not implemented yet")
)
| E_if (i, t, e) ->
let statements_monadic = as_monadic || effectful (effect_of t) || effectful (effect_of e) in
nest 2 (string "if" ^^ space ^^ nest 1 (doc_exp false ctx i))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this still ignores the monadicity of the condition?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But yes, let's ignore that for now and fix it when we need to fix it.

Copy link
Collaborator Author

@lfrenot lfrenot Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does because the condition should never be monadic (it should be a Bool and not a SailM Bool), instead the responsibility is on doc_exp, see line 352

The responsibility needs to be moved like that, because we want to be able to handle both if (Eq (← (rX r)) 0) and if (← readReg B)

^^ hardline
^^ nest 2 (string "then" ^^ space ^^ nest 3 (doc_exp statements_monadic ctx t))
^^ hardline
^^ nest 2 (string "else" ^^ space ^^ nest 3 (doc_exp statements_monadic ctx e))
| E_ref id -> string "Reg " ^^ doc_id_ctor id
| _ -> failwith ("Expression " ^ string_of_exp_con full_exp ^ " " ^ string_of_exp full_exp ^ " not translatable yet.")

and doc_fexp with_arrow ctx (FE_aux (FE_fexp (field, e), _)) =
Expand Down Expand Up @@ -518,21 +531,47 @@ let doc_typdef ctx (TD_aux (td, tannot) as full_typdef) =
(* TODO don't ignore type quantifiers *)
nest 2 (flow (break 1) [string "structure"; string id; string "where"] ^^ hardline ^^ enums_doc)
| TD_abbrev (Id_aux (Id id, _), tq, A_aux (A_typ t, _)) ->
nest 2 (flow (break 1) [string "def"; string id; coloneq; doc_typ ctx t])
nest 2 (flow (break 1) [string "abbrev"; string id; coloneq; doc_typ ctx t])
| TD_abbrev (Id_aux (Id id, _), tq, A_aux (A_nexp ne, _)) ->
nest 2 (flow (break 1) [string "def"; string id; colon; string "Int"; coloneq; doc_nexp ctx ne])
nest 2 (flow (break 1) [string "abbrev"; string id; colon; string "Int"; coloneq; doc_nexp ctx ne])
| _ -> failwith ("Type definition " ^ string_of_type_def_con full_typdef ^ " not translatable yet.")

let rec doc_defs_aux ctx defs types fundefs =
(* Copied from the Coq PP *)
let doc_val ctx pat exp =
let id, pat_typ =
match pat with
| P_aux (P_typ (typ, P_aux (P_id id, _)), _) -> (id, Some typ)
| P_aux (P_id id, _) -> (id, None)
| P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_var kid, _)), _) when Id.compare id (id_of_kid kid) == 0 -> (id, None)
| P_aux (P_typ (typ, P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_var kid, _)), _)), _)
when Id.compare id (id_of_kid kid) == 0 ->
(id, Some typ)
| P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_app (app_id, [TP_aux (TP_var kid, _)]), _)), _)
when Id.compare app_id (mk_id "atom") == 0 && Id.compare id (id_of_kid kid) == 0 ->
(id, None)
| P_aux
(P_typ (typ, P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_app (app_id, [TP_aux (TP_var kid, _)]), _)), _)), _)
when Id.compare app_id (mk_id "atom") == 0 && Id.compare id (id_of_kid kid) == 0 ->
(id, Some typ)
| _ -> failwith ("Pattern " ^ string_of_pat_con pat ^ " " ^ string_of_pat pat ^ " not translatable yet.")
in
let typpp = match pat_typ with None -> empty | Some typ -> space ^^ colon ^^ space ^^ doc_typ ctx typ in
let idpp = doc_id_ctor id in
let base_pp = doc_exp false ctx exp in
nest 2 (group (string "def" ^^ space ^^ idpp ^^ typpp ^^ space ^^ coloneq ^/^ base_pp))

let rec doc_defs_rec ctx defs types docdefs =
match defs with
| [] -> (types, fundefs)
| [] -> (types, docdefs)
| DEF_aux (DEF_fundef fdef, _) :: defs' ->
doc_defs_aux ctx defs' types (fundefs ^^ group (doc_fundef ctx fdef) ^/^ hardline)
doc_defs_rec ctx defs' types (docdefs ^^ group (doc_fundef ctx fdef) ^/^ hardline)
| DEF_aux (DEF_type tdef, _) :: defs' ->
doc_defs_aux ctx defs' (types ^^ group (doc_typdef ctx tdef) ^/^ hardline) fundefs
| _ :: defs' -> doc_defs_aux ctx defs' types fundefs
doc_defs_rec ctx defs' (types ^^ group (doc_typdef ctx tdef) ^/^ hardline) docdefs
| DEF_aux (DEF_let (LB_aux (LB_val (pat, exp), _)), _) :: defs' ->
doc_defs_rec ctx defs' types (docdefs ^^ group (doc_val ctx pat exp) ^/^ hardline)
| _ :: defs' -> doc_defs_rec ctx defs' types docdefs

let doc_defs ctx defs = doc_defs_aux ctx defs empty empty
let doc_defs ctx defs = doc_defs_rec ctx defs empty empty

(* Remove all imports for now, they will be printed in other files. Probably just for testing. *)
let rec remove_imports (defs : (Libsail.Type_check.tannot, Libsail.Type_check.env) def list) depth =
Expand All @@ -542,15 +581,9 @@ let rec remove_imports (defs : (Libsail.Type_check.tannot, Libsail.Type_check.en
| DEF_aux (DEF_pragma ("include_end", _, _), _) :: ds -> remove_imports ds (depth - 1)
| d :: ds -> if depth > 0 then remove_imports ds depth else d :: remove_imports ds depth

let opt_cons v = function None -> Some [v] | Some t -> Some (v :: t)

let reg_type_name typ_id = prepend_id "register_" typ_id
let reg_case_name typ_id = prepend_id "R_" typ_id
let state_field_name typ_id = append_id typ_id "_s"
let ref_name reg = append_id reg "_ref"
let add_reg_typ env (typ_map, regs_map) (typ, id, has_init) =
let add_reg_typ typ_map (typ, id, _) =
let typ_id = State.id_of_regtyp IdSet.empty typ in
(Bindings.add typ_id typ typ_map, Bindings.update typ_id (opt_cons id) regs_map)
Bindings.add typ_id (id, typ) typ_map

let register_enums registers =
separate hardline
Expand All @@ -572,14 +605,29 @@ let type_enum ctx registers =
empty;
]

let inhabit_enum ctx typ_map =
separate_map hardline
(fun (_, (id, typ)) ->
string "instance : Inhabited (RegisterRef RegisterType "
^^ doc_typ ctx typ ^^ string ") where" ^^ hardline ^^ string " default := .Reg " ^^ doc_id_ctor id
)
typ_map

let doc_reg_info env registers =
let bare_ctx = initial_context env in
let ctx = initial_context env in

let type_map = List.fold_left add_reg_typ Bindings.empty registers in
let type_map = Bindings.bindings type_map in

separate hardline
[
register_enums registers;
type_enum bare_ctx registers;
type_enum ctx registers;
string "abbrev SailM := PreSailM RegisterType";
empty;
string "open RegisterRef";
inhabit_enum ctx type_map;
empty;
empty;
]

Expand Down
2 changes: 1 addition & 1 deletion src/sail_lean_backend/sail_plugin_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ let opt_lean_output_dir : string option ref = ref None

let opt_lean_force_output : bool ref = ref false

let lean_version : string = "lean4:nightly-2024-09-25"
let lean_version : string = "lean4:nightly-2025-01-22"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this bumped?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was bumped because nightly-2024-09-25 used a different implementation of Vectors


let lean_options =
[
Expand Down
6 changes: 5 additions & 1 deletion test/lean/bitfield.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import Out.Sail.Sail

open Sail

def cr_type := (BitVec 8)
abbrev cr_type := (BitVec 8)

inductive Register : Type where
| R
Expand All @@ -14,6 +14,10 @@ abbrev RegisterType : Register → Type

abbrev SailM := PreSailM RegisterType

open RegisterRef
instance : Inhabited (RegisterRef RegisterType (BitVec 8)) where
default := .Reg R

def undefined_cr_type (lit : Unit) : SailM (BitVec 8) := do
sorry

Expand Down
49 changes: 49 additions & 0 deletions test/lean/ite.expected.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import Out.Sail.Sail

open Sail

inductive Register : Type where
| B
| R
deriving DecidableEq, Hashable
open Register

abbrev RegisterType : Register → Type
| .B => Bool
| .R => Nat

abbrev SailM := PreSailM RegisterType

open RegisterRef
instance : Inhabited (RegisterRef RegisterType Bool) where
default := .Reg B
instance : Inhabited (RegisterRef RegisterType Nat) where
default := .Reg R

/-- Type quantifiers: n : Int, 0 ≤ n -/
def elif (n : Nat) : (BitVec 1) :=
if (Eq n 0)
then 1#1
else if (Eq n 1)
then 1#1
else 0#1

/-- Type quantifiers: n : Int, 0 ≤ n -/
def monadic_in_out (n : Nat) : SailM Nat := do
if (← readReg B)
then writeReg R n
else (pure ())
readReg R

/-- Type quantifiers: n : Int, 0 ≤ n -/
def monadic_lines (n : Nat) : SailM Unit := do
let b := (Eq n 0)
if b
then writeReg R n
writeReg B b
else writeReg B b

def initialize_registers : SailM Unit := do
writeReg R sorry
writeReg B sorry

33 changes: 33 additions & 0 deletions test/lean/ite.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
default Order dec

$include <prelude.sail>

register R : nat
register B : bool

function elif(n : nat) -> bit = {
if n == 0 then
bitone
else if n == 1 then
bitone
else
bitzero
}

function monadic_in_out(n : nat) -> nat = {
if B then
R = n
else
();
R
}

function monadic_lines(n : nat) -> unit = {
let b = n == 0;
if b then {
R = n;
B = b
}
else
B = b
}
Loading
Loading