Skip to content

Commit

Permalink
ppx_deriving_err_case for type extension
Browse files Browse the repository at this point in the history
  • Loading branch information
maxtori committed Oct 4, 2024
1 parent 0d5153f commit fb18ff0
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 123 deletions.
8 changes: 8 additions & 0 deletions src/common/err.ml
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,11 @@ let get ~code l =
Json_encoding.case enc select deselect
) l in
Some (Json_encoding.union cases)

let merge_selects l e =
let rec aux = function
| [] -> None
| f :: tl -> match f e with
| Some e -> Some e
| None -> aux tl in
aux l
5 changes: 3 additions & 2 deletions src/ppx/dune
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,6 @@
(optional)
(modules ppx_deriving_err_case)
(preprocess (pps ppxlib.metaquot))
(kind ppx_deriver)
(libraries ppx_deriving_encoding.lib))
(kind ppx_rewriter)
(ppx_runtime_libraries ez_api)
(libraries ppx_deriving_encoding.lib ez_api))
106 changes: 0 additions & 106 deletions src/ppx/ppx_common.ml
Original file line number Diff line number Diff line change
Expand Up @@ -353,112 +353,6 @@ let deprecate =
Format.eprintf "deprecated: [@@@@@@%s ...] -> [%%%%%s ...]@." s s
| Some () -> ()

(* let rec impl ?kind str = *)
(* let rec pmod_impl pmod = match pmod.pmod_desc with *)
(* | Pmod_structure str -> {pmod with pmod_desc = Pmod_structure (impl ?kind str)} *)
(* | Pmod_functor (f, m) -> {pmod with pmod_desc = Pmod_functor (f, pmod_impl m)} *)
(* | Pmod_apply (m1, m2) -> {pmod with pmod_desc = Pmod_apply (pmod_impl m1, pmod_impl m2)} *)
(* | Pmod_constraint (m, mt) -> {pmod with pmod_desc = Pmod_constraint (pmod_impl m, mt)} *)
(* | _ -> pmod in *)
(* List.rev @@ *)
(* List.fold_left (fun acc str -> *)
(* match str.pstr_desc with *)
(* | Pstr_value (rflag, [ v ]) when kind <> Some `client -> *)
(* begin match List.partition (fun a -> List.mem a.attr_name.txt methods) v.pvb_attributes with *)
(* (\* service for handler *\) *)
(* | [ a ], pvb_attributes -> *)
(* begin match v.pvb_pat.ppat_desc with *)
(* | Ppat_var {txt=name;_} -> *)
(* let pvb_expr = handler_args v.pvb_expr in *)
(* let str = {str with pstr_desc = Pstr_value (rflag, [ {v with pvb_expr; pvb_attributes }])} in *)
(* (List.rev @@ process name a) @ str :: acc *)
(* | _ -> *)
(* str :: acc *)
(* end *)
(* (\* link service *\) *)
(* | [], attributes -> *)
(* begin match List.partition (fun a -> a.attr_name.txt = "service") attributes with *)
(* | [ a ], pvb_attributes -> *)
(* begin match v.pvb_pat.ppat_desc with *)
(* | Ppat_var {txt=name;_} -> *)
(* let pvb_expr = handler_args v.pvb_expr in *)
(* let str = {str with pstr_desc = Pstr_value (rflag, [ {v with pvb_expr; pvb_attributes }])} in *)
(* (List.rev @@ register name a) @ str :: acc *)
(* | _ -> str :: acc *)
(* end *)
(* | _ -> str :: acc *)
(* end *)
(* | _ -> str :: acc *)
(* end *)
(* | Pstr_value (rflag, (v_react :: v_bg :: onclose)) when kind <> Some `client -> *)
(* let attributes = match onclose with *)
(* | [] -> v_bg.pvb_attributes *)
(* | v :: _ -> v.pvb_attributes in *)
(* begin match List.partition (fun a -> a.attr_name.txt = "ws" || a.attr_name.txt = "websocket") attributes with *)
(* (\* service for websocket handlers *\) *)
(* | [ a ], pvb_attributes -> *)
(* begin match v_react.pvb_pat.ppat_desc, v_bg.pvb_pat.ppat_desc with *)
(* | Ppat_var {txt=name_react;_}, Ppat_var {txt=name_bg;_} -> *)
(* let pvb_expr_react = handler_args v_react.pvb_expr in *)
(* let pvb_expr_bg = handler_args v_bg.pvb_expr in *)
(* let pvb_attributes, vs = match onclose with *)
(* | [] -> pvb_attributes, [] *)
(* | v :: t -> v_bg.pvb_attributes, {v with pvb_attributes} :: t in *)
(* let str = {str with pstr_desc = Pstr_value (rflag, ( *)
(* {v_react with pvb_expr = pvb_expr_react } :: *)
(* {v_bg with pvb_expr = pvb_expr_bg; pvb_attributes } :: *)
(* vs )) } in *)
(* (List.rev @@ process_ws ~onclose name_react name_bg a) @ str :: acc *)
(* | _ -> str :: acc *)
(* end *)
(* (\* link websocket service *\) *)
(* | [], attributes -> *)
(* begin match List.partition (fun a -> a.attr_name.txt = "service") attributes with *)
(* | [ a ], pvb_attributes -> *)
(* begin match v_react.pvb_pat.ppat_desc, v_bg.pvb_pat.ppat_desc with *)
(* | Ppat_var {txt=name_react;_}, Ppat_var {txt=name_bg;_} -> *)
(* let pvb_expr_react = handler_args v_react.pvb_expr in *)
(* let pvb_expr_bg = handler_args v_bg.pvb_expr in *)
(* let pvb_attributes, vs = match onclose with *)
(* | [] -> pvb_attributes, [] *)
(* | v :: t -> v_bg.pvb_attributes, {v with pvb_attributes} :: t in *)
(* let str = {str with pstr_desc = Pstr_value (rflag, ( *)
(* {v_react with pvb_expr = pvb_expr_react } :: *)
(* {v_bg with pvb_expr = pvb_expr_bg; pvb_attributes } :: *)
(* vs )) } in *)
(* (List.rev @@ register_ws ~onclose name_react name_bg a) @ str :: acc *)
(* | _ -> str :: acc *)
(* end *)
(* | _ -> str :: acc *)
(* end *)
(* | _ -> str :: acc *)
(* end *)
(* (\* server main *\) *)
(* | Pstr_attribute a when a.attr_name.txt = "server" && kind = Some `server -> *)
(* deprecate "server"; *)
(* let loc = a.attr_loc in *)
(* let expr = server ~loc a.attr_payload in *)
(* pstr_value ~loc Nonrecursive [ value_binding ~loc ~pat:(punit ~loc) ~expr ] :: acc *)
(* | Pstr_extension (({txt="server"; loc}, p), _) when kind = Some `server -> *)
(* let expr = server ~loc p in *)
(* pstr_value ~loc Nonrecursive [ value_binding ~loc ~pat:(punit ~loc) ~expr ] :: acc *)
(* (\* client service *\) *)
(* | Pstr_attribute a when List.mem a.attr_name.txt methods -> *)
(* deprecate a.attr_name.txt; *)
(* let service, _, _ = service_value ~client:true ~meth:a.attr_name.txt ~loc:a.attr_loc a.attr_payload in *)
(* service :: acc *)
(* | Pstr_extension (({txt; loc}, PStr [ { pstr_desc = Pstr_value (_, [ { pvb_expr; pvb_pat= {ppat_desc=Ppat_var {txt=name; _}; _}; _} ]); _} ]), _) when List.mem txt methods -> *)
(* let service, _, _ = service_value ~name ~client:true ~meth:txt ~loc @@ PStr [ pstr_eval ~loc pvb_expr [] ] in *)
(* service :: acc *)
(* | Pstr_extension (({txt; loc}, p), _) when List.mem txt methods -> *)
(* let service, _, _ = service_value ~client:true ~meth:txt ~loc p in *)
(* service :: acc *)
(* | Pstr_module ({pmb_expr; _} as m) -> *)
(* {str with pstr_desc = Pstr_module {m with pmb_expr = pmod_impl pmb_expr}} :: acc *)
(* | _ -> str :: acc *)
(* ) [] str *)


let transform ?kind () =
object(self)
inherit Ast_traverse.map as super
Expand Down
157 changes: 142 additions & 15 deletions src/ppx/ppx_deriving_err_case.ml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ let row ?kind_label ~title prf =
match prf.prf_desc with
| Rtag ({txt; loc}, _, []) -> txt, mk ~loc ?kind_label ~title txt code
| Rtag ({txt; loc}, _, (h :: _)) ->
let enc = Encoding.core ~wrap:false h in
let enc = Encoding.core h in
txt, mk ~loc ~enc ?kind_label ~title txt code
| _ ->
Location.raise_errorf ~loc "inherit not handled"
Expand All @@ -63,29 +63,156 @@ let expressions ?kind_label ~title t =
let loc = t.ptype_loc in
match t.ptype_kind, t.ptype_manifest with
| Ptype_abstract, Some {ptyp_desc=Ptyp_variant (l, _, _); _} ->
List.map (row ?kind_label ~title) l
| _ -> Location.raise_errorf ~loc "error cases only from variants"
`variant (List.map (row ?kind_label ~title) l)
| Ptype_open, None -> `type_ext t.ptype_name.txt
| _ -> Location.raise_errorf ~loc "error cases only from variants and type extension"

let str_gen ~loc ~path:_ (rec_flag, l) debug title kind_label =
let str_gen ~loc ~path:_ (_rec_flag, l) debug title kind_label =
let l = List.map (fun t ->
let loc = t.ptype_loc in
let cases = expressions ?kind_label ~title t in
List.map (fun (name, expr) ->
let pat = ppat_constraint ~loc (pvar ~loc (String.lowercase_ascii name ^ "_case"))
[%type: [%t ptyp_constr ~loc (Utils.llid ~loc t.ptype_name.txt) []] EzAPI.Err.case] in
value_binding ~loc ~pat ~expr) cases) l in
let r = expressions ?kind_label ~title t in
match r with
| `variant cases ->
List.map (fun (name, expr) ->
let pat = ppat_constraint ~loc (pvar ~loc (String.lowercase_ascii name ^ "_case"))
[%type: [%t ptyp_constr ~loc (Utils.llid ~loc t.ptype_name.txt) []] EzAPI.Err.case] in
value_binding ~loc ~pat ~expr) cases
| `type_ext name ->
let t = ptyp_constr ~loc (Utils.llid ~loc name) [] in
let pat = [%pat? ([%p pvar ~loc ("_error_selects_" ^ name)] : (int * ([%t t] -> [%t t] option)) list ref)] in
let selects = value_binding ~loc ~pat ~expr:[%expr ref []] in
let pat = [%pat? ([%p pvar ~loc ("_error_cases_" ^ name)] : (int * [%t t] Json_encoding.case) list ref)] in
let cases = value_binding ~loc ~pat ~expr:[%expr ref []] in
[ selects; cases ]
) l in
let l = List.flatten l in
let rec_flag = if List.length l < 2 then Nonrecursive else rec_flag in
let s = [ pstr_value ~loc rec_flag l ] in
if debug then Format.printf "%s@." (Pprintast.string_of_structure s);
let s = [ pstr_value ~loc Nonrecursive l ] in
if debug then Format.printf "%a@." Pprintast.structure s;
s

let attribute_code ~code attrs =
let c = List.find_map (fun a -> match a.attr_name.txt, a.attr_payload with
| "code", PStr [ { pstr_desc = Pstr_eval ({ pexp_desc = Pexp_constant Pconst_integer (s, _); _ }, _); _ } ] ->
Some (int_of_string s)
| _ -> None) attrs in
match c, code with Some c, _ | _, Some c -> c | _ -> 500

let str_type_ext ~loc:_ ~path:_ t debug code =
let loc = t.ptyext_loc in
let name = Longident.name t.ptyext_path.txt in
let l = List.filter_map (fun pext ->
let loc = pext.pext_loc in
match pext.pext_kind with
| Pext_decl ([], args, None) ->
let code = attribute_code ~code pext.pext_attributes in
let case = Encoding.resolve_case ~loc @@ Encoding.constructor_label ~wrap:true ~case:`snake
~loc ~name:pext.pext_name.txt ~attrs:pext.pext_attributes args in
let select = pext.pext_name.txt, (match args with Pcstr_tuple [] -> false | _ -> true) in
Some (code, case, select)
| _ -> None
) t.ptyext_constructors in
let cases = elist ~loc @@ List.map (fun (code, case, _) ->
[%expr [%e eint ~loc code], [%e case]]) l in
let select_grouped = List.fold_left (fun acc (code, _, select) ->
match List.assoc_opt code acc with
| None -> acc @ [code, [ select ]]
| Some l -> (List.remove_assoc code acc) @ [ code, l @ [ select ] ]
) [] l in
let select_merged cons = pexp_function ~loc (
(List.map (fun (name, has_arg) ->
case ~guard:None
~lhs:(ppat_alias ~loc (ppat_construct ~loc (Utils.llid ~loc name) (if has_arg then Some [%pat? _] else None)) {txt="x"; loc})
~rhs:[%expr Some x]) cons) @ [
case ~guard:None ~lhs:[%pat? _] ~rhs:[%expr None]
]) in
let selects = elist ~loc @@ List.map (fun (code, cons) ->
[%expr [%e eint ~loc code], [%e select_merged cons] ]) select_grouped in
let cases_name = "_error_cases_" ^ name in
let selects_name = "_error_selects_" ^ name in
let expr = [%expr
[%e evar ~loc cases_name] := ![%e evar ~loc cases_name] @ [%e cases];
[%e evar ~loc selects_name] := ![%e evar ~loc selects_name] @ [%e selects];
] in
let s = [
pstr_value ~loc Nonrecursive [ value_binding ~loc ~pat:[%pat? ()] ~expr ]
] in
if debug then Format.printf "%a@." Pprintast.structure s;
s

let remove_spaces s =
let b = Bytes.create (String.length s) in
let n = String.fold_left (fun i -> function ' ' -> i | c -> Bytes.set b i c; i+1) 0 s in
Bytes.(to_string @@ sub b 0 n)

let type_ext_err_case ~loc ~typ ?(def=true)code =
match EzAPI.Error_codes.error code with
| None -> Location.raise_errorf ~loc "code is not standard"
| Some name ->
let enc = [%expr
Json_encoding.union @@ List.filter_map (fun (code, case) ->
if code = [%e eint ~loc code] then Some case else None) ![%e evar ~loc ("_error_cases_" ^ typ)]
] in
let enc =
if not def then enc
else [%expr Json_encoding.def [%e estring ~loc (remove_spaces name)] [%e enc]] in
[%expr
let select = EzAPI.Err.merge_selects @@ List.filter_map (fun (code, case) ->
if code = [%e eint ~loc code] then Some case else None) ![%e evar ~loc ("_error_selects_" ^ typ)] in
EzAPI.Err.make ~code:[%e eint ~loc code] ~name:[%e estring ~loc name]
~encoding:[%e enc] ~select ~deselect:Fun.id ]

let remove_poly c = match c.ptyp_desc with Ptyp_poly (_, c) -> c | _ -> c

let get_err_case_options ~loc l =
let code, debug, def = List.fold_left (fun (code, debug, def) (lid, e) -> match Longident.name lid.txt, e.pexp_desc with
| "code", Pexp_constant Pconst_integer (s, _) -> Some (int_of_string s), debug, def
| "debug", _ -> code, true, def
| "nodef", _ -> code, debug, false
| "def", Pexp_construct ({txt=Lident "false"; _}, None) -> code, debug, false
| s, _ -> Format.eprintf "%s option not understood@." s; code, debug, def
) (None, false, true) l in
match code with
| None -> Location.raise_errorf ~loc "code not found"
| Some code -> code, debug, def

let transform =
object
inherit Ast_traverse.map
method! structure_item it = match it.pstr_desc with
| Pstr_extension (({txt="err_case"; _}, PStr [{pstr_desc=Pstr_value (_, [ vb ]); pstr_loc=loc; _}]), _) ->
let typ, e, pat = match vb.pvb_expr.pexp_desc, vb.pvb_pat.ppat_desc with
| Pexp_constraint (e, typ), (Ppat_constraint ({ppat_desc=p; _}, _) | p) ->
remove_poly typ, e, { vb.pvb_pat with ppat_desc=p }
| _, Ppat_constraint (p, typ) ->
remove_poly typ, vb.pvb_expr, p
| _ -> Location.raise_errorf ~loc "no error type given to derive the error case" in
let code, debug, def = match e.pexp_desc with
| Pexp_constant Pconst_integer (s, _) -> int_of_string s, false, true
| Pexp_record (l, None) -> get_err_case_options ~loc:e.pexp_loc l
| _ -> Location.raise_errorf ~loc:e.pexp_loc "code not found" in
let typ = match typ.ptyp_desc with
| Ptyp_constr ({txt; _}, [])
| Ptyp_constr ({txt=(Ldot (Ldot (Lident "EzAPI", "Err"), "case") | Ldot (Lident "Err", "case")) ; _}, [
{ ptyp_desc = Ptyp_constr ({txt; _}, []); _ }
]) -> Longident.name txt
| _ -> Location.raise_errorf ~loc:typ.ptyp_loc "couldn't find type to derive error case" in
let expr = type_ext_err_case ~loc ~typ ~def code in
let it = pstr_value ~loc Nonrecursive [ value_binding ~loc ~pat ~expr ] in
if debug then Format.printf "%a@." Pprintast.structure_item it;
it
| _ -> it
end

let () =
let args_str = Deriving.Args.(
let open Deriving in
let args_str = Args.(
empty
+> flag "debug"
+> flag "title"
+> arg "kind_label" (estring __)
) in
let str_type_decl = Deriving.Generator.make args_str str_gen in
Deriving.ignore @@ Deriving.add "err_case" ~str_type_decl
let str_type_decl = Generator.make args_str str_gen in
let args_type_ext = Args.(empty +> flag "debug" +> arg "code" (eint __)) in
let str_type_ext = Generator.make args_type_ext str_type_ext in
ignore @@ add "err_case" ~str_type_decl ~str_type_ext;
Driver.register_transformation "err_case" ~impl:transform#structure
5 changes: 5 additions & 0 deletions test/ppx/dune
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,8 @@
(modules test_ppx_client)
(libraries test_ppx_lib ez_api.ixhr_lwt)
(modes js))

(library
(name test_ppx_err_case)
(modules test_ppx_err_case)
(preprocess (pps ez_api.ppx ez_api.ppx_err_case)))
10 changes: 10 additions & 0 deletions test/ppx/test_ppx_err_case.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
type error = .. [@@deriving err_case]

type error +=
| NotFound of string [@code 404]
| Unauthorized of string [@code 401]
| Error1
| Error2
[@@deriving err_case {code=400}]

let%err_case generic_case : error EzAPI.Err.case = 400

0 comments on commit fb18ff0

Please sign in to comment.