Skip to content

Commit

Permalink
better control of cors headers
Browse files Browse the repository at this point in the history
  • Loading branch information
maxtori committed Oct 1, 2024
1 parent d33f04c commit 3f55f22
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 61 deletions.
24 changes: 16 additions & 8 deletions src/server/cohttp/ezAPIServerCohttp.ml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ let set_debug () = Cohttp_lwt_unix.Debug.activate_debug ()

let register_ip req io time =
let open Conduit_lwt_unix in match io with
| Domain_socket _
| Vchan _ -> ()
| Domain_socket _ | Vchan _ | Tunnel _ -> ()
| TCP tcp ->
match[@warning "-42"] Lwt_unix.getpeername tcp.fd with
| Lwt_unix.ADDR_INET (ip,_port) ->
Expand Down Expand Up @@ -60,7 +59,8 @@ let debug_cohttp req =
(String.split_on_char ',' v))
(Request.headers req))

let dispatch ?catch s io req body =
let dispatch ?allow_origin ?allow_headers ?allow_methods ?allow_credentials
?catch s io req body =
let time = GMTime.time () in
register_ip req io time ;
debug_cohttp req;
Expand All @@ -80,24 +80,32 @@ let dispatch ?catch s io req body =
>>= function
| `ws (Ok ra) -> Lwt.return ra
| `ws (Error _) ->
let headers = Header.of_list default_access_control_headers in
let headers = Header.of_list @@
merge_headers_with_default ?allow_origin ?allow_headers ?allow_methods
?allow_credentials [] in
let status = Code.status_of_code 501 in
Server.respond_string ~headers ~status ~body:"" () >|= fun (r, b) ->
`Response (r, b)
| `http {Answer.code; body; headers} ->
let headers = merge_headers_with_default headers in
| `http {Answer.code; body; headers=resp_headers} ->
let origin = match allow_origin with
| Some `origin -> StringMap.find_opt "origin" headers
| _ -> None in
let headers = merge_headers_with_default ?allow_origin ?allow_headers
?allow_methods ?allow_credentials ?origin resp_headers in
let status = Code.status_of_code code in
debug ~v:(if code >= 200 && code < 300 then 1 else 0) "Reply computed to %S: %d" path_str code;
debug ~v:3 "Reply content:\n %s" body;
let headers = Header.of_list headers in
Server.respond_string ~headers ~status ~body () >|= fun (r, b) ->
`Response (r, b)

let create_server ?catch server_port server_kind =
let create_server ?catch ?allow_origin ?allow_headers ?allow_methods
?allow_credentials server_port server_kind =
let s = { server_port; server_kind } in
Timings.init (GMTime.time ()) @@ Doc.nservices ();
ignore @@ Doc.all_services_registered ();
let callback conn req body = dispatch ?catch s (fst conn) req body in
let callback conn req body = dispatch ?allow_origin ?allow_headers
?allow_methods ?allow_credentials ?catch s (fst conn) req body in
let on_exn = function
| Unix.Unix_error (Unix.EPIPE, _, _) -> ()
| exn -> EzDebug.printf "Server Error: %s" (Printexc.to_string exn) in
Expand Down
67 changes: 40 additions & 27 deletions src/server/ezAPIServerUtils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -151,30 +151,43 @@ let handle ?meth ?content_type ?ws s r path body =
| Some ws -> ws ?onclose ?step ~react ~bg r.Req.req_id
end >|= fun ra -> `ws ra

(* Default access control headers *)
let default_access_control_headers = [
"access-control-allow-origin", "*";
"access-control-allow-headers", "accept, content-type"
]

(* merge headers correctly with default one *)
let merge_headers_with_default headers : (string * string) list =
(* combining existing headers *)
let l = List.fold_left
(fun acc ((hn,hv) as h) ->
match List.assoc_opt hn default_access_control_headers with
| None -> h::acc
| Some _ when hn = "access-control-allow-origin" ->
h::acc
| Some v when hn = "access-control-allow-headers" ->
(hn, hv ^ "," ^ v)::acc
| _ -> acc)
[]
headers
in
(* Adding default if not present *)
List.fold_left (fun acc ((hn,_) as h) ->
match List.assoc_opt hn l with
| None -> h::acc
| _ -> acc
) l default_access_control_headers
type allow_kind = [ `all | `default | `custom of string list ]
type allow_kind_with_none = [ `all | `default | `custom of string list ]

let merge_headers_allow ~dft ~key headers = function
| `none -> headers
| #allow_kind as k ->
let v old =
match k, old with
| `all, _ -> "*"
(* restrict headers if former ones are * *)
| `default, None | `default, Some "*" -> dft
| `custom l, None | `custom l, Some "*" -> String.concat "," l
| `default, Some old -> old ^ "," ^ dft
| `custom l, Some old -> String.concat "," (old :: l) in
match List.assoc_opt key headers with
| None -> headers @ [ key, v None ]
| Some old -> List.remove_assoc key headers @ [ key, v (Some old) ]

let merge_headers_allow_origin ?origin headers kind =
let key = "access-control-allow-origin" in
match kind with
| `none -> headers
| `origin -> (match origin with None -> headers | Some o -> headers @ [ key, String.concat "," o ])
| `all | `default -> List.remove_assoc key headers @ [ key, "*" ]
| `custom l -> match List.assoc_opt key headers with
| None -> headers @ [ key, String.concat "," l ]
| Some "*" -> (List.remove_assoc key headers) @ [ key, String.concat "," l ]
| Some v -> (List.remove_assoc key headers) @ [ key, String.concat "," (v :: l) ]

let merge_headers_with_default ?(allow_origin=`default) ?(allow_headers=`default) ?(allow_methods=`default)
?allow_credentials ?origin headers =
let headers = merge_headers_allow_origin ?origin headers allow_origin in
let headers = merge_headers_allow ~dft:"accept,content-type" ~key:"access-control-allow-headers" headers allow_headers in
let headers = merge_headers_allow ~dft:"*" ~key:"access-control-allow-methods" headers allow_methods in
let key = "access-control-allow-credentials" in
match allow_credentials with
| None -> headers
| Some b -> match List.assoc_opt key headers with
| None -> headers @ [ key, string_of_bool b ]
| Some _ -> List.remove_assoc key headers @ [ key, string_of_bool b ]
51 changes: 25 additions & 26 deletions src/server/httpaf/ezAPIServerHttpAf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,8 @@ let register_ip req time = function
Ip.register time ip
| Unix.ADDR_UNIX _ -> ()

let connection_handler :
?catch:(string -> exn -> string Answer.t Lwt.t) ->
server -> Unix.sockaddr -> Lwt_unix.file_descr -> unit Lwt.t =
fun ?catch s sockaddr fd ->
let connection_handler ?allow_origin ?allow_headers ?allow_methods
?allow_credentials ?catch s sockaddr fd =
let request_handler sockaddr reqd =
let req = Reqd.request reqd in
let time = GMTime.time () in
Expand All @@ -266,17 +264,23 @@ let connection_handler :
| Some c -> c path_str exn >|= fun a -> `http a)
>>= function
| `ws (Error _) ->
let headers = Headers.of_list default_access_control_headers in
let headers = Headers.of_list @@
merge_headers_with_default ?allow_origin ?allow_headers ?allow_methods
?allow_credentials [] in
let status = Status.unsafe_of_code 501 in
let response = Response.create ~headers status in
Reqd.respond_with_string reqd response "";
Lwt.return_unit
| `ws (Ok (_response, _b)) ->
Lwt.return_unit
| `http {Answer.code; body; headers} ->
| `http {Answer.code; body; headers=resp_headers} ->
let status = Status.unsafe_of_code code in
debug ~v:(if code = 200 then 1 else 0) "Reply computed to %S: %d" path_str code;
let headers = merge_headers_with_default headers in
let origin = match allow_origin with
| Some `origin -> StringMap.find_opt "origin" headers
| _ -> None in
let headers = merge_headers_with_default ?allow_origin ?allow_headers
?allow_methods ?allow_credentials ?origin resp_headers in
let headers = Headers.of_list headers in
let len = String.length body in
let headers = Headers.add headers "content-length" (string_of_int len) in
Expand All @@ -285,23 +289,16 @@ let connection_handler :
Lwt.return_unit
in

let error_handler :
Unix.sockaddr ->
?request:Httpaf.Request.t ->
_ ->
(Headers.t -> [`write] Body.t) ->
unit =
fun _client_address ?request:_ error start_response ->

let response_body = start_response Headers.empty in
begin match error with
| `Exn exn ->
Body.write_string response_body (Printexc.to_string exn);
Body.write_string response_body "\n";
| #Status.standard as error ->
Body.write_string response_body (Status.default_reason_phrase error)
end;
Body.flush response_body (fun () -> Body.close_writer response_body)
let error_handler _client_address ?request:_ error start_response =
let response_body = start_response Headers.empty in
begin match error with
| `Exn exn ->
Body.write_string response_body (Printexc.to_string exn);
Body.write_string response_body "\n";
| #Status.standard as error ->
Body.write_string response_body (Status.default_reason_phrase error)
end;
Body.flush response_body (fun () -> Body.close_writer response_body)
in

Httpaf_lwt_unix.Server.create_connection_handler
Expand All @@ -311,7 +308,8 @@ let connection_handler :
sockaddr
fd

let create_server ?catch ~max_connections server_port server_kind =
let create_server ?catch ?allow_origin ?allow_headers ?allow_methods ?allow_credentials
~max_connections server_port server_kind =
let s = { server_port; server_kind } in
Timings.init (GMTime.time ()) @@ Doc.nservices ();
ignore @@ Doc.all_services_registered ();
Expand All @@ -320,7 +318,8 @@ let create_server ?catch ~max_connections server_port server_kind =
establish_server_with_client_socket
~nb_max_connections:max_connections
listen_address (fun sockaddr fd ->
connection_handler ?catch s sockaddr fd) >>= fun _server ->
connection_handler ?catch ?allow_origin ?allow_headers ?allow_methods
?allow_credentials s sockaddr fd) >>= fun _server ->
Lwt.return_unit

let server ?catch servers =
Expand Down

0 comments on commit 3f55f22

Please sign in to comment.