Skip to content

Commit

Permalink
Lean: add placeholder for state monad to stateful functions (#870)
Browse files Browse the repository at this point in the history
  • Loading branch information
javra authored Jan 16, 2025
1 parent bf921f9 commit a1b339a
Show file tree
Hide file tree
Showing 13 changed files with 74 additions and 42 deletions.
4 changes: 4 additions & 0 deletions src/sail_lean_backend/Sail/Sail.lean
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
namespace Sail

/- Placeholder for a future implementation of the state monad some Sail functions use. -/
abbrev SailM := StateM Unit

namespace BitVec

def length {w : Nat} (_ : BitVec w) : Nat := w
Expand Down
20 changes: 14 additions & 6 deletions src/sail_lean_backend/pretty_print_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ let add_single_kid_id_rename ctxt id kid =
}

let implicit_parens x = enclose (string "{") (string "}") x

let doc_id_ctor (Id_aux (i, _)) =
match i with Id i -> string i | Operator x -> string (Util.zencode_string ("op " ^ x))

Expand Down Expand Up @@ -133,7 +134,7 @@ let rec doc_typ ctxt (Typ_aux (t, _) as typ) =
| Typ_id (Id_aux (Id "nat", _)) -> string "Nat"
| Typ_app (Id_aux (Id "bitvector", _), [A_aux (A_nexp m, _)]) | Typ_app (Id_aux (Id "bits", _), [A_aux (A_nexp m, _)])
->
string "BitVec " ^^ doc_nexp ctxt m
parens (string "BitVec " ^^ doc_nexp ctxt m)
| Typ_app (Id_aux (Id "atom", _), [A_aux (A_nexp (Nexp_aux (Nexp_var ki, _)), _)]) ->
string "Int" (* TODO This probably has to be generalized *)
| Typ_app (Id_aux (Id "implicit", _), [A_aux (A_nexp (Nexp_aux (Nexp_var ki, _)), _)]) ->
Expand All @@ -149,15 +150,15 @@ let rec captured_typ_var ((i, Typ_aux (t, _)) as typ) =
Some (i, ki)
| _ -> None

let doc_typ_id ctxt (typ, fid) = concat [doc_id_ctor fid; space; colon; space; doc_typ ctxt typ; hardline]
let doc_typ_id ctxt (typ, fid) = flow (break 1) [doc_id_ctor fid; colon; doc_typ ctxt typ]

let doc_kind (K_aux (k, _)) =
match k with
| K_int -> string "Int"
| K_bool -> string "Bool"
| _ -> failwith ("Kind " ^ string_of_kind_aux k ^ " not translatable yet.")

let doc_typ_arg ctxt ta = string "foo"
let doc_typ_arg ctxt ta = string "foo" (* TODO implement *)

let rec doc_nconstraint ctxt (NC_aux (nc, _)) =
match nc with
Expand Down Expand Up @@ -294,7 +295,7 @@ let doc_funcl_init (FCL_aux (FCL_funcl (id, pexp), annot)) =
| Typ_aux (Typ_fn (arg_typs, ret_typ), _) -> (arg_typs, ret_typ, no_effect)
| _ -> failwith ("Function " ^ string_of_id id ^ " does not have function type")
in
let pat, _, _, _ = destruct_pexp pexp in
let pat, _, exp, _ = destruct_pexp pexp in
let pats, _ = untuple_args_pat arg_typs pat in
let binders : (id * typ) list =
pats
Expand Down Expand Up @@ -322,13 +323,20 @@ let doc_funcl_init (FCL_aux (FCL_funcl (id, pexp), annot)) =
in
(* Use auto-implicits for type quanitifiers for now and see if this works *)
let doc_ret_typ = doc_typ ctxt ret_typ in
let is_monadic = effectful (effect_of exp) in
(* Add monad for stateful functions *)
let doc_ret_typ = if is_monadic then string "SailM " ^^ doc_ret_typ else doc_ret_typ in
let decl_val = [doc_ret_typ; coloneq] in
(* Add do block for stateful functions *)
let decl_val = if is_monadic then decl_val @ [string "do"] else decl_val in
( typ_quant_comment,
separate space ([string "def"; string (string_of_id id)] @ binders @ [colon; doc_ret_typ; coloneq])
)

let doc_funcl_body (FCL_aux (FCL_funcl (id, pexp), annot)) =
let _, _, exp, _ = destruct_pexp pexp in
doc_exp empty_context exp
let is_monadic = effectful (effect_of exp) in
if is_monadic then nest 2 (flow (break 1) [string "return"; doc_exp empty_context exp]) else doc_exp empty_context exp

let doc_funcl funcl =
let comment, signature = doc_funcl_init funcl in
Expand Down Expand Up @@ -363,7 +371,7 @@ let doc_typdef ctxt (TD_aux (td, tannot) as full_typdef) =
)
| TD_record (Id_aux (Id id, _), TypQ_aux (tq, _), fields, _) ->
let fields = List.map (doc_typ_id ctxt) fields in
let enums_doc = concat fields in
let enums_doc = separate hardline fields in
let rectyp = doc_typ_quant ctxt tq in
(* TODO don't ignore type quantifiers *)
nest 2 (flow (break 1) [string "structure"; string id; string "where"] ^^ hardline ^^ enums_doc)
Expand Down
1 change: 1 addition & 0 deletions src/sail_lean_backend/sail_plugin_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ let create_lake_project (out_name : string) default_sail_dir =
in
let project_main = open_out (Filename.concat project_dir (out_name_camel ^ ".lean")) in
output_string project_main ("import " ^ out_name_camel ^ ".Sail.Sail\n\n");
output_string project_main "open Sail\n\n";
project_main

let output (out_name : string) ast default_sail_dir =
Expand Down
34 changes: 18 additions & 16 deletions test/lean/bitvec_operation.expected.lean
Original file line number Diff line number Diff line change
@@ -1,51 +1,53 @@
import Out.Sail.Sail

def bitvector_eq (x : BitVec 16) (y : BitVec 16) : Bool :=
open Sail

def bitvector_eq (x : (BitVec 16)) (y : (BitVec 16)) : Bool :=
(Eq x y)

def bitvector_neq (x : BitVec 16) (y : BitVec 16) : Bool :=
def bitvector_neq (x : (BitVec 16)) (y : (BitVec 16)) : Bool :=
(Ne x y)

def bitvector_len (x : BitVec 16) : Nat :=
def bitvector_len (x : (BitVec 16)) : Nat :=
(Sail.BitVec.length x)

def bitvector_sign_extend (x : BitVec 16) : BitVec 32 :=
def bitvector_sign_extend (x : (BitVec 16)) : (BitVec 32) :=
(Sail.BitVec.signExtend x 32)

def bitvector_zero_extend (x : BitVec 16) : BitVec 32 :=
def bitvector_zero_extend (x : (BitVec 16)) : (BitVec 32) :=
(Sail.BitVec.zeroExtend x 32)

def bitvector_truncate (x : BitVec 32) : BitVec 16 :=
def bitvector_truncate (x : (BitVec 32)) : (BitVec 16) :=
(Sail.BitVec.truncate x 16)

def bitvector_truncateLSB (x : BitVec 32) : BitVec 16 :=
def bitvector_truncateLSB (x : (BitVec 32)) : (BitVec 16) :=
(Sail.BitVec.truncateLsb x 16)

def bitvector_append (x : BitVec 16) (y : BitVec 16) : BitVec 32 :=
def bitvector_append (x : (BitVec 16)) (y : (BitVec 16)) : (BitVec 32) :=
(BitVec.append x y)

def bitvector_add (x : BitVec 16) (y : BitVec 16) : BitVec 16 :=
def bitvector_add (x : (BitVec 16)) (y : (BitVec 16)) : (BitVec 16) :=
(HAdd.hAdd x y)

def bitvector_sub (x : BitVec 16) (y : BitVec 16) : BitVec 16 :=
def bitvector_sub (x : (BitVec 16)) (y : (BitVec 16)) : (BitVec 16) :=
(HSub.hSub x y)

def bitvector_not (x : BitVec 16) : BitVec 16 :=
def bitvector_not (x : (BitVec 16)) : (BitVec 16) :=
(Complement.complement x)

def bitvector_and (x : BitVec 16) (y : BitVec 16) : BitVec 16 :=
def bitvector_and (x : (BitVec 16)) (y : (BitVec 16)) : (BitVec 16) :=
(HAnd.hAnd x y)

def bitvector_or (x : BitVec 16) (y : BitVec 16) : BitVec 16 :=
def bitvector_or (x : (BitVec 16)) (y : (BitVec 16)) : (BitVec 16) :=
(HOr.hOr x y)

def bitvector_xor (x : BitVec 16) (y : BitVec 16) : BitVec 16 :=
def bitvector_xor (x : (BitVec 16)) (y : (BitVec 16)) : (BitVec 16) :=
(HXor.hXor x y)

def bitvector_unsigned (x : BitVec 16) : Nat :=
def bitvector_unsigned (x : (BitVec 16)) : Nat :=
(BitVec.toNat x)

def bitvector_signed (x : BitVec 16) : Int :=
def bitvector_signed (x : (BitVec 16)) : Int :=
(BitVec.toInt x)

def initialize_registers : Unit :=
Expand Down
6 changes: 4 additions & 2 deletions test/lean/enum.expected.lean
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import Out.Sail.Sail

open Sail

inductive E where | A | B | C
deriving Inhabited

def undefined_E : E :=
(sorry : E)
def undefined_E : SailM E :=
return (sorry : E)

def initialize_registers : Unit :=
()
Expand Down
2 changes: 2 additions & 0 deletions test/lean/extern.expected.lean
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import Out.Sail.Sail

open Sail

def extern_add : Int :=
(Int.add 5 4)

Expand Down
10 changes: 6 additions & 4 deletions test/lean/extern_bitvec.expected.lean
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import Out.Sail.Sail

def extern_const : BitVec 64 :=
(0xFFFF000012340000 : BitVec 64)
open Sail

def extern_add : BitVec 16 :=
(HAdd.hAdd (0xFFFF : BitVec 16) (0x1234 : BitVec 16))
def extern_const : (BitVec 64) :=
(0xFFFF000012340000 : (BitVec 64))

def extern_add : (BitVec 16) :=
(HAdd.hAdd (0xFFFF : (BitVec 16)) (0x1234 : (BitVec 16)))

def initialize_registers : Unit :=
()
Expand Down
8 changes: 5 additions & 3 deletions test/lean/let.expected.lean
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import Out.Sail.Sail

def foo : BitVec 16 :=
let z := (HOr.hOr (0xFFFF : BitVec 16) (0xABCD : BitVec 16))
(HAnd.hAnd (0x0000 : BitVec 16) z)
open Sail

def foo : (BitVec 16) :=
let z := (HOr.hOr (0xFFFF : (BitVec 16)) (0xABCD : (BitVec 16)))
(HAnd.hAnd (0x0000 : (BitVec 16)) z)

def initialize_registers : Unit :=
()
Expand Down
7 changes: 4 additions & 3 deletions test/lean/struct.expected.lean
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import Out.Sail.Sail

open Sail

structure My_struct where
field1 : Int
field2 : Int


def undefined_My_struct (lit : Unit) : My_struct :=
sorry
def undefined_My_struct (lit : Unit) : SailM My_struct :=
return sorry

def initialize_registers : Unit :=
()
Expand Down
2 changes: 2 additions & 0 deletions test/lean/trivial.expected.lean
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import Out.Sail.Sail

open Sail

def foo (y : Unit) : Unit :=
y

Expand Down
6 changes: 4 additions & 2 deletions test/lean/tuples.expected.lean
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import Out.Sail.Sail

def tuple1 : (Int × Int × (BitVec 2 × Unit)) :=
(3, 5, ((0b10 : BitVec 2), ()))
open Sail

def tuple1 : (Int × Int × ((BitVec 2) × Unit)) :=
(3, 5, ((0b10 : (BitVec 2)), ()))

def initialize_registers : Unit :=
()
Expand Down
8 changes: 5 additions & 3 deletions test/lean/typedef.expected.lean
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import Out.Sail.Sail

open Sail

def xlen : Int := 64

def xlen_bytes : Int := 8

def xlenbits := BitVec 64
def xlenbits := (BitVec 64)

/-- Type quantifiers: k_n : Int, m : Int, m ≥ k_n -/
def EXTZ {m : _} (v : BitVec k_n) : BitVec m :=
def EXTZ {m : _} (v : (BitVec k_n)) : (BitVec m) :=
(Sail.BitVec.zeroExtend v m)

/-- Type quantifiers: k_n : Int, m : Int, m ≥ k_n -/
def EXTS {m : _} (v : BitVec k_n) : BitVec m :=
def EXTS {m : _} (v : (BitVec k_n)) : (BitVec m) :=
(Sail.BitVec.signExtend v m)

def initialize_registers : Unit :=
Expand Down
8 changes: 5 additions & 3 deletions test/lean/typquant.expected.lean
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import Out.Sail.Sail

open Sail

/-- Type quantifiers: n : Int -/
def foo (n : Int) : BitVec 4 :=
(0xF : BitVec 4)
def foo (n : Int) : (BitVec 4) :=
(0xF : (BitVec 4))

/-- Type quantifiers: k_n : Int -/
def bar (x : BitVec k_n) : BitVec k_n :=
def bar (x : (BitVec k_n)) : (BitVec k_n) :=
x

def initialize_registers : Unit :=
Expand Down

0 comments on commit a1b339a

Please sign in to comment.