Skip to content

Commit

Permalink
Lean: add support for vectors and register vectors (#911)
Browse files Browse the repository at this point in the history
This also adds a test file for ite, and an additional test in register_vector
  • Loading branch information
lfrenot authored Jan 28, 2025
1 parent f703aac commit dde655a
Show file tree
Hide file tree
Showing 12 changed files with 431 additions and 28 deletions.
4 changes: 2 additions & 2 deletions lib/flow.sail
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,11 @@ val and_bool_no_flow = pure {coq: "andb", lean: "Bool.and", _: "and_bool"} : (bo

val or_bool = pure {coq: "orb", lean: "Bool.or", _: "or_bool"} : forall ('p : Bool) ('q : Bool). (bool('p), bool('q)) -> bool('p | 'q)

val eq_int = pure {ocaml: "eq_int", interpreter: "eq_int", lem: "eq", coq: "Z.eqb", _: "eq_int"} : forall 'n 'm. (int('n), int('m)) -> bool('n == 'm)
val eq_int = pure {ocaml: "eq_int", interpreter: "eq_int", lem: "eq", coq: "Z.eqb", lean: "Eq", _: "eq_int"} : forall 'n 'm. (int('n), int('m)) -> bool('n == 'm)

val eq_bool = pure {ocaml: "eq_bool", interpreter: "eq_bool", lem: "eq", coq: "Bool.eqb", lean: "Eq", _: "eq_bool"} : (bool, bool) -> bool

val neq_int = pure {lem: "neq"} : forall 'n 'm. (int('n), int('m)) -> bool('n != 'm)
val neq_int = pure {lem: "neq", lean: "Ne"} : forall 'n 'm. (int('n), int('m)) -> bool('n != 'm)
function neq_int (x, y) = not_bool(eq_int(x, y))

val neq_bool : (bool, bool) -> bool
Expand Down
1 change: 1 addition & 0 deletions lib/vector.sail
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ val plain_vector_access = pure {
interpreter: "access",
lem: "access_list_dec",
coq: "vec_access_dec",
lean: "vectorAccess",
_: "vector_access"
} : forall ('n : Int) ('m : Int) ('a : Type), 0 <= 'm < 'n. (vector('n, dec, 'a), int('m)) -> 'a

Expand Down
2 changes: 2 additions & 0 deletions src/sail_lean_backend/Sail/Sail.lean
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def writeRegRef (reg_ref : @RegisterRef Register RegisterType α) (a : α) :

def reg_deref (reg_ref : @RegisterRef Register RegisterType α) := readRegRef reg_ref

def vectorAccess [Inhabited α] (v : Vector α m) (n : Nat) := v[n]!

end Regs

/- TODO: Remove when #911 is merged -/
Expand Down
90 changes: 69 additions & 21 deletions src/sail_lean_backend/pretty_print_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ let provably_nneg ctx x = Type_check.prove __POS__ ctx.env (nc_gteq x (nint 0))

let rec doc_typ ctx (Typ_aux (t, _) as typ) =
match t with
| Typ_app (Id_aux (Id "vector", _), [A_aux (A_nexp m, _); A_aux (A_typ elem_typ, _)]) ->
(* TODO: remove duplication with exists, below *)
string "Vector" ^^ space ^^ parens (doc_typ ctx elem_typ) ^^ space ^^ doc_nexp ctx m
| Typ_id (Id_aux (Id "unit", _)) -> string "Unit"
| Typ_id (Id_aux (Id "int", _)) -> string "Int"
| Typ_app (Id_aux (Id "atom_bool", _), _) | Typ_id (Id_aux (Id "bool", _)) -> string "Bool"
Expand Down Expand Up @@ -352,9 +355,10 @@ and doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) =
let field_monadic = effectful (effect_of e) in
doc_fexp field_monadic ctx fexp
in
(* string (" /- " ^ string_of_exp_con full_exp ^ " -/ ") ^^ *)
match e with
| E_id id ->
if Env.is_register id env then string "readReg " ^^ doc_id_ctor id
if Env.is_register id env then wrap_with_left_arrow (not as_monadic) (string "readReg " ^^ doc_id_ctor id)
else wrap_with_pure as_monadic (string (string_of_id id))
| E_lit l -> wrap_with_pure as_monadic (doc_lit l)
| E_app (Id_aux (Id "undefined_int", _), _) (* TODO remove when we handle imports *)
Expand All @@ -369,7 +373,7 @@ and doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) =
let e0 = doc_pat pat in
let e1_pp = doc_exp false ctx e1 in
let e2' = rebind_cast_pattern_vars pat (typ_of e1) e2 in
let e2_pp = doc_exp false ctx e2' in
let e2_pp = doc_exp as_monadic ctx e2' in
let e0_pp =
begin
match pat with
Expand All @@ -386,7 +390,8 @@ and doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) =
let d_args = List.map d_of_arg args in
let fn_monadic = not (Effects.function_is_pure f ctx.global.effect_info) in
nest 2 (wrap_with_pure (as_monadic && fn_monadic) (parens (flow (break 1) (d_id :: d_args))))
| E_vector vals -> failwith "vector found"
| E_vector vals ->
string "#v" ^^ wrap_with_pure as_monadic (brackets (nest 2 (flow (comma ^^ break 1) (List.map d_of_arg vals))))
| E_typ (typ, e) ->
if effectful (effect_of e) then
parens (separate space [doc_exp false ctx e; colon; string "SailM"; doc_typ ctx typ])
Expand Down Expand Up @@ -425,6 +430,14 @@ and doc_exp (as_monadic : bool) ctx (E_aux (e, (l, annot)) as full_exp) =
| LE_deref e' -> string "writeRegRef " ^^ doc_exp false ctx e' ^^ space ^^ doc_exp false ctx e
| _ -> failwith ("assign " ^ string_of_lexp le ^ "not implemented yet")
)
| E_if (i, t, e) ->
let statements_monadic = as_monadic || effectful (effect_of t) || effectful (effect_of e) in
nest 2 (string "if" ^^ space ^^ nest 1 (doc_exp false ctx i))
^^ hardline
^^ nest 2 (string "then" ^^ space ^^ nest 3 (doc_exp statements_monadic ctx t))
^^ hardline
^^ nest 2 (string "else" ^^ space ^^ nest 3 (doc_exp statements_monadic ctx e))
| E_ref id -> string "Reg " ^^ doc_id_ctor id
| _ -> failwith ("Expression " ^ string_of_exp_con full_exp ^ " " ^ string_of_exp full_exp ^ " not translatable yet.")

and doc_fexp with_arrow ctx (FE_aux (FE_fexp (field, e), _)) =
Expand Down Expand Up @@ -538,21 +551,47 @@ let doc_typdef ctx (TD_aux (td, tannot) as full_typdef) =
(* TODO don't ignore type quantifiers *)
nest 2 (flow (break 1) [string "structure"; string id; string "where"] ^^ hardline ^^ enums_doc)
| TD_abbrev (Id_aux (Id id, _), tq, A_aux (A_typ t, _)) ->
nest 2 (flow (break 1) [string "def"; string id; coloneq; doc_typ ctx t])
nest 2 (flow (break 1) [string "abbrev"; string id; coloneq; doc_typ ctx t])
| TD_abbrev (Id_aux (Id id, _), tq, A_aux (A_nexp ne, _)) ->
nest 2 (flow (break 1) [string "def"; string id; colon; string "Int"; coloneq; doc_nexp ctx ne])
nest 2 (flow (break 1) [string "abbrev"; string id; colon; string "Int"; coloneq; doc_nexp ctx ne])
| _ -> failwith ("Type definition " ^ string_of_type_def_con full_typdef ^ " not translatable yet.")

let rec doc_defs_aux ctx defs types fundefs =
(* Copied from the Coq PP *)
let doc_val ctx pat exp =
let id, pat_typ =
match pat with
| P_aux (P_typ (typ, P_aux (P_id id, _)), _) -> (id, Some typ)
| P_aux (P_id id, _) -> (id, None)
| P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_var kid, _)), _) when Id.compare id (id_of_kid kid) == 0 -> (id, None)
| P_aux (P_typ (typ, P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_var kid, _)), _)), _)
when Id.compare id (id_of_kid kid) == 0 ->
(id, Some typ)
| P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_app (app_id, [TP_aux (TP_var kid, _)]), _)), _)
when Id.compare app_id (mk_id "atom") == 0 && Id.compare id (id_of_kid kid) == 0 ->
(id, None)
| P_aux
(P_typ (typ, P_aux (P_var (P_aux (P_id id, _), TP_aux (TP_app (app_id, [TP_aux (TP_var kid, _)]), _)), _)), _)
when Id.compare app_id (mk_id "atom") == 0 && Id.compare id (id_of_kid kid) == 0 ->
(id, Some typ)
| _ -> failwith ("Pattern " ^ string_of_pat_con pat ^ " " ^ string_of_pat pat ^ " not translatable yet.")
in
let typpp = match pat_typ with None -> empty | Some typ -> space ^^ colon ^^ space ^^ doc_typ ctx typ in
let idpp = doc_id_ctor id in
let base_pp = doc_exp false ctx exp in
nest 2 (group (string "def" ^^ space ^^ idpp ^^ typpp ^^ space ^^ coloneq ^/^ base_pp))

let rec doc_defs_rec ctx defs types docdefs =
match defs with
| [] -> (types, fundefs)
| [] -> (types, docdefs)
| DEF_aux (DEF_fundef fdef, _) :: defs' ->
doc_defs_aux ctx defs' types (fundefs ^^ group (doc_fundef ctx fdef) ^/^ hardline)
doc_defs_rec ctx defs' types (docdefs ^^ group (doc_fundef ctx fdef) ^/^ hardline)
| DEF_aux (DEF_type tdef, _) :: defs' ->
doc_defs_aux ctx defs' (types ^^ group (doc_typdef ctx tdef) ^/^ hardline) fundefs
| _ :: defs' -> doc_defs_aux ctx defs' types fundefs
doc_defs_rec ctx defs' (types ^^ group (doc_typdef ctx tdef) ^/^ hardline) docdefs
| DEF_aux (DEF_let (LB_aux (LB_val (pat, exp), _)), _) :: defs' ->
doc_defs_rec ctx defs' types (docdefs ^^ group (doc_val ctx pat exp) ^/^ hardline)
| _ :: defs' -> doc_defs_rec ctx defs' types docdefs

let doc_defs ctx defs = doc_defs_aux ctx defs empty empty
let doc_defs ctx defs = doc_defs_rec ctx defs empty empty

(* Remove all imports for now, they will be printed in other files. Probably just for testing. *)
let rec remove_imports (defs : (Libsail.Type_check.tannot, Libsail.Type_check.env) def list) depth =
Expand All @@ -562,15 +601,9 @@ let rec remove_imports (defs : (Libsail.Type_check.tannot, Libsail.Type_check.en
| DEF_aux (DEF_pragma ("include_end", _, _), _) :: ds -> remove_imports ds (depth - 1)
| d :: ds -> if depth > 0 then remove_imports ds depth else d :: remove_imports ds depth

let opt_cons v = function None -> Some [v] | Some t -> Some (v :: t)

let reg_type_name typ_id = prepend_id "register_" typ_id
let reg_case_name typ_id = prepend_id "R_" typ_id
let state_field_name typ_id = append_id typ_id "_s"
let ref_name reg = append_id reg "_ref"
let add_reg_typ env (typ_map, regs_map) (typ, id, has_init) =
let add_reg_typ typ_map (typ, id, _) =
let typ_id = State.id_of_regtyp IdSet.empty typ in
(Bindings.add typ_id typ typ_map, Bindings.update typ_id (opt_cons id) regs_map)
Bindings.add typ_id (id, typ) typ_map

let register_enums registers =
separate hardline
Expand All @@ -592,14 +625,29 @@ let type_enum ctx registers =
empty;
]

let inhabit_enum ctx typ_map =
separate_map hardline
(fun (_, (id, typ)) ->
string "instance : Inhabited (RegisterRef RegisterType "
^^ doc_typ ctx typ ^^ string ") where" ^^ hardline ^^ string " default := .Reg " ^^ doc_id_ctor id
)
typ_map

let doc_reg_info env registers =
let bare_ctx = initial_context env in
let ctx = initial_context env in

let type_map = List.fold_left add_reg_typ Bindings.empty registers in
let type_map = Bindings.bindings type_map in

separate hardline
[
register_enums registers;
type_enum bare_ctx registers;
type_enum ctx registers;
string "abbrev SailM := PreSailM RegisterType";
empty;
string "open RegisterRef";
inhabit_enum ctx type_map;
empty;
empty;
]

Expand Down
2 changes: 1 addition & 1 deletion src/sail_lean_backend/sail_plugin_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ let opt_lean_output_dir : string option ref = ref None

let opt_lean_force_output : bool ref = ref false

let lean_version : string = "lean4:nightly-2024-09-25"
let lean_version : string = "lean4:nightly-2025-01-22"

let lean_options =
[
Expand Down
6 changes: 5 additions & 1 deletion test/lean/bitfield.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import Out.Sail.Sail

open Sail

def cr_type := (BitVec 8)
abbrev cr_type := (BitVec 8)

inductive Register : Type where
| R
Expand All @@ -14,6 +14,10 @@ abbrev RegisterType : Register → Type

abbrev SailM := PreSailM RegisterType

open RegisterRef
instance : Inhabited (RegisterRef RegisterType (BitVec 8)) where
default := .Reg R

def undefined_cr_type (lit : Unit) : SailM (BitVec 8) := do
sorry

Expand Down
49 changes: 49 additions & 0 deletions test/lean/ite.expected.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import Out.Sail.Sail

open Sail

inductive Register : Type where
| B
| R
deriving DecidableEq, Hashable
open Register

abbrev RegisterType : Register → Type
| .B => Bool
| .R => Nat

abbrev SailM := PreSailM RegisterType

open RegisterRef
instance : Inhabited (RegisterRef RegisterType Bool) where
default := .Reg B
instance : Inhabited (RegisterRef RegisterType Nat) where
default := .Reg R

/-- Type quantifiers: n : Int, 0 ≤ n -/
def elif (n : Nat) : (BitVec 1) :=
if (Eq n 0)
then 1#1
else if (Eq n 1)
then 1#1
else 0#1

/-- Type quantifiers: n : Int, 0 ≤ n -/
def monadic_in_out (n : Nat) : SailM Nat := do
if (← readReg B)
then writeReg R n
else (pure ())
readReg R

/-- Type quantifiers: n : Int, 0 ≤ n -/
def monadic_lines (n : Nat) : SailM Unit := do
let b := (Eq n 0)
if b
then writeReg R n
writeReg B b
else writeReg B b

def initialize_registers : SailM Unit := do
writeReg R sorry
writeReg B sorry

33 changes: 33 additions & 0 deletions test/lean/ite.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
default Order dec

$include <prelude.sail>

register R : nat
register B : bool

function elif(n : nat) -> bit = {
if n == 0 then
bitone
else if n == 1 then
bitone
else
bitzero
}

function monadic_in_out(n : nat) -> nat = {
if B then
R = n
else
();
R
}

function monadic_lines(n : nat) -> unit = {
let b = n == 0;
if b then {
R = n;
B = b
}
else
B = b
}
Loading

0 comments on commit dde655a

Please sign in to comment.