From 757d120ed51cdc8a3eed4988b44fd7c3161061b9 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Fri, 2 Dec 2022 17:36:02 +0100 Subject: [PATCH] adapt to new happy eyeballs (#329) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * adapt to new happy eyeballs: handle cancellation Co-authored-by: Reynir Björnsson --- dns-client.opam | 2 +- lwt/client/dns_client_lwt.ml | 92 +++++++++----- mirage/client/dns_client_mirage.ml | 185 +++++++++++++++++------------ 3 files changed, 168 insertions(+), 111 deletions(-) diff --git a/dns-client.opam b/dns-client.opam index 9ae3620ef..a524691c9 100644 --- a/dns-client.opam +++ b/dns-client.opam @@ -29,7 +29,7 @@ depends: [ "mirage-clock" {>= "3.0.0"} "mtime" {>= "1.2.0"} "mirage-crypto-rng" {>= "0.8.0"} - "happy-eyeballs" {>= "0.1.0"} + "happy-eyeballs" {>= "0.4.0"} "alcotest" {with-test} "tls" {>= "0.15.0"} "tls-mirage" {>= "0.15.0"} diff --git a/lwt/client/dns_client_lwt.ml b/lwt/client/dns_client_lwt.ml index 05c3f0590..217849a69 100644 --- a/lwt/client/dns_client_lwt.ml +++ b/lwt/client/dns_client_lwt.ml @@ -27,6 +27,7 @@ module Transport : Dns_client.S mutable connected_condition : unit Lwt_condition.t option ; mutable requests : (Cstruct.t * (Cstruct.t, [ `Msg of string ]) result Lwt_condition.t) IM.t ; mutable he : Happy_eyeballs.t ; + mutable cancel_connecting : unit Lwt.u Happy_eyeballs.Waiter_map.t; mutable waiters : ((Ipaddr.t * int) * Lwt_unix.file_descr, [ `Msg of string ]) result Lwt.u Happy_eyeballs.Waiter_map.t ; timer_condition : unit Lwt_condition.t ; } @@ -55,45 +56,71 @@ module Transport : Dns_client.S let close_socket fd = Lwt.catch (fun () -> Lwt_unix.close fd) (fun _ -> Lwt.return_unit) - let rec handle_action t action = - (match action with - | Happy_eyeballs.Connect (host, id, (ip, port)) -> - Lwt_unix.(getprotobyname "tcp" >|= fun x -> x.p_proto) >>= fun proto_number -> - let fam = - Ipaddr.(Lwt_unix.(match ip with V4 _ -> PF_INET | V6 _ -> PF_INET6)) - in - let socket = Lwt_unix.socket fam Lwt_unix.SOCK_STREAM proto_number in - let addr = Lwt_unix.ADDR_INET (Ipaddr_unix.to_inet_addr ip, port) in - Lwt.catch (fun () -> - Lwt_unix.connect socket addr >>= fun () -> - let waiters, r = Happy_eyeballs.Waiter_map.find_and_remove id t.waiters in - t.waiters <- waiters; - begin match r with - | Some waiter -> Lwt.wakeup_later waiter (Ok ((ip, port), socket)); Lwt.return_unit - | None -> close_socket socket - end >|= fun () -> - Some (Happy_eyeballs.Connected (host, id, (ip, port)))) - (fun e -> - Log.err (fun m -> m "connection to %a:%d failed: %s" Ipaddr.pp ip port - (Printexc.to_string e)); - close_socket socket >|= fun () -> - Some (Happy_eyeballs.Connection_failed (host, id, (ip, port)))) - | Connect_failed (_host, id) -> + let handle_one_action t = function + | Happy_eyeballs.Connect (host, id, (ip, port)) -> + let cancelled, cancel = Lwt.task () in + t.cancel_connecting <- Happy_eyeballs.Waiter_map.add id cancel t.cancel_connecting; + Lwt_unix.(getprotobyname "tcp" >|= fun x -> x.p_proto) >>= fun proto_number -> + let fam = + Ipaddr.(Lwt_unix.(match ip with V4 _ -> PF_INET | V6 _ -> PF_INET6)) + in + let socket = Lwt_unix.socket fam Lwt_unix.SOCK_STREAM proto_number in + let addr = Lwt_unix.ADDR_INET (Ipaddr_unix.to_inet_addr ip, port) in + Lwt.pick [ + Lwt.try_bind + (fun () -> Lwt_unix.connect socket addr) + Lwt.return_ok + (fun e -> + let err = + Fmt.str "error %s connecting to nameserver %a:%d" + (Printexc.to_string e) Ipaddr.pp ip port + in + Lwt.return (Error (`Msg err))); + (cancelled >|= fun () -> Error (`Msg "cancelled")); + ] >>= fun r -> + t.cancel_connecting <- Happy_eyeballs.Waiter_map.remove id t.cancel_connecting; + begin match r with + | Ok () -> + let waiters, r = Happy_eyeballs.Waiter_map.find_and_remove id t.waiters in + t.waiters <- waiters; + begin match r with + | Some waiter -> + Lwt.wakeup_later waiter (Ok ((ip, port), socket)); + Lwt.return_unit + | None -> close_socket socket + end >|= fun () -> + Some (Happy_eyeballs.Connected (host, id, (ip, port))) + | Error `Msg err -> + close_socket socket >|= fun () -> + Some (Happy_eyeballs.Connection_failed (host, id, (ip, port), err)) + end + | Connect_failed (host, id, reason) -> let waiters, r = Happy_eyeballs.Waiter_map.find_and_remove id t.waiters in t.waiters <- waiters; begin match r with - | Some waiter -> Lwt.wakeup_later waiter (Error (`Msg "connection failed")) + | Some waiter -> + let err = + Fmt.str "connection to %a failed: %s" Domain_name.pp host reason + in + Lwt.wakeup_later waiter (Error (`Msg err)) | None -> () end; Lwt.return None - | a -> + | Connect_cancelled (_host, id) -> + (match Happy_eyeballs.Waiter_map.find_opt id t.cancel_connecting with + | None -> Lwt.return_none + | Some cancel -> Lwt.wakeup cancel (); Lwt.return_none) + | Resolve_a _ | Resolve_aaaa _ as a -> Log.warn (fun m -> m "ignoring action %a" Happy_eyeballs.pp_action a); - Lwt.return None) >>= function - | None -> Lwt.return_unit - | Some event -> - let he, actions = Happy_eyeballs.event t.he (clock ()) event in - t.he <- he; - Lwt_list.iter_p (handle_action t) actions + Lwt.return None + + let rec handle_action t action = + handle_one_action t action >>= function + | None -> Lwt.return_unit + | Some event -> + let he, actions = Happy_eyeballs.event t.he (clock ()) event in + t.he <- he; + Lwt_list.iter_p (handle_action t) actions let handle_timer_actions t actions = Lwt.async (fun () -> Lwt_list.iter_p (fun a -> handle_action t a) actions) @@ -208,6 +235,7 @@ module Transport : Dns_client.S connected_condition = None ; requests = IM.empty ; he = Happy_eyeballs.create (clock ()) ; + cancel_connecting = Happy_eyeballs.Waiter_map.empty ; waiters = Happy_eyeballs.Waiter_map.empty ; timer_condition = Lwt_condition.create () ; } in diff --git a/mirage/client/dns_client_mirage.ml b/mirage/client/dns_client_mirage.ml index 6ef578c97..5c10260e7 100644 --- a/mirage/client/dns_client_mirage.ml +++ b/mirage/client/dns_client_mirage.ml @@ -66,58 +66,58 @@ The format of a nameserver is: let nameserver_of_string str = let ( let* ) = Result.bind in begin match String.split_on_char ':' str with - | "tls" :: rest -> - let str = String.concat ":" rest in - ( match String.split_on_char '!' str with - | [ nameserver ] -> - let* ipaddr, port = Ipaddr.with_port_of_string ~default:853 nameserver in - let* authenticator = CA.authenticator () in - let tls = Tls.Config.client ~authenticator () in - Ok (`Tcp, `Tls (tls, ipaddr, port)) - | nameserver :: opt_hostname :: authenticator -> - let* ipaddr, port = Ipaddr.with_port_of_string ~default:853 nameserver in - let peer_name, data = - match - let* dn = Domain_name.of_string opt_hostname in - Domain_name.host dn - with - | Ok hostname -> Some hostname, String.concat "!" authenticator - | Error _ -> None, String.concat "!" (opt_hostname :: authenticator) - in - let* authenticator = - if data = "" then - CA.authenticator () - else - let* a = X509.Authenticator.of_string data in - Ok (a (fun () -> Some (Ptime.v (P.now_d_ps ())))) - in - let tls = Tls.Config.client ~authenticator ?peer_name () in - Ok (`Tcp, `Tls (tls, ipaddr, port)) - | [] -> assert false ) - | "tcp" :: nameserver -> - let str = String.concat ":" nameserver in - let* ipaddr, port = Ipaddr.with_port_of_string ~default:53 str in - Ok (`Tcp, `Plaintext (ipaddr, port)) - | "udp" :: nameserver -> - let str = String.concat ":" nameserver in - let* ipaddr, port = Ipaddr.with_port_of_string ~default:53 str in - Ok (`Udp, `Plaintext (ipaddr, port)) - | _ -> - Error (`Msg ("Unable to decode nameserver " ^ str)) - end |> Result.map_error (function `Msg e -> `Msg (e ^ format)) + | "tls" :: rest -> + let str = String.concat ":" rest in + ( match String.split_on_char '!' str with + | [ nameserver ] -> + let* ipaddr, port = Ipaddr.with_port_of_string ~default:853 nameserver in + let* authenticator = CA.authenticator () in + let tls = Tls.Config.client ~authenticator () in + Ok (`Tcp, `Tls (tls, ipaddr, port)) + | nameserver :: opt_hostname :: authenticator -> + let* ipaddr, port = Ipaddr.with_port_of_string ~default:853 nameserver in + let peer_name, data = + match + let* dn = Domain_name.of_string opt_hostname in + Domain_name.host dn + with + | Ok hostname -> Some hostname, String.concat "!" authenticator + | Error _ -> None, String.concat "!" (opt_hostname :: authenticator) + in + let* authenticator = + if data = "" then + CA.authenticator () + else + let* a = X509.Authenticator.of_string data in + Ok (a (fun () -> Some (Ptime.v (P.now_d_ps ())))) + in + let tls = Tls.Config.client ~authenticator ?peer_name () in + Ok (`Tcp, `Tls (tls, ipaddr, port)) + | [] -> assert false ) + | "tcp" :: nameserver -> + let str = String.concat ":" nameserver in + let* ipaddr, port = Ipaddr.with_port_of_string ~default:53 str in + Ok (`Tcp, `Plaintext (ipaddr, port)) + | "udp" :: nameserver -> + let str = String.concat ":" nameserver in + let* ipaddr, port = Ipaddr.with_port_of_string ~default:53 str in + Ok (`Udp, `Plaintext (ipaddr, port)) + | _ -> + Error (`Msg ("Unable to decode nameserver " ^ str)) + end |> Result.map_error (function `Msg e -> `Msg (e ^ format)) module Transport : Dns_client.S with type stack = S.t and type +'a io = 'a Lwt.t and type io_addr = [ - | `Plaintext of Ipaddr.t * int - | `Tls of Tls.Config.client * Ipaddr.t * int - ] = struct + | `Plaintext of Ipaddr.t * int + | `Tls of Tls.Config.client * Ipaddr.t * int + ] = struct type stack = S.t type io_addr = [ - | `Plaintext of Ipaddr.t * int - | `Tls of Tls.Config.client * Ipaddr.t * int - ] + | `Plaintext of Ipaddr.t * int + | `Tls of Tls.Config.client * Ipaddr.t * int + ] type +'a io = 'a Lwt.t module IS = Set.Make(Int) type t = { @@ -130,6 +130,7 @@ The format of a nameserver is: mutable connected_condition : unit Lwt_condition.t option ; mutable requests : (Cstruct.t * (Cstruct.t, [ `Msg of string ]) result Lwt_condition.t) IM.t ; mutable he : Happy_eyeballs.t ; + mutable cancel_connecting : unit Lwt.u Happy_eyeballs.Waiter_map.t ; mutable waiters : ((Ipaddr.t * int) * S.TCP.flow, [ `Msg of string ]) result Lwt.u Happy_eyeballs.Waiter_map.t ; timer_condition : unit Lwt_condition.t ; } @@ -138,40 +139,67 @@ The format of a nameserver is: let clock = M.elapsed_ns let he_timer_interval = Duration.of_ms 500 + let handle_one_action t = function + | Happy_eyeballs.Connect (host, id, addr) -> + let cancelled, cancel = Lwt.task () in + t.cancel_connecting <- Happy_eyeballs.Waiter_map.add id cancel t.cancel_connecting; + Lwt.pick [ + begin + S.TCP.create_connection (S.tcp t.stack) addr >>= function + | Error e -> + let err = + Fmt.str "error connecting to nameserver %a: %a" + Ipaddr.pp (fst addr) S.TCP.pp_error e + in + Lwt.return_error (`Msg err) + | Ok flow -> + Lwt.return_ok flow + end; + begin + cancelled >|= fun () -> Error (`Msg "cancelled") + end; + ] >>= fun r -> + begin match r with + | Ok flow -> + let waiters, r = Happy_eyeballs.Waiter_map.find_and_remove id t.waiters in + t.waiters <- waiters; + begin match r with + | Some waiter -> + Lwt.wakeup_later waiter (Ok (addr, flow)); + Lwt.return_unit + | None -> S.TCP.close flow + end >|= fun () -> + Some (Happy_eyeballs.Connected (host, id, addr)) + | Error `Msg err -> + Lwt.return (Some (Happy_eyeballs.Connection_failed (host, id, addr, err))) + end + | Connect_failed (host, id, reason) -> + let waiters, r = Happy_eyeballs.Waiter_map.find_and_remove id t.waiters in + t.waiters <- waiters; + begin match r with + | Some waiter -> + let err = + Fmt.str "connection to %a failed: %s" Domain_name.pp host reason + in + Lwt.wakeup_later waiter (Error (`Msg err)) + | None -> () + end; + Lwt.return None + | Connect_cancelled (_host, id) -> + (match Happy_eyeballs.Waiter_map.find_opt id t.cancel_connecting with + | None -> Lwt.return_none + | Some cancel -> Lwt.wakeup cancel (); Lwt.return_none) + | Resolve_a _ | Resolve_aaaa _ as a -> + Log.warn (fun m -> m "ignoring action %a" Happy_eyeballs.pp_action a); + Lwt.return None + let rec handle_action t action = - (match action with - | Happy_eyeballs.Connect (host, id, addr) -> - begin - S.TCP.create_connection (S.tcp t.stack) addr >>= function - | Error e -> - Log.err (fun m -> m "error connecting to nameserver %a: %a" - Ipaddr.pp (fst addr) S.TCP.pp_error e) ; - Lwt.return (Some (Happy_eyeballs.Connection_failed (host, id, addr))) - | Ok flow -> - let waiters, r = Happy_eyeballs.Waiter_map.find_and_remove id t.waiters in - t.waiters <- waiters; - begin match r with - | Some waiter -> Lwt.wakeup_later waiter (Ok (addr, flow)); Lwt.return_unit - | None -> S.TCP.close flow - end >|= fun () -> - Some (Happy_eyeballs.Connected (host, id, addr)) - end - | Connect_failed (_host, id) -> - let waiters, r = Happy_eyeballs.Waiter_map.find_and_remove id t.waiters in - t.waiters <- waiters; - begin match r with - | Some waiter -> Lwt.wakeup_later waiter (Error (`Msg "connection failed")) - | None -> () - end; - Lwt.return None - | a -> - Log.warn (fun m -> m "ignoring action %a" Happy_eyeballs.pp_action a); - Lwt.return None) >>= function - | None -> Lwt.return_unit - | Some event -> - let he, actions = Happy_eyeballs.event t.he (clock ()) event in - t.he <- he; - Lwt_list.iter_p (handle_action t) actions + handle_one_action t action >>= function + | None -> Lwt.return_unit + | Some event -> + let he, actions = Happy_eyeballs.event t.he (clock ()) event in + t.he <- he; + Lwt_list.iter_p (handle_action t) actions let handle_timer_actions t actions = Lwt.async (fun () -> Lwt_list.iter_p (fun a -> handle_action t a) actions) @@ -243,6 +271,7 @@ The format of a nameserver is: connected_condition = None ; requests = IM.empty ; he = Happy_eyeballs.create (clock ()) ; + cancel_connecting = Happy_eyeballs.Waiter_map.empty ; waiters = Happy_eyeballs.Waiter_map.empty ; timer_condition = Lwt_condition.create () ; } in