Skip to content

Commit

Permalink
global security and errors for ppx
Browse files Browse the repository at this point in the history
  • Loading branch information
maxtori committed Oct 7, 2024
1 parent 07ce9a3 commit 5bb6911
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 49 deletions.
137 changes: 89 additions & 48 deletions src/ppx/ppx_common.ml
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,6 @@
open Ppxlib
open Ast_builder.Default

let str_of_expr e = Pprintast.string_of_expression e
let str_of_pat p =
Pprintast.pattern Format.str_formatter p;
Format.flush_str_formatter ()
let str_of_structure e = Pprintast.string_of_structure e

let llid ~loc s = {txt=Longident.parse s; loc}
let esome e =
let loc = e.pexp_loc in
pexp_construct ~loc (llid ~loc "Some") (Some e)

(** service *)

type options = {
Expand All @@ -45,6 +34,46 @@ type options = {
service : expression option;
}

let loc = !Ast_helper.default_loc
let global_errors = ref [%expr None]
let global_error_type = ref [%type: exn]
let global_security = ref [%expr None]
let global_security_type = ref [%type: EzAPI.no_security]

let remove_poly c = match c.ptyp_desc with
| Ptyp_poly ([], c) -> c
| _ -> c

let remove_constraint e = match e.pexp_desc with
| Pexp_constraint (e, _) -> e
| _ -> e

let extract_list_type = function
| None -> [%type: _]
| Some t ->
let t = remove_poly t in
match t.ptyp_desc with
| Ptyp_constr ({txt=(Lident "list" | Ldot (Lident "List", "t")); _}, [ c ]) -> c
| _ -> t

let set_global_errors ?typ e =
let loc = e.pexp_loc in
global_errors := [%expr Some [%e remove_constraint e]];
global_error_type := extract_list_type typ

let set_global_security ?typ e =
let loc = e.pexp_loc in
global_security := [%expr Some [%e remove_constraint e]];
global_security_type := extract_list_type typ

let set_globals l =
List.iter (fun ({txt; _}, e) ->
let name = Longident.name txt in
match name with
| "errors" -> set_global_errors e
| "security" -> set_global_security e
| _ -> ()) l

let raw e =
let loc = e.pexp_loc in
[%expr EzAPI.Raw (List.filter_map EzAPI.Mime.parse [%e e])]
Expand All @@ -53,11 +82,11 @@ let options loc = {
path = [%expr EzAPI.Path.root];
input = [%expr EzAPI.Empty];
output = [%expr EzAPI.Empty];
errors = [%expr None]; params = [%expr None];
errors = !global_errors; params = [%expr None];
section = [%expr None]; name=[%expr None]; descr = [%expr None];
security = [%expr None]; register=[%expr true]; input_example = [%expr None];
hide = [%expr None]; output_example = [%expr None]; error_type = [%type: exn];
security_type = [%type: EzAPI.no_security];
security = !global_security; register=[%expr true]; input_example = [%expr None];
hide = [%expr None]; output_example = [%expr None]; error_type = !global_error_type;
security_type = !global_security_type;
debug = false; directory = None; service = None
}

Expand All @@ -75,7 +104,7 @@ let parse_arg ~loc s = match String.index_opt s ':' with
Location.raise_errorf ~loc "argument type not understood: %S" typ

let parse_path ~loc s =
let path ~loc s = pexp_ident ~loc (llid ~loc ("EzAPI.Path." ^ s)) in
let path ~loc s = pexp_ident ~loc {txt=Longident.parse ("EzAPI.Path." ^ s); loc} in
let l = String.split_on_char '/' s in
let l = List.filter (fun s -> s <> "") l in
List.fold_left (fun acc s ->
Expand Down Expand Up @@ -107,28 +136,28 @@ let get_options ~loc ?(options=options loc) ?name p =
| "raw_input" -> name, { acc with input = raw e }
| "output" -> name, { acc with output = [%expr EzAPI.Json [%e e]] }
| "raw_output" -> name, { acc with output = raw e }
| "params" -> name, { acc with params = esome e }
| "errors" -> name, { acc with errors = esome e; error_type = ptyp_any ~loc }
| "section" -> name, { acc with section = esome e }
| "params" -> name, { acc with params = [%expr Some [%e e]] }
| "errors" -> name, { acc with errors = [%expr Some [%e e]]; error_type = [%type: _] }
| "section" -> name, { acc with section = [%expr Some [%e e]] }
| "name" ->
begin match e.pexp_desc with
| Pexp_constant cst ->
begin match name, string_literal cst with
| None, Some s -> Some s, { acc with name = esome e }
| Some n, _ -> Some n, { acc with name = esome e }
| None, Some s -> Some s, { acc with name = [%expr Some [%e e]] }
| Some n, _ -> Some n, { acc with name = [%expr Some [%e e]] }
| _ -> Format.eprintf "name should be a string literal"; name, acc
end
| _ ->
match name with
| Some n -> Some n, { acc with name = [%expr Some [%e estring ~loc n]] }
| _ -> name, acc
end
| "descr" -> name, { acc with descr = esome e }
| "security" -> name, { acc with security = esome e; security_type = ptyp_any ~loc }
| "descr" -> name, { acc with descr = [%expr Some [%e e]] }
| "security" -> name, { acc with security = [%expr Some [%e e]]; security_type = [%type: _] }
| "register" -> name, { acc with register = e }
| "hide" -> name, { acc with hide = e }
| "input_example" -> name, { acc with input_example = esome e }
| "output_example" -> name, { acc with output_example = esome e }
| "input_example" -> name, { acc with input_example = [%expr Some [%e e]] }
| "output_example" -> name, { acc with output_example = [%expr Some [%e e]] }
| "debug" -> name, { acc with debug = true }
| "dir" -> begin match e.pexp_desc with
| Pexp_constant cst ->
Expand All @@ -139,10 +168,10 @@ let get_options ~loc ?(options=options loc) ?name p =
| _ -> Format.eprintf "directory should be a literal"; name, acc
end
| "service" ->
name, { acc with service = Some e; error_type = ptyp_any ~loc; security_type = ptyp_any ~loc }
name, { acc with service = Some e; error_type = [%type: _]; security_type = [%type: _] }
| _ -> name, acc) (name, options) l
| PStr [ {pstr_desc=Pstr_eval ({pexp_desc=Pexp_ident _; _} as e, _); _} ] ->
name, { options with service = Some e; error_type = ptyp_any ~loc; security_type = ptyp_any ~loc }
name, { options with service = Some e; error_type = [%type: _]; security_type = [%type: _] }
| PStr [ {pstr_desc=Pstr_eval ({pexp_desc=Pexp_constant Pconst_string (s, loc, _); _}, _); _} ] ->
name, { options with path = parse_path ~loc s }
| PStr s ->
Expand Down Expand Up @@ -173,11 +202,9 @@ let service_value ?name ?options ~meth ~loc p =
Optional "output_example", options.output_example;
Nolabel, options.path ] in
let pat = ppat_constraint ~loc (pvar ~loc name) @@
ptyp_constr ~loc (llid ~loc "EzAPI.service") [
ptyp_any ~loc; ptyp_any ~loc; ptyp_any ~loc; options.error_type;
options.security_type ] in
[%type: (_, _, _, [%t options.error_type], [%t options.security_type]) EzAPI.service] in
let str = pstr_value ~loc Nonrecursive [ value_binding ~loc ~pat ~expr ] in
if options.debug then Format.printf "%s@." @@ str_of_structure [ str ];
if options.debug then Format.printf "%a@." Pprintast.structure_item str;
str, name, options

(** register service/handler *)
Expand Down Expand Up @@ -205,7 +232,7 @@ let register name a =
~expr:(eapply ~loc (evar ~loc "EzAPIServerUtils.register") [
e; evar ~loc name; evar ~loc ppx_dir_name ]) in
let str = ppx_dir @ [ pstr_value ~loc Nonrecursive [ register ] ] in
if options.debug then Format.printf "%s@." @@ str_of_structure str;
if options.debug then Format.printf "%a@." Pprintast.structure str;
str

let register_ws ~onclose react_name bg_name a =
Expand All @@ -215,7 +242,7 @@ let register_ws ~onclose react_name bg_name a =
let ppx_dir_name = match options.directory with None -> "ppx_dir" | Some s -> s in
let onclose = match onclose with
| [] -> [%expr None]
| [ {pvb_pat = {ppat_desc = Ppat_var {txt; loc}; _}; _} ] -> esome (evar ~loc txt)
| [ {pvb_pat = {ppat_desc = Ppat_var {txt; loc}; _}; _} ] -> [%expr Some [%e evar ~loc txt]]
| _ -> Location.raise_errorf ~loc "too many value bindings" in
match options.service with
| None -> Location.raise_errorf ~loc "service not defined"
Expand All @@ -229,7 +256,7 @@ let register_ws ~onclose react_name bg_name a =
Labelled "bg", evar ~loc bg_name;
Nolabel, evar ~loc ppx_dir_name ]) in
let str = ppx_dir @ [ pstr_value ~loc Nonrecursive [ register ] ] in
if options.debug then Format.printf "%s@." @@ str_of_structure str;
if options.debug then Format.printf "%a@." Pprintast.structure str;
str

let process name a =
Expand All @@ -243,7 +270,7 @@ let process name a =
value_binding ~loc ~pat:(pvar ~loc ppx_dir_name)
~expr:(eapply ~loc (evar ~loc "EzAPIServerUtils.register") [
evar ~loc service_name; evar ~loc name; evar ~loc ppx_dir_name ]) ] in
if options.debug then Format.printf "%s@." @@ str_of_structure [ register ];
if options.debug then Format.printf "%a@." Pprintast.structure_item register;
ppx_dir @ [ service; register ]

let process_ws ~onclose react_name bg_name a =
Expand All @@ -255,7 +282,7 @@ let process_ws ~onclose react_name bg_name a =
let ppx_dir_name = match options.directory with None -> "ppx_dir" | Some s -> s in
let onclose = match onclose with
| [] -> [%expr None]
| [ {pvb_pat = {ppat_desc = Ppat_var {txt; loc}; _}; _} ] -> esome (evar ~loc txt)
| [ {pvb_pat = {ppat_desc = Ppat_var {txt; loc}; _}; _} ] -> [%expr Some [%e evar ~loc txt]]
| _ -> Location.raise_errorf ~loc "too many value bindings" in
let register =
pstr_value ~loc Nonrecursive [
Expand All @@ -266,7 +293,7 @@ let process_ws ~onclose react_name bg_name a =
Labelled "react", evar ~loc react_name;
Labelled "bg", evar ~loc bg_name;
Nolabel, evar ~loc ppx_dir_name ]) ] in
if options.debug then Format.printf "%s@." @@ str_of_structure [ register ];
if options.debug then Format.printf "%a@." Pprintast.structure_item register;
ppx_dir @ [ service; register ]

let handler_args e =
Expand Down Expand Up @@ -308,11 +335,11 @@ let server_options e =
List.fold_left (fun acc (s, e) -> match s with
| "port" -> { acc with port = e }
| "dir" -> { acc with dir = e }
| "catch" -> { acc with catch = esome e }
| "headers" -> { acc with allow_headers = esome e }
| "methods" -> { acc with allow_methods = esome e }
| "origin" -> { acc with allow_origin = esome e }
| "credentials" -> { acc with allow_credentials = esome e }
| "catch" -> { acc with catch = [%expr Some [%e e]] }
| "headers" -> { acc with allow_headers = [%expr Some [%e e]] }
| "methods" -> { acc with allow_methods = [%expr Some [%e e]] }
| "origin" -> { acc with allow_origin = [%expr Some [%e e]] }
| "credentials" -> { acc with allow_credentials = [%expr Some [%e e]] }
| _ -> acc) (dft (eint ~loc 8080)) l
| _ -> Location.raise_errorf ~loc "server options not understood"

Expand All @@ -322,12 +349,10 @@ let server_aux e =
[%expr
EzAPIServer.server ?catch:[%e options.catch] ?allow_headers:[%e options.allow_headers]
?allow_methods:[%e options.allow_methods] ?allow_origin:[%e options.allow_origin]
?allow_credentials:[%e options.allow_credentials]
[%e elist ~loc [
pexp_tuple ~loc [
options.port;
pexp_construct ~loc (llid ~loc "EzAPIServerUtils.API") (Some options.dir)
] ] ] ]
?allow_credentials:[%e options.allow_credentials] [
([%e options.port], EzAPIServerUtils.API [%e options.dir])
]
]

let server ~loc p =
match p with
Expand Down Expand Up @@ -447,6 +472,22 @@ let transform ?kind () =
let options = { (options loc) with register = [%expr false] } in
let service, _, _ = service_value ~options ~meth:txt ~loc p in
service :: acc
(* globals *)
| Pstr_extension (({txt="service"; _}, PStr [ {pstr_desc=Pstr_eval ({pexp_desc=Pexp_record (l, _); _}, _); _} ]), _) ->
set_globals l;
acc
| Pstr_extension (({txt="service"; _}, PStr [ {pstr_desc=Pstr_value (_, l); _} ]), _) ->
List.iter (fun vb ->
match vb.pvb_pat.ppat_desc with
| Ppat_var {txt="errors"; _} -> set_global_errors vb.pvb_expr
| Ppat_var {txt="security"; _} -> set_global_security vb.pvb_expr
| Ppat_constraint ({ppat_desc = Ppat_var {txt="errors"; _}; _}, typ) ->
set_global_errors ~typ vb.pvb_expr
| Ppat_constraint ({ppat_desc = Ppat_var {txt="security"; _}; _}, typ) ->
set_global_security ~typ vb.pvb_expr
| _ -> ()) l;
acc
(* service deriver *)
| Pstr_type (_rec_flag, [ t ]) ->
let loc = t.ptype_loc in
begin match List.find_opt (fun a -> List.mem a.attr_name.txt methods) t.ptype_attributes with
Expand Down
13 changes: 12 additions & 1 deletion test/ppx/test_ppx_lib.ml
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@

type error = {
name: string;
msg: string;
} [@@deriving encoding]

let%service errors = [
EzAPI.Err.make ~code:400 ~name:"Error" ~encoding:error_enc ~select:Option.some ~deselect:Fun.id
]
and security : EzAPI.Security.bearer list = [ `Bearer {EzAPI.Security.bearer_name="Bearer"; format=None} ]

type nonrec test_derive_input = {
foo: string;
bar: int;
}
and test_derive_output = int
[@@post {path="/test/getter"; debug}]
[@@post {path="/test/getter"}]

let%post echo_input = {
path="/echo_input"; raw_input=["text/plain"];
Expand Down

0 comments on commit 5bb6911

Please sign in to comment.