From fb18ff03f834c8deeaa3efea104de08f4216cef5 Mon Sep 17 00:00:00 2001 From: Maxime Levillain Date: Fri, 4 Oct 2024 18:50:22 +0200 Subject: [PATCH] ppx_deriving_err_case for type extension --- src/common/err.ml | 8 ++ src/ppx/dune | 5 +- src/ppx/ppx_common.ml | 106 --------------------- src/ppx/ppx_deriving_err_case.ml | 157 ++++++++++++++++++++++++++++--- test/ppx/dune | 5 + test/ppx/test_ppx_err_case.ml | 10 ++ 6 files changed, 168 insertions(+), 123 deletions(-) create mode 100644 test/ppx/test_ppx_err_case.ml diff --git a/src/common/err.ml b/src/common/err.ml index d8738c8..1c6df60 100644 --- a/src/common/err.ml +++ b/src/common/err.ml @@ -88,3 +88,11 @@ let get ~code l = Json_encoding.case enc select deselect ) l in Some (Json_encoding.union cases) + +let merge_selects l e = + let rec aux = function + | [] -> None + | f :: tl -> match f e with + | Some e -> Some e + | None -> aux tl in + aux l diff --git a/src/ppx/dune b/src/ppx/dune index d2601e7..9c5a5d7 100644 --- a/src/ppx/dune +++ b/src/ppx/dune @@ -39,5 +39,6 @@ (optional) (modules ppx_deriving_err_case) (preprocess (pps ppxlib.metaquot)) - (kind ppx_deriver) - (libraries ppx_deriving_encoding.lib)) + (kind ppx_rewriter) + (ppx_runtime_libraries ez_api) + (libraries ppx_deriving_encoding.lib ez_api)) diff --git a/src/ppx/ppx_common.ml b/src/ppx/ppx_common.ml index c9e94f7..5f1dbd5 100644 --- a/src/ppx/ppx_common.ml +++ b/src/ppx/ppx_common.ml @@ -353,112 +353,6 @@ let deprecate = Format.eprintf "deprecated: [@@@@@@%s ...] -> [%%%%%s ...]@." s s | Some () -> () -(* let rec impl ?kind str = *) -(* let rec pmod_impl pmod = match pmod.pmod_desc with *) -(* | Pmod_structure str -> {pmod with pmod_desc = Pmod_structure (impl ?kind str)} *) -(* | Pmod_functor (f, m) -> {pmod with pmod_desc = Pmod_functor (f, pmod_impl m)} *) -(* | Pmod_apply (m1, m2) -> {pmod with pmod_desc = Pmod_apply (pmod_impl m1, pmod_impl m2)} *) -(* | Pmod_constraint (m, mt) -> {pmod with pmod_desc = Pmod_constraint (pmod_impl m, mt)} *) -(* | _ -> pmod in *) -(* List.rev @@ *) -(* List.fold_left (fun acc str -> *) -(* match str.pstr_desc with *) -(* | Pstr_value (rflag, [ v ]) when kind <> Some `client -> *) -(* begin match List.partition (fun a -> List.mem a.attr_name.txt methods) v.pvb_attributes with *) -(* (\* service for handler *\) *) -(* | [ a ], pvb_attributes -> *) -(* begin match v.pvb_pat.ppat_desc with *) -(* | Ppat_var {txt=name;_} -> *) -(* let pvb_expr = handler_args v.pvb_expr in *) -(* let str = {str with pstr_desc = Pstr_value (rflag, [ {v with pvb_expr; pvb_attributes }])} in *) -(* (List.rev @@ process name a) @ str :: acc *) -(* | _ -> *) -(* str :: acc *) -(* end *) -(* (\* link service *\) *) -(* | [], attributes -> *) -(* begin match List.partition (fun a -> a.attr_name.txt = "service") attributes with *) -(* | [ a ], pvb_attributes -> *) -(* begin match v.pvb_pat.ppat_desc with *) -(* | Ppat_var {txt=name;_} -> *) -(* let pvb_expr = handler_args v.pvb_expr in *) -(* let str = {str with pstr_desc = Pstr_value (rflag, [ {v with pvb_expr; pvb_attributes }])} in *) -(* (List.rev @@ register name a) @ str :: acc *) -(* | _ -> str :: acc *) -(* end *) -(* | _ -> str :: acc *) -(* end *) -(* | _ -> str :: acc *) -(* end *) -(* | Pstr_value (rflag, (v_react :: v_bg :: onclose)) when kind <> Some `client -> *) -(* let attributes = match onclose with *) -(* | [] -> v_bg.pvb_attributes *) -(* | v :: _ -> v.pvb_attributes in *) -(* begin match List.partition (fun a -> a.attr_name.txt = "ws" || a.attr_name.txt = "websocket") attributes with *) -(* (\* service for websocket handlers *\) *) -(* | [ a ], pvb_attributes -> *) -(* begin match v_react.pvb_pat.ppat_desc, v_bg.pvb_pat.ppat_desc with *) -(* | Ppat_var {txt=name_react;_}, Ppat_var {txt=name_bg;_} -> *) -(* let pvb_expr_react = handler_args v_react.pvb_expr in *) -(* let pvb_expr_bg = handler_args v_bg.pvb_expr in *) -(* let pvb_attributes, vs = match onclose with *) -(* | [] -> pvb_attributes, [] *) -(* | v :: t -> v_bg.pvb_attributes, {v with pvb_attributes} :: t in *) -(* let str = {str with pstr_desc = Pstr_value (rflag, ( *) -(* {v_react with pvb_expr = pvb_expr_react } :: *) -(* {v_bg with pvb_expr = pvb_expr_bg; pvb_attributes } :: *) -(* vs )) } in *) -(* (List.rev @@ process_ws ~onclose name_react name_bg a) @ str :: acc *) -(* | _ -> str :: acc *) -(* end *) -(* (\* link websocket service *\) *) -(* | [], attributes -> *) -(* begin match List.partition (fun a -> a.attr_name.txt = "service") attributes with *) -(* | [ a ], pvb_attributes -> *) -(* begin match v_react.pvb_pat.ppat_desc, v_bg.pvb_pat.ppat_desc with *) -(* | Ppat_var {txt=name_react;_}, Ppat_var {txt=name_bg;_} -> *) -(* let pvb_expr_react = handler_args v_react.pvb_expr in *) -(* let pvb_expr_bg = handler_args v_bg.pvb_expr in *) -(* let pvb_attributes, vs = match onclose with *) -(* | [] -> pvb_attributes, [] *) -(* | v :: t -> v_bg.pvb_attributes, {v with pvb_attributes} :: t in *) -(* let str = {str with pstr_desc = Pstr_value (rflag, ( *) -(* {v_react with pvb_expr = pvb_expr_react } :: *) -(* {v_bg with pvb_expr = pvb_expr_bg; pvb_attributes } :: *) -(* vs )) } in *) -(* (List.rev @@ register_ws ~onclose name_react name_bg a) @ str :: acc *) -(* | _ -> str :: acc *) -(* end *) -(* | _ -> str :: acc *) -(* end *) -(* | _ -> str :: acc *) -(* end *) -(* (\* server main *\) *) -(* | Pstr_attribute a when a.attr_name.txt = "server" && kind = Some `server -> *) -(* deprecate "server"; *) -(* let loc = a.attr_loc in *) -(* let expr = server ~loc a.attr_payload in *) -(* pstr_value ~loc Nonrecursive [ value_binding ~loc ~pat:(punit ~loc) ~expr ] :: acc *) -(* | Pstr_extension (({txt="server"; loc}, p), _) when kind = Some `server -> *) -(* let expr = server ~loc p in *) -(* pstr_value ~loc Nonrecursive [ value_binding ~loc ~pat:(punit ~loc) ~expr ] :: acc *) -(* (\* client service *\) *) -(* | Pstr_attribute a when List.mem a.attr_name.txt methods -> *) -(* deprecate a.attr_name.txt; *) -(* let service, _, _ = service_value ~client:true ~meth:a.attr_name.txt ~loc:a.attr_loc a.attr_payload in *) -(* service :: acc *) -(* | Pstr_extension (({txt; loc}, PStr [ { pstr_desc = Pstr_value (_, [ { pvb_expr; pvb_pat= {ppat_desc=Ppat_var {txt=name; _}; _}; _} ]); _} ]), _) when List.mem txt methods -> *) -(* let service, _, _ = service_value ~name ~client:true ~meth:txt ~loc @@ PStr [ pstr_eval ~loc pvb_expr [] ] in *) -(* service :: acc *) -(* | Pstr_extension (({txt; loc}, p), _) when List.mem txt methods -> *) -(* let service, _, _ = service_value ~client:true ~meth:txt ~loc p in *) -(* service :: acc *) -(* | Pstr_module ({pmb_expr; _} as m) -> *) -(* {str with pstr_desc = Pstr_module {m with pmb_expr = pmod_impl pmb_expr}} :: acc *) -(* | _ -> str :: acc *) -(* ) [] str *) - - let transform ?kind () = object(self) inherit Ast_traverse.map as super diff --git a/src/ppx/ppx_deriving_err_case.ml b/src/ppx/ppx_deriving_err_case.ml index 5d4c671..7542412 100644 --- a/src/ppx/ppx_deriving_err_case.ml +++ b/src/ppx/ppx_deriving_err_case.ml @@ -54,7 +54,7 @@ let row ?kind_label ~title prf = match prf.prf_desc with | Rtag ({txt; loc}, _, []) -> txt, mk ~loc ?kind_label ~title txt code | Rtag ({txt; loc}, _, (h :: _)) -> - let enc = Encoding.core ~wrap:false h in + let enc = Encoding.core h in txt, mk ~loc ~enc ?kind_label ~title txt code | _ -> Location.raise_errorf ~loc "inherit not handled" @@ -63,29 +63,156 @@ let expressions ?kind_label ~title t = let loc = t.ptype_loc in match t.ptype_kind, t.ptype_manifest with | Ptype_abstract, Some {ptyp_desc=Ptyp_variant (l, _, _); _} -> - List.map (row ?kind_label ~title) l - | _ -> Location.raise_errorf ~loc "error cases only from variants" + `variant (List.map (row ?kind_label ~title) l) + | Ptype_open, None -> `type_ext t.ptype_name.txt + | _ -> Location.raise_errorf ~loc "error cases only from variants and type extension" -let str_gen ~loc ~path:_ (rec_flag, l) debug title kind_label = +let str_gen ~loc ~path:_ (_rec_flag, l) debug title kind_label = let l = List.map (fun t -> let loc = t.ptype_loc in - let cases = expressions ?kind_label ~title t in - List.map (fun (name, expr) -> - let pat = ppat_constraint ~loc (pvar ~loc (String.lowercase_ascii name ^ "_case")) - [%type: [%t ptyp_constr ~loc (Utils.llid ~loc t.ptype_name.txt) []] EzAPI.Err.case] in - value_binding ~loc ~pat ~expr) cases) l in + let r = expressions ?kind_label ~title t in + match r with + | `variant cases -> + List.map (fun (name, expr) -> + let pat = ppat_constraint ~loc (pvar ~loc (String.lowercase_ascii name ^ "_case")) + [%type: [%t ptyp_constr ~loc (Utils.llid ~loc t.ptype_name.txt) []] EzAPI.Err.case] in + value_binding ~loc ~pat ~expr) cases + | `type_ext name -> + let t = ptyp_constr ~loc (Utils.llid ~loc name) [] in + let pat = [%pat? ([%p pvar ~loc ("_error_selects_" ^ name)] : (int * ([%t t] -> [%t t] option)) list ref)] in + let selects = value_binding ~loc ~pat ~expr:[%expr ref []] in + let pat = [%pat? ([%p pvar ~loc ("_error_cases_" ^ name)] : (int * [%t t] Json_encoding.case) list ref)] in + let cases = value_binding ~loc ~pat ~expr:[%expr ref []] in + [ selects; cases ] + ) l in let l = List.flatten l in - let rec_flag = if List.length l < 2 then Nonrecursive else rec_flag in - let s = [ pstr_value ~loc rec_flag l ] in - if debug then Format.printf "%s@." (Pprintast.string_of_structure s); + let s = [ pstr_value ~loc Nonrecursive l ] in + if debug then Format.printf "%a@." Pprintast.structure s; s +let attribute_code ~code attrs = + let c = List.find_map (fun a -> match a.attr_name.txt, a.attr_payload with + | "code", PStr [ { pstr_desc = Pstr_eval ({ pexp_desc = Pexp_constant Pconst_integer (s, _); _ }, _); _ } ] -> + Some (int_of_string s) + | _ -> None) attrs in + match c, code with Some c, _ | _, Some c -> c | _ -> 500 + +let str_type_ext ~loc:_ ~path:_ t debug code = + let loc = t.ptyext_loc in + let name = Longident.name t.ptyext_path.txt in + let l = List.filter_map (fun pext -> + let loc = pext.pext_loc in + match pext.pext_kind with + | Pext_decl ([], args, None) -> + let code = attribute_code ~code pext.pext_attributes in + let case = Encoding.resolve_case ~loc @@ Encoding.constructor_label ~wrap:true ~case:`snake + ~loc ~name:pext.pext_name.txt ~attrs:pext.pext_attributes args in + let select = pext.pext_name.txt, (match args with Pcstr_tuple [] -> false | _ -> true) in + Some (code, case, select) + | _ -> None + ) t.ptyext_constructors in + let cases = elist ~loc @@ List.map (fun (code, case, _) -> + [%expr [%e eint ~loc code], [%e case]]) l in + let select_grouped = List.fold_left (fun acc (code, _, select) -> + match List.assoc_opt code acc with + | None -> acc @ [code, [ select ]] + | Some l -> (List.remove_assoc code acc) @ [ code, l @ [ select ] ] + ) [] l in + let select_merged cons = pexp_function ~loc ( + (List.map (fun (name, has_arg) -> + case ~guard:None + ~lhs:(ppat_alias ~loc (ppat_construct ~loc (Utils.llid ~loc name) (if has_arg then Some [%pat? _] else None)) {txt="x"; loc}) + ~rhs:[%expr Some x]) cons) @ [ + case ~guard:None ~lhs:[%pat? _] ~rhs:[%expr None] + ]) in + let selects = elist ~loc @@ List.map (fun (code, cons) -> + [%expr [%e eint ~loc code], [%e select_merged cons] ]) select_grouped in + let cases_name = "_error_cases_" ^ name in + let selects_name = "_error_selects_" ^ name in + let expr = [%expr + [%e evar ~loc cases_name] := ![%e evar ~loc cases_name] @ [%e cases]; + [%e evar ~loc selects_name] := ![%e evar ~loc selects_name] @ [%e selects]; + ] in + let s = [ + pstr_value ~loc Nonrecursive [ value_binding ~loc ~pat:[%pat? ()] ~expr ] + ] in + if debug then Format.printf "%a@." Pprintast.structure s; + s + +let remove_spaces s = + let b = Bytes.create (String.length s) in + let n = String.fold_left (fun i -> function ' ' -> i | c -> Bytes.set b i c; i+1) 0 s in + Bytes.(to_string @@ sub b 0 n) + +let type_ext_err_case ~loc ~typ ?(def=true)code = + match EzAPI.Error_codes.error code with + | None -> Location.raise_errorf ~loc "code is not standard" + | Some name -> + let enc = [%expr + Json_encoding.union @@ List.filter_map (fun (code, case) -> + if code = [%e eint ~loc code] then Some case else None) ![%e evar ~loc ("_error_cases_" ^ typ)] + ] in + let enc = + if not def then enc + else [%expr Json_encoding.def [%e estring ~loc (remove_spaces name)] [%e enc]] in + [%expr + let select = EzAPI.Err.merge_selects @@ List.filter_map (fun (code, case) -> + if code = [%e eint ~loc code] then Some case else None) ![%e evar ~loc ("_error_selects_" ^ typ)] in + EzAPI.Err.make ~code:[%e eint ~loc code] ~name:[%e estring ~loc name] + ~encoding:[%e enc] ~select ~deselect:Fun.id ] + +let remove_poly c = match c.ptyp_desc with Ptyp_poly (_, c) -> c | _ -> c + +let get_err_case_options ~loc l = + let code, debug, def = List.fold_left (fun (code, debug, def) (lid, e) -> match Longident.name lid.txt, e.pexp_desc with + | "code", Pexp_constant Pconst_integer (s, _) -> Some (int_of_string s), debug, def + | "debug", _ -> code, true, def + | "nodef", _ -> code, debug, false + | "def", Pexp_construct ({txt=Lident "false"; _}, None) -> code, debug, false + | s, _ -> Format.eprintf "%s option not understood@." s; code, debug, def + ) (None, false, true) l in + match code with + | None -> Location.raise_errorf ~loc "code not found" + | Some code -> code, debug, def + +let transform = + object + inherit Ast_traverse.map + method! structure_item it = match it.pstr_desc with + | Pstr_extension (({txt="err_case"; _}, PStr [{pstr_desc=Pstr_value (_, [ vb ]); pstr_loc=loc; _}]), _) -> + let typ, e, pat = match vb.pvb_expr.pexp_desc, vb.pvb_pat.ppat_desc with + | Pexp_constraint (e, typ), (Ppat_constraint ({ppat_desc=p; _}, _) | p) -> + remove_poly typ, e, { vb.pvb_pat with ppat_desc=p } + | _, Ppat_constraint (p, typ) -> + remove_poly typ, vb.pvb_expr, p + | _ -> Location.raise_errorf ~loc "no error type given to derive the error case" in + let code, debug, def = match e.pexp_desc with + | Pexp_constant Pconst_integer (s, _) -> int_of_string s, false, true + | Pexp_record (l, None) -> get_err_case_options ~loc:e.pexp_loc l + | _ -> Location.raise_errorf ~loc:e.pexp_loc "code not found" in + let typ = match typ.ptyp_desc with + | Ptyp_constr ({txt; _}, []) + | Ptyp_constr ({txt=(Ldot (Ldot (Lident "EzAPI", "Err"), "case") | Ldot (Lident "Err", "case")) ; _}, [ + { ptyp_desc = Ptyp_constr ({txt; _}, []); _ } + ]) -> Longident.name txt + | _ -> Location.raise_errorf ~loc:typ.ptyp_loc "couldn't find type to derive error case" in + let expr = type_ext_err_case ~loc ~typ ~def code in + let it = pstr_value ~loc Nonrecursive [ value_binding ~loc ~pat ~expr ] in + if debug then Format.printf "%a@." Pprintast.structure_item it; + it + | _ -> it + end + let () = - let args_str = Deriving.Args.( + let open Deriving in + let args_str = Args.( empty +> flag "debug" +> flag "title" +> arg "kind_label" (estring __) ) in - let str_type_decl = Deriving.Generator.make args_str str_gen in - Deriving.ignore @@ Deriving.add "err_case" ~str_type_decl + let str_type_decl = Generator.make args_str str_gen in + let args_type_ext = Args.(empty +> flag "debug" +> arg "code" (eint __)) in + let str_type_ext = Generator.make args_type_ext str_type_ext in + ignore @@ add "err_case" ~str_type_decl ~str_type_ext; + Driver.register_transformation "err_case" ~impl:transform#structure diff --git a/test/ppx/dune b/test/ppx/dune index ac7199d..f536fa1 100644 --- a/test/ppx/dune +++ b/test/ppx/dune @@ -14,3 +14,8 @@ (modules test_ppx_client) (libraries test_ppx_lib ez_api.ixhr_lwt) (modes js)) + +(library + (name test_ppx_err_case) + (modules test_ppx_err_case) + (preprocess (pps ez_api.ppx ez_api.ppx_err_case))) diff --git a/test/ppx/test_ppx_err_case.ml b/test/ppx/test_ppx_err_case.ml new file mode 100644 index 0000000..c78e7bc --- /dev/null +++ b/test/ppx/test_ppx_err_case.ml @@ -0,0 +1,10 @@ +type error = .. [@@deriving err_case] + +type error += + | NotFound of string [@code 404] + | Unauthorized of string [@code 401] + | Error1 + | Error2 +[@@deriving err_case {code=400}] + +let%err_case generic_case : error EzAPI.Err.case = 400