Skip to content

Commit

Permalink
Lean: Adding features to support structs and bitfields (#817)
Browse files Browse the repository at this point in the history
  • Loading branch information
lfrenot authored Jan 17, 2025
1 parent cf168b5 commit f87c292
Show file tree
Hide file tree
Showing 9 changed files with 164 additions and 8 deletions.
7 changes: 4 additions & 3 deletions lib/arith.sail
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,16 @@ $include <flow.sail>

// ***** Addition *****

val add_atom = pure {ocaml: "add_int", interpreter: "add_int", lem: "integerAdd", coq: "Z.add", _: "add_int"} : forall 'n 'm.
val add_atom = pure {ocaml: "add_int", interpreter: "add_int", lem: "integerAdd", coq: "Z.add", lean: "HAdd.hAdd", _: "add_int"} : forall 'n 'm.
(int('n), int('m)) -> int('n + 'm)

val add_int = pure {ocaml: "add_int", interpreter: "add_int", lem: "integerAdd", coq: "Z.add",lean: "Int.add", _: "add_int"} : (int, int) -> int
val add_int = pure {ocaml: "add_int", interpreter: "add_int", lem: "integerAdd", coq: "Z.add", lean: "Int.add", _: "add_int"} : (int, int) -> int

overload operator + = {add_atom, add_int}

// ***** Subtraction *****

val sub_atom = pure {ocaml: "sub_int", interpreter: "sub_int", lem: "integerMinus", coq: "Z.sub", _: "sub_int"} : forall 'n 'm.
val sub_atom = pure {ocaml: "sub_int", interpreter: "sub_int", lem: "integerMinus", coq: "Z.sub", lean: "HSub.hSub", _: "sub_int"} : forall 'n 'm.
(int('n), int('m)) -> int('n - 'm)

val sub_int = pure {ocaml: "sub_int", interpreter: "sub_int", lem: "integerMinus", coq: "Z.sub", lean: "Int.sub", _: "sub_int"} : (int, int) -> int
Expand All @@ -71,6 +71,7 @@ val sub_nat = pure {
ocaml: "(fun (x,y) -> let n = sub_int (x,y) in if Big_int.less_equal n Big_int.zero then Big_int.zero else n)",
lem: "integerMinus",
coq: "Z.sub",
lean: "HSub.hSub",
_: "sub_nat"
} : (nat, nat) -> nat

Expand Down
2 changes: 2 additions & 0 deletions lib/vector.sail
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ val subrange_bits = pure {
interpreter: "subrange",
lem: "subrange_vec_dec",
coq: "subrange_vec_dec",
lean: "Sail.BitVec.extractLsb",
_: "vector_subrange"
} : forall ('n : Int) ('m : Int) ('o : Int), 0 <= 'o <= 'm < 'n.
(bits('n), int('m), int('o)) -> bits('m - 'o + 1)
Expand All @@ -321,6 +322,7 @@ val update_subrange_bits = pure {
interpreter: "update_subrange",
lem: "update_subrange_vec_dec",
coq: "update_subrange_vec_dec",
lean: "Sail.BitVec.updateSubrange",
_: "vector_update_subrange"
} : forall 'n 'm 'o, 0 <= 'o <= 'm < 'n. (bits('n), int('m), int('o), bits('m - ('o - 1))) -> bits('n)
$else
Expand Down
21 changes: 21 additions & 0 deletions src/sail_lean_backend/Sail/Sail.lean
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,26 @@ def truncate {w : Nat} (x : BitVec w) (w' : Nat) : BitVec w' :=
def truncateLsb {w : Nat} (x : BitVec w) (w' : Nat) : BitVec w' :=
x.extractLsb' (w - w') w'

def extractLsb {w : Nat} (x : BitVec w) (hi lo : Nat) : BitVec (hi - lo + 1) :=
x.extractLsb hi lo

def updateSubrange' {w : Nat} (x : BitVec w) (start len : Nat) (y : BitVec len) : BitVec w :=
let mask := ~~~(((BitVec.allOnes len).zeroExtend w) <<< start)
let y' := mask ||| ((y.zeroExtend w) <<< start)
x &&& y'

def updateSubrange {w : Nat} (x : BitVec w) (hi lo : Nat) (y : BitVec (hi - lo + 1)) : BitVec w :=
updateSubrange' x lo _ y

end BitVec
end Sail

structure RegisterRef (regstate regval a : Type) where
name : String
read_from : regstate -> a
write_to : a -> regstate -> regstate
of_regval : regval -> Option a
regval_of : a -> regval

def undefined_bitvector (w : Nat) : BitVec w :=
0
18 changes: 18 additions & 0 deletions src/sail_lean_backend/pretty_print_lean.ml
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,21 @@ let rec doc_typ ctxt (Typ_aux (t, _) as typ) =
parens (string "BitVec " ^^ doc_nexp ctxt m)
| Typ_app (Id_aux (Id "atom", _), [A_aux (A_nexp (Nexp_aux (Nexp_var ki, _)), _)]) ->
string "Int" (* TODO This probably has to be generalized *)
| Typ_app (Id_aux (Id "register", _), t_app) ->
string "RegisterRef Unit Unit "
(* TODO: Replace units with real types. *) ^^ separate_map comma (doc_typ_app ctxt) t_app
| Typ_app (Id_aux (Id "implicit", _), [A_aux (A_nexp (Nexp_aux (Nexp_var ki, _)), _)]) ->
underscore (* TODO check if the type of implicit arguments can really be always inferred *)
| Typ_tuple ts -> parens (separate_map (space ^^ string "×" ^^ space) (doc_typ ctxt) ts)
| Typ_id (Id_aux (Id id, _)) -> string id
| _ -> failwith ("Type " ^ string_of_typ_con typ ^ " " ^ string_of_typ typ ^ " not translatable yet.")

and doc_typ_app ctxt (A_aux (t, _) as typ) =
match t with
| A_typ t' -> doc_typ ctxt t'
| A_bool nc -> failwith ("Constraint " ^ string_of_n_constraint nc ^ "not translatable yet.")
| A_nexp m -> doc_nexp ctxt m

let rec captured_typ_var ((i, Typ_aux (t, _)) as typ) =
match t with
| Typ_app (Id_aux (Id "atom", _), [A_aux (A_nexp (Nexp_aux (Nexp_var ki, _)), _)])
Expand Down Expand Up @@ -274,8 +283,17 @@ let rec doc_exp ctxt (E_aux (e, (l, annot)) as full_exp) =
| _ -> failwith "Let pattern not translatable yet."
in
nest 2 (flow (break 1) [string "let"; string id; coloneq; doc_exp ctxt lexp]) ^^ hardline ^^ doc_exp ctxt e
| E_struct fexps ->
let args = List.map (doc_fexp ctxt) fexps in
braces (space ^^ nest 2 (separate hardline args) ^^ space)
| E_field (exp, id) -> doc_exp ctxt exp ^^ dot ^^ doc_id_ctor id
| E_struct_update (exp, fexps) ->
let args = List.map (doc_fexp ctxt) fexps in
braces (space ^^ doc_exp ctxt exp ^^ string " with " ^^ separate (comma ^^ space) args ^^ space)
| _ -> failwith ("Expression " ^ string_of_exp_con full_exp ^ " " ^ string_of_exp full_exp ^ " not translatable yet.")

and doc_fexp ctxt (FE_aux (FE_fexp (field, exp), _)) = doc_id_ctor field ^^ string " := " ^^ doc_exp ctxt exp

let doc_binder ctxt i t =
let paranthesizer =
match t with
Expand Down
69 changes: 69 additions & 0 deletions test/lean/bitfield.expected.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import Out.Sail.Sail

open Sail

def cr_type := (BitVec 8)

def undefined_cr_type (lit : Unit) : SailM (BitVec 8) :=
return ((undefined_bitvector 8) : (BitVec 8))

def Mk_cr_type (v : (BitVec 8)) : (BitVec 8) :=
v

def _get_cr_type_bits (v : (BitVec 8)) : (BitVec 8) :=
(Sail.BitVec.extractLsb v (HSub.hSub 8 1) 0)

def _update_cr_type_bits (v : (BitVec 8)) (x : (BitVec 8)) : (BitVec 8) :=
(Sail.BitVec.updateSubrange v (HSub.hSub 8 1) 0 x)

def _set_cr_type_bits (r_ref : RegisterRef Unit Unit (BitVec 8)) (v : (BitVec 8)) : SailM Unit :=
return sorry

def _get_cr_type_CR0 (v : (BitVec 8)) : (BitVec 4) :=
(Sail.BitVec.extractLsb v 7 4)

def _update_cr_type_CR0 (v : (BitVec 8)) (x : (BitVec 4)) : (BitVec 8) :=
(Sail.BitVec.updateSubrange v 7 4 x)

def _set_cr_type_CR0 (r_ref : RegisterRef Unit Unit (BitVec 8)) (v : (BitVec 4)) : SailM Unit :=
return sorry

def _get_cr_type_CR1 (v : (BitVec 8)) : (BitVec 2) :=
(Sail.BitVec.extractLsb v 3 2)

def _update_cr_type_CR1 (v : (BitVec 8)) (x : (BitVec 2)) : (BitVec 8) :=
(Sail.BitVec.updateSubrange v 3 2 x)

def _set_cr_type_CR1 (r_ref : RegisterRef Unit Unit (BitVec 8)) (v : (BitVec 2)) : SailM Unit :=
return sorry

def _get_cr_type_CR3 (v : (BitVec 8)) : (BitVec 2) :=
(Sail.BitVec.extractLsb v 1 0)

def _update_cr_type_CR3 (v : (BitVec 8)) (x : (BitVec 2)) : (BitVec 8) :=
(Sail.BitVec.updateSubrange v 1 0 x)

def _set_cr_type_CR3 (r_ref : RegisterRef Unit Unit (BitVec 8)) (v : (BitVec 2)) : SailM Unit :=
return sorry

def _get_cr_type_GT (v : (BitVec 8)) : (BitVec 1) :=
(Sail.BitVec.extractLsb v 6 6)

def _update_cr_type_GT (v : (BitVec 8)) (x : (BitVec 1)) : (BitVec 8) :=
(Sail.BitVec.updateSubrange v 6 6 x)

def _set_cr_type_GT (r_ref : RegisterRef Unit Unit (BitVec 8)) (v : (BitVec 1)) : SailM Unit :=
return sorry

def _get_cr_type_LT (v : (BitVec 8)) : (BitVec 1) :=
(Sail.BitVec.extractLsb v 7 7)

def _update_cr_type_LT (v : (BitVec 8)) (x : (BitVec 1)) : (BitVec 8) :=
(Sail.BitVec.updateSubrange v 7 7 x)

def _set_cr_type_LT (r_ref : RegisterRef Unit Unit (BitVec 8)) (v : (BitVec 1)) : SailM Unit :=
return sorry

def initialize_registers : Unit :=
()

11 changes: 11 additions & 0 deletions test/lean/bitfield.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
default Order dec

$include <prelude.sail>

bitfield cr_type : bits(8) = {
CR0 : 7 .. 4,
LT : 7,
GT : 6,
CR1 : 3 .. 2,
CR3 : 1 .. 0
}
17 changes: 16 additions & 1 deletion test/lean/struct.expected.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,26 @@ open Sail

structure My_struct where
field1 : Int
field2 : Int
field2 : (BitVec 1)

def undefined_My_struct (lit : Unit) : SailM My_struct :=
return sorry

def struct_field2 (s : My_struct) : (BitVec 1) :=
s.field2

def struct_update_field2 (s : My_struct) (b : (BitVec 1)) : My_struct :=
{ s with field2 := b }

/-- Type quantifiers: i : Int -/
def struct_update_both_fields (s : My_struct) (i : Int) (b : (BitVec 1)) : My_struct :=
{ s with field1 := i, field2 := b }

/-- Type quantifiers: i : Int -/
def mk_struct (i : Int) (b : (BitVec 1)) : My_struct :=
{ field1 := i
field2 := b }

def initialize_registers : Unit :=
()

3 changes: 0 additions & 3 deletions test/lean/struct.lean.expected

This file was deleted.

24 changes: 23 additions & 1 deletion test/lean/struct.sail
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,28 @@ $include <prelude.sail>

struct My_struct = {
field1 : int,
field2 : int,
field2 : bit,
}

val struct_field2 : My_struct -> bit
function struct_field2(s) = {
s.field2
}

val struct_update_field2 : (My_struct, bit) -> My_struct
function struct_update_field2(s, b) = {
{ s with field2 = b }
}

val struct_update_both_fields : (My_struct, int, bit) -> My_struct
function struct_update_both_fields(s, i, b) = {
{ s with field1 = i, field2 = b }
}

val mk_struct : (int, bit) -> My_struct
function mk_struct(i, b) = {
struct {
field1 = i,
field2 = b
}
}

0 comments on commit f87c292

Please sign in to comment.