Skip to content

Commit

Permalink
More complete tests on module calls and a bunch of fixes (#525)
Browse files Browse the repository at this point in the history
  • Loading branch information
AltGr authored Oct 17, 2023
2 parents 7141734 + bd90555 commit 73df41e
Show file tree
Hide file tree
Showing 11 changed files with 415 additions and 188 deletions.
4 changes: 4 additions & 0 deletions .git-blame-ignore-revs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
# Use `git config --global blame.ignoreRevsFile .git-blame-ignore-revs` to use it
# Add new reformatting commits at the top

2708fa53b23bde545e7378a660cdb99e8671f1de
a79acd1fa8b701a5688c7fa985c7064cd6d81acf
4bce4e6322ede5cddb7384511f2fc0d05416f97f

4910158aeadad66fd9e542b736bf81fab66cd26d
8e33355eadabe2a95478c419884fab899244766b
72882f82dfc75888470a9415a5b51a7ab38e140e
Expand Down
216 changes: 99 additions & 117 deletions compiler/dcalc/from_scopelang.ml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ type scope_input_var_ctx = {
scope_input_name : StructField.t;
scope_input_io : Runtime.io_input Mark.pos;
scope_input_typ : naked_typ;
scope_input_thunked : bool;
(* For reentrant variables: if true, the type t of the field has been
changed to (unit -> t). Otherwise, the type was already a function and
wasn't changed so no additional wrapping will be needed *)
}

type 'm scope_ref =
Expand Down Expand Up @@ -193,19 +197,30 @@ let collapse_similar_outcomes (type m) (excepts : m Scopelang.Ast.expr list) :
in
excepts

let thunk_scope_arg ~is_func io_in e =
let input_var_needs_thunking typ io_in =
(* For "context" (or reentrant) variables, we thunk them as [(fun () -> e)] so
that we can put them in default terms at the initialisation of the function
body, allowing an empty error to recover the default value. *)
let silent_var = Var.make "_" in
let pos = Mark.get io_in in
match Mark.remove io_in with
| Runtime.NoInput -> invalid_arg "thunk_scope_arg"
| Runtime.OnlyInput -> Expr.eerroronempty e (Mark.get e)
| Runtime.Reentrant ->
(* we don't need to thunk expressions that are already functions *)
if is_func then e
else Expr.make_abs [| silent_var |] e [TLit TUnit, pos] pos
match Mark.remove io_in.Desugared.Ast.io_input, typ with
| Runtime.Reentrant, TArrow _ ->
false (* we don't need to thunk expressions that are already functions *)
| Runtime.Reentrant, _ -> true
| _ -> false

let input_var_typ typ io_in =
let pos = Mark.get io_in.Desugared.Ast.io_input in
if input_var_needs_thunking typ io_in then
TArrow ([TLit TUnit, pos], (typ, pos)), pos
else typ, pos

let thunk_scope_arg var_ctx e =
match var_ctx.scope_input_io, var_ctx.scope_input_thunked with
| (Runtime.NoInput, _), _ -> invalid_arg "thunk_scope_arg"
| (Runtime.OnlyInput, _), false -> Expr.eerroronempty e (Mark.get e)
| (Runtime.Reentrant, _), false -> e
| (Runtime.Reentrant, pos), true ->
Expr.make_abs [| Var.make "_" |] e [TLit TUnit, pos] pos
| _ -> assert false

let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
'm Ast.expr boxed =
Expand Down Expand Up @@ -246,23 +261,27 @@ let rec translate_expr (ctx : 'm ctx) (e : 'm Scopelang.Ast.expr) :
let in_var_map =
ScopeVar.Map.merge
(fun var_name (str_field : scope_input_var_ctx option) expr ->
let expr =
match str_field, expr with
| Some { scope_input_io = Reentrant, _; _ }, None ->
Some (Expr.unbox (Expr.eemptyerror (mark_tany m pos)))
| _ -> expr
in
match str_field, expr with
| None, None -> None
| None, None -> assert false
| Some ({ scope_input_io = Reentrant, iopos; _ } as var_ctx), None ->
let ty0 =
match var_ctx.scope_input_typ with
| TArrow ([_], ty) -> ty
| _ -> assert false
(* reentrant field must be thunked with correct function type at
this point *)
in
Some
( var_ctx.scope_input_name,
Expr.make_abs
[| Var.make "_" |]
(Expr.eemptyerror (Expr.with_ty m ty0))
[TAny, iopos]
pos )
| Some var_ctx, Some e ->
Some
( var_ctx.scope_input_name,
thunk_scope_arg
~is_func:
(match var_ctx.scope_input_typ with
| TArrow _ -> true
| _ -> false)
var_ctx.scope_input_io (translate_expr ctx e) )
thunk_scope_arg var_ctx (translate_expr ctx e) )
| Some var_ctx, None ->
Message.raise_multispanned_error
[
Expand Down Expand Up @@ -641,10 +660,8 @@ let translate_rule
ctx.scope_vars;
} )
| Definition
( (SubScopeVar { alias = subs_index; var = subs_var; _ }, var_def_pos),
tau,
a_io,
e ) ->
((SubScopeVar { alias = subs_index; var = subs_var; _ }, _), tau, a_io, e)
->
let a_name =
Mark.map
(fun str ->
Expand All @@ -662,9 +679,14 @@ let translate_rule
})
[sigma_name, pos_sigma; a_name]
in
let is_func = match Mark.remove tau with TArrow _ -> true | _ -> false in
let thunked_or_nonempty_new_e =
thunk_scope_arg ~is_func a_io.Desugared.Ast.io_input new_e
match a_io.Desugared.Ast.io_input with
| Runtime.NoInput, _ -> assert false
| Runtime.OnlyInput, _ -> Expr.eerroronempty new_e (Mark.get new_e)
| Runtime.Reentrant, pos -> (
match Mark.remove tau with
| TArrow _ -> new_e
| _ -> Expr.thunk_term new_e (Expr.with_pos pos (Mark.get new_e)))
in
( (fun next ->
Bindlib.box_apply2
Expand All @@ -673,13 +695,7 @@ let translate_rule
{
scope_let_next = next;
scope_let_pos = Mark.get a_name;
scope_let_typ =
(match Mark.remove a_io.io_input with
| NoInput -> failwith "should not happen"
| OnlyInput -> tau
| Reentrant ->
if is_func then tau
else TArrow ([TLit TUnit, var_def_pos], tau), var_def_pos);
scope_let_typ = input_var_typ (Mark.remove tau) a_io;
scope_let_expr = thunked_or_nonempty_new_e;
scope_let_kind = SubScopeVarDefinition;
})
Expand Down Expand Up @@ -927,8 +943,7 @@ let translate_rules
let translate_scope_decl
(ctx : 'm ctx)
(scope_name : ScopeName.t)
(sigma : 'm Scopelang.Ast.scope_decl) :
'm Ast.expr scope_body Bindlib.box * struct_ctx =
(sigma : 'm Scopelang.Ast.scope_decl) =
let sigma_info = ScopeName.get_info sigma.scope_decl_name in
let scope_sig =
ScopeName.Map.find sigma.scope_decl_name ctx.scopes_parameters.scope_sigs
Expand Down Expand Up @@ -1007,17 +1022,6 @@ let translate_scope_decl
| _ -> true)
scope_variables
in
let input_var_typ (var_ctx : scope_var_ctx) =
match Mark.remove var_ctx.scope_var_io.io_input with
| OnlyInput -> var_ctx.scope_var_typ, pos_sigma
| Reentrant -> (
match var_ctx.scope_var_typ with
| TArrow _ -> var_ctx.scope_var_typ, pos_sigma
| _ ->
( TArrow ([TLit TUnit, pos_sigma], (var_ctx.scope_var_typ, pos_sigma)),
pos_sigma ))
| NoInput -> failwith "should not happen"
in
let input_destructurings next =
List.fold_right
(fun (var_ctx, v) next ->
Expand All @@ -1033,7 +1037,8 @@ let translate_scope_decl
scope_let_kind = DestructuringInputStruct;
scope_let_next = next;
scope_let_pos = pos_sigma;
scope_let_typ = input_var_typ var_ctx;
scope_let_typ =
input_var_typ var_ctx.scope_var_typ var_ctx.scope_var_io;
scope_let_expr =
( EStructAccess
{ name = scope_input_struct_name; e = r; field },
Expand All @@ -1044,31 +1049,15 @@ let translate_scope_decl
(Expr.make_var scope_input_var (mark_tany scope_mark pos_sigma))))
scope_input_variables next
in
let scope_body =
Bindlib.box_apply
(fun scope_body_expr ->
{
scope_body_expr;
scope_body_input_struct = scope_input_struct_name;
scope_body_output_struct = scope_return_struct_name;
})
(Bindlib.bind_var scope_input_var
(input_destructurings rules_with_return_expr))
in
let field_map =
List.fold_left
(fun acc (var_ctx, _) ->
let var = var_ctx.scope_var_name in
let field =
(ScopeVar.Map.find var scope_sig.scope_sig_in_fields).scope_input_name
in
StructField.Map.add field (input_var_typ var_ctx) acc)
StructField.Map.empty scope_input_variables
in
let new_struct_ctx =
StructName.Map.singleton scope_input_struct_name field_map
in
scope_body, new_struct_ctx
Bindlib.box_apply
(fun scope_body_expr ->
{
scope_body_expr;
scope_body_input_struct = scope_input_struct_name;
scope_body_output_struct = scope_return_struct_name;
})
(Bindlib.bind_var scope_input_var
(input_destructurings rules_with_return_expr))

let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
let defs_dependencies = Scopelang.Dependency.build_program_dep_graph prgm in
Expand Down Expand Up @@ -1114,7 +1103,10 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
{
scope_input_name = StructField.fresh (s, Mark.get info);
scope_input_io = vis.Desugared.Ast.io_input;
scope_input_typ = Mark.remove typ;
scope_input_typ =
Mark.remove (input_var_typ (Mark.remove typ) vis);
scope_input_thunked =
input_var_needs_thunking (Mark.remove typ) vis;
})
scope.Scopelang.Ast.scope_sig
in
Expand Down Expand Up @@ -1155,33 +1147,35 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
ModuleName.Map.map process_modules prgm.Scopelang.Ast.program_modules;
}
in
let add_scope_in_structs scope_sigs structs =
ScopeName.Map.fold
(fun _ scope_sig_ctx acc ->
let fields =
ScopeVar.Map.fold
(fun _ sivc acc ->
let pos = Mark.get (StructField.get_info sivc.scope_input_name) in
StructField.Map.add sivc.scope_input_name
(sivc.scope_input_typ, pos)
acc)
scope_sig_ctx.scope_sig_in_fields StructField.Map.empty
in
StructName.Map.add scope_sig_ctx.scope_sig_input_struct fields acc)
scope_sigs.scope_sigs structs
in
let rec gather_module_in_structs acc sctx =
(* Expose all added in_structs from submodules at toplevel *)
ModuleName.Map.fold
(fun _ scope_sigs acc ->
let acc = gather_module_in_structs acc scope_sigs.scope_sigs_modules in
ScopeName.Map.fold
(fun _ scope_sig_ctx acc ->
let fields =
ScopeVar.Map.fold
(fun _ sivc acc ->
let pos =
Mark.get (StructField.get_info sivc.scope_input_name)
in
StructField.Map.add sivc.scope_input_name
(sivc.scope_input_typ, pos)
acc)
scope_sig_ctx.scope_sig_in_fields StructField.Map.empty
in
StructName.Map.add scope_sig_ctx.scope_sig_input_struct fields acc)
scope_sigs.scope_sigs acc)
add_scope_in_structs scope_sigs
(gather_module_in_structs acc scope_sigs.scope_sigs_modules))
sctx acc
in
let decl_ctx =
{
decl_ctx with
ctx_structs =
gather_module_in_structs decl_ctx.ctx_structs sctx.scope_sigs_modules;
add_scope_in_structs sctx
(gather_module_in_structs decl_ctx.ctx_structs sctx.scope_sigs_modules);
}
in
let top_ctx =
Expand All @@ -1205,21 +1199,20 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
ending with the top-level scope. The decl_ctx is filled in left-to-right
order, then the chained scopes aggregated from the right. *)
let rec translate_defs ctx = function
| [] -> Bindlib.box Nil, ctx
| [] -> Bindlib.box Nil
| def :: next ->
let ctx, dvar, def =
let dvar, def =
match def with
| Scopelang.Dependency.Topdef gname ->
let expr, ty = TopdefName.Map.find gname prgm.program_topdefs in
let expr = translate_expr ctx expr in
( ctx,
fst (TopdefName.Map.find gname ctx.toplevel_vars),
( fst (TopdefName.Map.find gname ctx.toplevel_vars),
Bindlib.box_apply
(fun e -> Topdef (gname, ty, e))
(Expr.Box.lift expr) )
| Scopelang.Dependency.Scope scope_name ->
let scope = ScopeName.Map.find scope_name prgm.program_scopes in
let scope_body, scope_in_struct =
let scope_body =
translate_scope_decl ctx scope_name (Mark.remove scope)
in
let scope_var =
Expand All @@ -1230,33 +1223,22 @@ let translate_program (prgm : 'm Scopelang.Ast.program) : 'm Ast.program =
| Local_scope_ref v -> v
| External_scope_ref _ -> assert false
in
( {
ctx with
decl_ctx =
{
ctx.decl_ctx with
ctx_structs =
StructName.Map.union
(fun _ _ -> assert false)
ctx.decl_ctx.ctx_structs scope_in_struct;
};
},
scope_var,
( scope_var,
Bindlib.box_apply
(fun body -> ScopeDef (scope_name, body))
scope_body )
in
let scope_next, ctx = translate_defs ctx next in
let scope_next = translate_defs ctx next in
let next_bind = Bindlib.bind_var dvar scope_next in
( Bindlib.box_apply2
(fun item next_bind -> Cons (item, next_bind))
def next_bind,
ctx )
Bindlib.box_apply2
(fun item next_bind -> Cons (item, next_bind))
def next_bind
in
let items, ctx = translate_defs top_ctx defs_ordering in
let items = translate_defs top_ctx defs_ordering in
Expr.Box.assert_closed items;
{
code_items = Bindlib.unbox items;
decl_ctx = ctx.decl_ctx;
decl_ctx;
module_name = prgm.Scopelang.Ast.program_module_name;
lang = prgm.program_lang;
}
2 changes: 1 addition & 1 deletion compiler/scopelang/dependency.ml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ let rec expr_used_defs e =
if TopdefName.path v <> [] then VMap.empty
else VMap.singleton (Topdef v) pos
| (EScopeCall { scope; _ }, m) as e ->
if ScopeName.path scope <> [] then VMap.empty
if ScopeName.path scope <> [] then recurse_subterms e
else VMap.add (Scope scope) (Expr.mark_pos m) (recurse_subterms e)
| EAbs { binder; _ }, _ ->
let _, body = Bindlib.unmbind binder in
Expand Down
Loading

0 comments on commit 73df41e

Please sign in to comment.