Skip to content

Commit a4304c9

Browse files
committed
ppx_deriving_err_case for type extension
1 parent da5f2b1 commit a4304c9

File tree

6 files changed

+168
-123
lines changed

6 files changed

+168
-123
lines changed

src/common/err.ml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,11 @@ let get ~code l =
8888
Json_encoding.case enc select deselect
8989
) l in
9090
Some (Json_encoding.union cases)
91+
92+
let merge_selects l e =
93+
let rec aux = function
94+
| [] -> None
95+
| f :: tl -> match f e with
96+
| Some e -> Some e
97+
| None -> aux tl in
98+
aux l

src/ppx/dune

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,6 @@
3939
(optional)
4040
(modules ppx_deriving_err_case)
4141
(preprocess (pps ppxlib.metaquot))
42-
(kind ppx_deriver)
43-
(libraries ppx_deriving_encoding.lib))
42+
(kind ppx_rewriter)
43+
(ppx_runtime_libraries ez_api)
44+
(libraries ppx_deriving_encoding.lib ez_api))

src/ppx/ppx_common.ml

Lines changed: 0 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -353,112 +353,6 @@ let deprecate =
353353
Format.eprintf "deprecated: [@@@@@@%s ...] -> [%%%%%s ...]@." s s
354354
| Some () -> ()
355355

356-
(* let rec impl ?kind str = *)
357-
(* let rec pmod_impl pmod = match pmod.pmod_desc with *)
358-
(* | Pmod_structure str -> {pmod with pmod_desc = Pmod_structure (impl ?kind str)} *)
359-
(* | Pmod_functor (f, m) -> {pmod with pmod_desc = Pmod_functor (f, pmod_impl m)} *)
360-
(* | Pmod_apply (m1, m2) -> {pmod with pmod_desc = Pmod_apply (pmod_impl m1, pmod_impl m2)} *)
361-
(* | Pmod_constraint (m, mt) -> {pmod with pmod_desc = Pmod_constraint (pmod_impl m, mt)} *)
362-
(* | _ -> pmod in *)
363-
(* List.rev @@ *)
364-
(* List.fold_left (fun acc str -> *)
365-
(* match str.pstr_desc with *)
366-
(* | Pstr_value (rflag, [ v ]) when kind <> Some `client -> *)
367-
(* begin match List.partition (fun a -> List.mem a.attr_name.txt methods) v.pvb_attributes with *)
368-
(* (\* service for handler *\) *)
369-
(* | [ a ], pvb_attributes -> *)
370-
(* begin match v.pvb_pat.ppat_desc with *)
371-
(* | Ppat_var {txt=name;_} -> *)
372-
(* let pvb_expr = handler_args v.pvb_expr in *)
373-
(* let str = {str with pstr_desc = Pstr_value (rflag, [ {v with pvb_expr; pvb_attributes }])} in *)
374-
(* (List.rev @@ process name a) @ str :: acc *)
375-
(* | _ -> *)
376-
(* str :: acc *)
377-
(* end *)
378-
(* (\* link service *\) *)
379-
(* | [], attributes -> *)
380-
(* begin match List.partition (fun a -> a.attr_name.txt = "service") attributes with *)
381-
(* | [ a ], pvb_attributes -> *)
382-
(* begin match v.pvb_pat.ppat_desc with *)
383-
(* | Ppat_var {txt=name;_} -> *)
384-
(* let pvb_expr = handler_args v.pvb_expr in *)
385-
(* let str = {str with pstr_desc = Pstr_value (rflag, [ {v with pvb_expr; pvb_attributes }])} in *)
386-
(* (List.rev @@ register name a) @ str :: acc *)
387-
(* | _ -> str :: acc *)
388-
(* end *)
389-
(* | _ -> str :: acc *)
390-
(* end *)
391-
(* | _ -> str :: acc *)
392-
(* end *)
393-
(* | Pstr_value (rflag, (v_react :: v_bg :: onclose)) when kind <> Some `client -> *)
394-
(* let attributes = match onclose with *)
395-
(* | [] -> v_bg.pvb_attributes *)
396-
(* | v :: _ -> v.pvb_attributes in *)
397-
(* begin match List.partition (fun a -> a.attr_name.txt = "ws" || a.attr_name.txt = "websocket") attributes with *)
398-
(* (\* service for websocket handlers *\) *)
399-
(* | [ a ], pvb_attributes -> *)
400-
(* begin match v_react.pvb_pat.ppat_desc, v_bg.pvb_pat.ppat_desc with *)
401-
(* | Ppat_var {txt=name_react;_}, Ppat_var {txt=name_bg;_} -> *)
402-
(* let pvb_expr_react = handler_args v_react.pvb_expr in *)
403-
(* let pvb_expr_bg = handler_args v_bg.pvb_expr in *)
404-
(* let pvb_attributes, vs = match onclose with *)
405-
(* | [] -> pvb_attributes, [] *)
406-
(* | v :: t -> v_bg.pvb_attributes, {v with pvb_attributes} :: t in *)
407-
(* let str = {str with pstr_desc = Pstr_value (rflag, ( *)
408-
(* {v_react with pvb_expr = pvb_expr_react } :: *)
409-
(* {v_bg with pvb_expr = pvb_expr_bg; pvb_attributes } :: *)
410-
(* vs )) } in *)
411-
(* (List.rev @@ process_ws ~onclose name_react name_bg a) @ str :: acc *)
412-
(* | _ -> str :: acc *)
413-
(* end *)
414-
(* (\* link websocket service *\) *)
415-
(* | [], attributes -> *)
416-
(* begin match List.partition (fun a -> a.attr_name.txt = "service") attributes with *)
417-
(* | [ a ], pvb_attributes -> *)
418-
(* begin match v_react.pvb_pat.ppat_desc, v_bg.pvb_pat.ppat_desc with *)
419-
(* | Ppat_var {txt=name_react;_}, Ppat_var {txt=name_bg;_} -> *)
420-
(* let pvb_expr_react = handler_args v_react.pvb_expr in *)
421-
(* let pvb_expr_bg = handler_args v_bg.pvb_expr in *)
422-
(* let pvb_attributes, vs = match onclose with *)
423-
(* | [] -> pvb_attributes, [] *)
424-
(* | v :: t -> v_bg.pvb_attributes, {v with pvb_attributes} :: t in *)
425-
(* let str = {str with pstr_desc = Pstr_value (rflag, ( *)
426-
(* {v_react with pvb_expr = pvb_expr_react } :: *)
427-
(* {v_bg with pvb_expr = pvb_expr_bg; pvb_attributes } :: *)
428-
(* vs )) } in *)
429-
(* (List.rev @@ register_ws ~onclose name_react name_bg a) @ str :: acc *)
430-
(* | _ -> str :: acc *)
431-
(* end *)
432-
(* | _ -> str :: acc *)
433-
(* end *)
434-
(* | _ -> str :: acc *)
435-
(* end *)
436-
(* (\* server main *\) *)
437-
(* | Pstr_attribute a when a.attr_name.txt = "server" && kind = Some `server -> *)
438-
(* deprecate "server"; *)
439-
(* let loc = a.attr_loc in *)
440-
(* let expr = server ~loc a.attr_payload in *)
441-
(* pstr_value ~loc Nonrecursive [ value_binding ~loc ~pat:(punit ~loc) ~expr ] :: acc *)
442-
(* | Pstr_extension (({txt="server"; loc}, p), _) when kind = Some `server -> *)
443-
(* let expr = server ~loc p in *)
444-
(* pstr_value ~loc Nonrecursive [ value_binding ~loc ~pat:(punit ~loc) ~expr ] :: acc *)
445-
(* (\* client service *\) *)
446-
(* | Pstr_attribute a when List.mem a.attr_name.txt methods -> *)
447-
(* deprecate a.attr_name.txt; *)
448-
(* let service, _, _ = service_value ~client:true ~meth:a.attr_name.txt ~loc:a.attr_loc a.attr_payload in *)
449-
(* service :: acc *)
450-
(* | Pstr_extension (({txt; loc}, PStr [ { pstr_desc = Pstr_value (_, [ { pvb_expr; pvb_pat= {ppat_desc=Ppat_var {txt=name; _}; _}; _} ]); _} ]), _) when List.mem txt methods -> *)
451-
(* let service, _, _ = service_value ~name ~client:true ~meth:txt ~loc @@ PStr [ pstr_eval ~loc pvb_expr [] ] in *)
452-
(* service :: acc *)
453-
(* | Pstr_extension (({txt; loc}, p), _) when List.mem txt methods -> *)
454-
(* let service, _, _ = service_value ~client:true ~meth:txt ~loc p in *)
455-
(* service :: acc *)
456-
(* | Pstr_module ({pmb_expr; _} as m) -> *)
457-
(* {str with pstr_desc = Pstr_module {m with pmb_expr = pmod_impl pmb_expr}} :: acc *)
458-
(* | _ -> str :: acc *)
459-
(* ) [] str *)
460-
461-
462356
let transform ?kind () =
463357
object(self)
464358
inherit Ast_traverse.map as super

src/ppx/ppx_deriving_err_case.ml

Lines changed: 142 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ let row ?kind_label ~title prf =
5454
match prf.prf_desc with
5555
| Rtag ({txt; loc}, _, []) -> txt, mk ~loc ?kind_label ~title txt code
5656
| Rtag ({txt; loc}, _, (h :: _)) ->
57-
let enc = Encoding.core ~wrap:false h in
57+
let enc = Encoding.core h in
5858
txt, mk ~loc ~enc ?kind_label ~title txt code
5959
| _ ->
6060
Location.raise_errorf ~loc "inherit not handled"
@@ -63,29 +63,156 @@ let expressions ?kind_label ~title t =
6363
let loc = t.ptype_loc in
6464
match t.ptype_kind, t.ptype_manifest with
6565
| Ptype_abstract, Some {ptyp_desc=Ptyp_variant (l, _, _); _} ->
66-
List.map (row ?kind_label ~title) l
67-
| _ -> Location.raise_errorf ~loc "error cases only from variants"
66+
`variant (List.map (row ?kind_label ~title) l)
67+
| Ptype_open, None -> `type_ext t.ptype_name.txt
68+
| _ -> Location.raise_errorf ~loc "error cases only from variants and type extension"
6869

69-
let str_gen ~loc ~path:_ (rec_flag, l) debug title kind_label =
70+
let str_gen ~loc ~path:_ (_rec_flag, l) debug title kind_label =
7071
let l = List.map (fun t ->
7172
let loc = t.ptype_loc in
72-
let cases = expressions ?kind_label ~title t in
73-
List.map (fun (name, expr) ->
74-
let pat = ppat_constraint ~loc (pvar ~loc (String.lowercase_ascii name ^ "_case"))
75-
[%type: [%t ptyp_constr ~loc (Utils.llid ~loc t.ptype_name.txt) []] EzAPI.Err.case] in
76-
value_binding ~loc ~pat ~expr) cases) l in
73+
let r = expressions ?kind_label ~title t in
74+
match r with
75+
| `variant cases ->
76+
List.map (fun (name, expr) ->
77+
let pat = ppat_constraint ~loc (pvar ~loc (String.lowercase_ascii name ^ "_case"))
78+
[%type: [%t ptyp_constr ~loc (Utils.llid ~loc t.ptype_name.txt) []] EzAPI.Err.case] in
79+
value_binding ~loc ~pat ~expr) cases
80+
| `type_ext name ->
81+
let t = ptyp_constr ~loc (Utils.llid ~loc name) [] in
82+
let pat = [%pat? ([%p pvar ~loc ("_error_selects_" ^ name)] : (int * ([%t t] -> [%t t] option)) list ref)] in
83+
let selects = value_binding ~loc ~pat ~expr:[%expr ref []] in
84+
let pat = [%pat? ([%p pvar ~loc ("_error_cases_" ^ name)] : (int * [%t t] Json_encoding.case) list ref)] in
85+
let cases = value_binding ~loc ~pat ~expr:[%expr ref []] in
86+
[ selects; cases ]
87+
) l in
7788
let l = List.flatten l in
78-
let rec_flag = if List.length l < 2 then Nonrecursive else rec_flag in
79-
let s = [ pstr_value ~loc rec_flag l ] in
80-
if debug then Format.printf "%s@." (Pprintast.string_of_structure s);
89+
let s = [ pstr_value ~loc Nonrecursive l ] in
90+
if debug then Format.printf "%a@." Pprintast.structure s;
8191
s
8292

93+
let attribute_code ~code attrs =
94+
let c = List.find_map (fun a -> match a.attr_name.txt, a.attr_payload with
95+
| "code", PStr [ { pstr_desc = Pstr_eval ({ pexp_desc = Pexp_constant Pconst_integer (s, _); _ }, _); _ } ] ->
96+
Some (int_of_string s)
97+
| _ -> None) attrs in
98+
match c, code with Some c, _ | _, Some c -> c | _ -> 500
99+
100+
let str_type_ext ~loc:_ ~path:_ t debug code =
101+
let loc = t.ptyext_loc in
102+
let name = Longident.name t.ptyext_path.txt in
103+
let l = List.filter_map (fun pext ->
104+
let loc = pext.pext_loc in
105+
match pext.pext_kind with
106+
| Pext_decl ([], args, None) ->
107+
let code = attribute_code ~code pext.pext_attributes in
108+
let case = Encoding.resolve_case ~loc @@ Encoding.constructor_label ~wrap:true ~case:`snake
109+
~loc ~name:pext.pext_name.txt ~attrs:pext.pext_attributes args in
110+
let select = pext.pext_name.txt, (match args with Pcstr_tuple [] -> false | _ -> true) in
111+
Some (code, case, select)
112+
| _ -> None
113+
) t.ptyext_constructors in
114+
let cases = elist ~loc @@ List.map (fun (code, case, _) ->
115+
[%expr [%e eint ~loc code], [%e case]]) l in
116+
let select_grouped = List.fold_left (fun acc (code, _, select) ->
117+
match List.assoc_opt code acc with
118+
| None -> acc @ [code, [ select ]]
119+
| Some l -> (List.remove_assoc code acc) @ [ code, l @ [ select ] ]
120+
) [] l in
121+
let select_merged cons = pexp_function ~loc (
122+
(List.map (fun (name, has_arg) ->
123+
case ~guard:None
124+
~lhs:(ppat_alias ~loc (ppat_construct ~loc (Utils.llid ~loc name) (if has_arg then Some [%pat? _] else None)) {txt="x"; loc})
125+
~rhs:[%expr Some x]) cons) @ [
126+
case ~guard:None ~lhs:[%pat? _] ~rhs:[%expr None]
127+
]) in
128+
let selects = elist ~loc @@ List.map (fun (code, cons) ->
129+
[%expr [%e eint ~loc code], [%e select_merged cons] ]) select_grouped in
130+
let cases_name = "_error_cases_" ^ name in
131+
let selects_name = "_error_selects_" ^ name in
132+
let expr = [%expr
133+
[%e evar ~loc cases_name] := ![%e evar ~loc cases_name] @ [%e cases];
134+
[%e evar ~loc selects_name] := ![%e evar ~loc selects_name] @ [%e selects];
135+
] in
136+
let s = [
137+
pstr_value ~loc Nonrecursive [ value_binding ~loc ~pat:[%pat? ()] ~expr ]
138+
] in
139+
if debug then Format.printf "%a@." Pprintast.structure s;
140+
s
141+
142+
let remove_spaces s =
143+
let b = Bytes.create (String.length s) in
144+
let n = String.fold_left (fun i -> function ' ' -> i | c -> Bytes.set b i c; i+1) 0 s in
145+
Bytes.(to_string @@ sub b 0 n)
146+
147+
let type_ext_err_case ~loc ~typ ?(def=true)code =
148+
match EzAPI.Error_codes.error code with
149+
| None -> Location.raise_errorf ~loc "code is not standard"
150+
| Some name ->
151+
let enc = [%expr
152+
Json_encoding.union @@ List.filter_map (fun (code, case) ->
153+
if code = [%e eint ~loc code] then Some case else None) ![%e evar ~loc ("_error_cases_" ^ typ)]
154+
] in
155+
let enc =
156+
if not def then enc
157+
else [%expr Json_encoding.def [%e estring ~loc (remove_spaces name)] [%e enc]] in
158+
[%expr
159+
let select = EzAPI.Err.merge_selects @@ List.filter_map (fun (code, case) ->
160+
if code = [%e eint ~loc code] then Some case else None) ![%e evar ~loc ("_error_selects_" ^ typ)] in
161+
EzAPI.Err.make ~code:[%e eint ~loc code] ~name:[%e estring ~loc name]
162+
~encoding:[%e enc] ~select ~deselect:Fun.id ]
163+
164+
let remove_poly c = match c.ptyp_desc with Ptyp_poly (_, c) -> c | _ -> c
165+
166+
let get_err_case_options ~loc l =
167+
let code, debug, def = List.fold_left (fun (code, debug, def) (lid, e) -> match Longident.name lid.txt, e.pexp_desc with
168+
| "code", Pexp_constant Pconst_integer (s, _) -> Some (int_of_string s), debug, def
169+
| "debug", _ -> code, true, def
170+
| "nodef", _ -> code, debug, false
171+
| "def", Pexp_construct ({txt=Lident "false"; _}, None) -> code, debug, false
172+
| s, _ -> Format.eprintf "%s option not understood@." s; code, debug, def
173+
) (None, false, true) l in
174+
match code with
175+
| None -> Location.raise_errorf ~loc "code not found"
176+
| Some code -> code, debug, def
177+
178+
let transform =
179+
object
180+
inherit Ast_traverse.map
181+
method! structure_item it = match it.pstr_desc with
182+
| Pstr_extension (({txt="err_case"; _}, PStr [{pstr_desc=Pstr_value (_, [ vb ]); pstr_loc=loc; _}]), _) ->
183+
let typ, e, pat = match vb.pvb_expr.pexp_desc, vb.pvb_pat.ppat_desc with
184+
| Pexp_constraint (e, typ), (Ppat_constraint ({ppat_desc=p; _}, _) | p) ->
185+
remove_poly typ, e, { vb.pvb_pat with ppat_desc=p }
186+
| _, Ppat_constraint (p, typ) ->
187+
remove_poly typ, vb.pvb_expr, p
188+
| _ -> Location.raise_errorf ~loc "no error type given to derive the error case" in
189+
let code, debug, def = match e.pexp_desc with
190+
| Pexp_constant Pconst_integer (s, _) -> int_of_string s, false, true
191+
| Pexp_record (l, None) -> get_err_case_options ~loc:e.pexp_loc l
192+
| _ -> Location.raise_errorf ~loc:e.pexp_loc "code not found" in
193+
let typ = match typ.ptyp_desc with
194+
| Ptyp_constr ({txt; _}, [])
195+
| Ptyp_constr ({txt=(Ldot (Ldot (Lident "EzAPI", "Err"), "case") | Ldot (Lident "Err", "case")) ; _}, [
196+
{ ptyp_desc = Ptyp_constr ({txt; _}, []); _ }
197+
]) -> Longident.name txt
198+
| _ -> Location.raise_errorf ~loc:typ.ptyp_loc "couldn't find type to derive error case" in
199+
let expr = type_ext_err_case ~loc ~typ ~def code in
200+
let it = pstr_value ~loc Nonrecursive [ value_binding ~loc ~pat ~expr ] in
201+
if debug then Format.printf "%a@." Pprintast.structure_item it;
202+
it
203+
| _ -> it
204+
end
205+
83206
let () =
84-
let args_str = Deriving.Args.(
207+
let open Deriving in
208+
let args_str = Args.(
85209
empty
86210
+> flag "debug"
87211
+> flag "title"
88212
+> arg "kind_label" (estring __)
89213
) in
90-
let str_type_decl = Deriving.Generator.make args_str str_gen in
91-
Deriving.ignore @@ Deriving.add "err_case" ~str_type_decl
214+
let str_type_decl = Generator.make args_str str_gen in
215+
let args_type_ext = Args.(empty +> flag "debug" +> arg "code" (eint __)) in
216+
let str_type_ext = Generator.make args_type_ext str_type_ext in
217+
ignore @@ add "err_case" ~str_type_decl ~str_type_ext;
218+
Driver.register_transformation "err_case" ~impl:transform#structure

test/ppx/dune

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,8 @@
1414
(modules test_ppx_client)
1515
(libraries test_ppx_lib ez_api.ixhr_lwt)
1616
(modes js))
17+
18+
(library
19+
(name test_ppx_err_case)
20+
(modules test_ppx_err_case)
21+
(preprocess (pps ez_api.ppx ez_api.ppx_err_case)))

test/ppx/test_ppx_err_case.ml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
type error = .. [@@deriving err_case]
2+
3+
type error +=
4+
| NotFound of string [@code 404]
5+
| Unauthorized of string [@code 401]
6+
| Error1
7+
| Error2
8+
[@@deriving err_case {code=400}]
9+
10+
let%err_case generic_case : error EzAPI.Err.case = 400

0 commit comments

Comments
 (0)