From 8e863fc5fce3788871b0a7e2b832f460fee19775 Mon Sep 17 00:00:00 2001 From: Robur Date: Tue, 19 Oct 2021 16:08:47 +0000 Subject: [PATCH] dns-resolver.mirage: add dns-over-tls support --- CHANGES.md | 1 + dns-resolver.opam | 1 + mirage/resolver/dns_resolver_mirage.ml | 98 ++++++++++++++++++++++--- mirage/resolver/dns_resolver_mirage.mli | 11 +-- mirage/resolver/dune | 2 +- 5 files changed, 98 insertions(+), 15 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index cf84457f9..3a3cdeafb 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -40,6 +40,7 @@ /etc/resolv.conf sequentially (lwt and mirage) (#269 @reynir and @hannesm) * BREAKING dns-client remove UDP support from lwt (#270 @reynir and @hannesm) +* BREAKING dns-resolver.mirage add DNS-over-TLS support (@reynir @hannesm) * BREAKING dns-resolver remove "mode" from codebase, default to recursive (a stub resolver is available as dns-stub) (#260 @hannesm) * dns-resolver: use dns.cache instead of copy in Dns_resolver_cache diff --git a/dns-resolver.opam b/dns-resolver.opam index 806c73271..993051fd0 100644 --- a/dns-resolver.opam +++ b/dns-resolver.opam @@ -22,6 +22,7 @@ depends: [ "mirage-random" {>= "2.0.0"} "mirage-stack" {>= "2.0.0"} "alcotest" {with-test} + "tls" "tls-mirage" ] build: [ diff --git a/mirage/resolver/dns_resolver_mirage.ml b/mirage/resolver/dns_resolver_mirage.ml index 80a9a8773..3ad881b86 100644 --- a/mirage/resolver/dns_resolver_mirage.ml +++ b/mirage/resolver/dns_resolver_mirage.ml @@ -11,6 +11,10 @@ module Make (R : Mirage_random.S) (P : Mirage_clock.PCLOCK) (M : Mirage_clock.MC module T = S.TCP + module TLS = Tls_mirage.Make(T) + + type tls_flow = { tls_flow : TLS.flow ; mutable linger : Cstruct.t } + module FM = Map.Make(struct type t = Ipaddr.t * int let compare (ip, p) (ip', p') = @@ -19,13 +23,24 @@ module Make (R : Mirage_random.S) (P : Mirage_clock.PCLOCK) (M : Mirage_clock.MC | x -> x end) - let resolver stack ?(root = false) ?(timer = 500) ?(port = 53) t = + let resolver stack ?(root = false) ?(timer = 500) ?(udp = true) ?(tcp = true) ?tls ?(port = 53) ?(tls_port = 853) t = (* according to RFC5452 4.5, we can chose source port between 1024-49152 *) let sport () = 1024 + Randomconv.int ~bound:48128 R.generate in let state = ref t in let tcp_in = ref FM.empty in let tcp_out = ref Ipaddr.Map.empty in + let send_tls flow data = + let len = Cstruct.create 2 in + Cstruct.BE.set_uint16 len 0 (Cstruct.length data); + TLS.writev flow [len; data] >>= function + | Ok () -> Lwt.return (Ok ()) + | Error e -> + Log.err (fun m -> m "tls error %a while writing" TLS.pp_write_error e); + TLS.close flow >|= fun () -> + Error () + in + let rec client_out dst port = T.create_connection (S.tcp stack) (dst, port) >|= function | Error e -> @@ -94,9 +109,14 @@ module Make (R : Mirage_random.S) (P : Mirage_clock.PCLOCK) (M : Mirage_clock.MC Log.err (fun m -> m "wanted to answer %a:%d via TCP, but couldn't find a flow" Ipaddr.pp dst dst_port) ; Lwt.return_unit - | Some flow -> Dns.send_tcp flow data >|= function - | Ok () -> () - | Error () -> tcp_in := FM.remove (dst, dst_port) !tcp_in + | Some `Tcp flow -> + (Dns.send_tcp flow data >|= function + | Ok () -> () + | Error () -> tcp_in := FM.remove (dst, dst_port) !tcp_in) + | Some `Tls flow -> + (send_tls flow data >|= function + | Ok () -> () + | Error () -> tcp_in := FM.remove (dst, dst_port) !tcp_in) and udp_cb req ~src ~dst:_ ~src_port buf = let now = Ptime.v (P.now_d_ps ()) and ts = M.elapsed_ns () @@ -108,13 +128,15 @@ module Make (R : Mirage_random.S) (P : Mirage_clock.PCLOCK) (M : Mirage_clock.MC Lwt_list.iter_p handle_answer answers >>= fun () -> Lwt_list.iter_p handle_query queries in - S.listen_udp stack ~port (udp_cb true) ; - Log.app (fun f -> f "DNS resolver listening on UDP port %d" port); + if udp then begin + S.listen_udp stack ~port (udp_cb true); + Log.app (fun f -> f "DNS resolver listening on UDP port %d" port); + end; let tcp_cb query flow = let dst_ip, dst_port = T.dst flow in Log.info (fun m -> m "tcp connection from %a:%d" Ipaddr.pp dst_ip dst_port) ; - tcp_in := FM.add (dst_ip, dst_port) flow !tcp_in ; + tcp_in := FM.add (dst_ip, dst_port) (`Tcp flow) !tcp_in ; let f = Dns.of_flow flow in let rec loop () = Dns.read_tcp f >>= function @@ -134,8 +156,66 @@ module Make (R : Mirage_random.S) (P : Mirage_clock.PCLOCK) (M : Mirage_clock.MC in loop () in - S.listen_tcp stack ~port (tcp_cb true) ; - Log.info (fun m -> m "DNS resolver listening on TCP port %d" port) ; + if tcp then begin + S.listen_tcp stack ~port (tcp_cb true); + Log.info (fun m -> m "DNS resolver listening on TCP port %d" port); + end; + + let rec read_tls ({ tls_flow ; linger } as f) length = + if Cstruct.length linger >= length then + let a, b = Cstruct.split linger length in + f.linger <- b; + Lwt.return (Ok a) + else + TLS.read tls_flow >>= function + | Ok `Eof -> Log.debug (fun m -> m "end of file while reading"); TLS.close tls_flow >|= fun () -> Error () + | Error e -> Log.warn (fun m -> m "error reading TLS: %a" TLS.pp_error e); TLS.close tls_flow >|= fun () -> Error () + | Ok (`Data d) -> + f.linger <- Cstruct.append linger d; + read_tls f length + in + let read_tls_packet f = + read_tls f 2 >>= function + | Error () -> Lwt.return (Error ()) + | Ok k -> + let len = Cstruct.BE.get_uint16 k 0 in + read_tls f len + in + + let tls_cb cfg flow = + let dst_ip, dst_port = T.dst flow in + TLS.server_of_flow cfg flow >>= function + | Error e -> + Log.warn (fun m -> m "TLS error (from %a:%d): %a" Ipaddr.pp dst_ip dst_port + TLS.pp_write_error e); + Lwt.return_unit + | Ok tls -> + Log.info (fun m -> m "tls connection from %a:%d" Ipaddr.pp dst_ip dst_port); + tcp_in := FM.add (dst_ip, dst_port) (`Tls tls) !tcp_in ; + let tls_and_linger = { tls_flow = tls ; linger = Cstruct.empty } in + let rec loop () = + read_tls_packet tls_and_linger >>= function + | Error () -> + tcp_in := FM.remove (dst_ip, dst_port) !tcp_in ; + Lwt.return_unit + | Ok data -> + let now = Ptime.v (P.now_d_ps ()) in + let ts = M.elapsed_ns () in + let new_state, answers, queries = + Dns_resolver.handle_buf !state now ts true `Tcp dst_ip dst_port data + in + state := new_state ; + Lwt_list.iter_p handle_answer answers >>= fun () -> + Lwt_list.iter_p handle_query queries >>= fun () -> + loop () + in + loop () + in + (match tls with + | None -> () + | Some cfg -> + S.listen_tcp stack ~port:tls_port (tls_cb cfg); + Log.info (fun m -> m "DNS resolver listening on TLS port %d" port)); let rec time () = let new_state, answers, queries = diff --git a/mirage/resolver/dns_resolver_mirage.mli b/mirage/resolver/dns_resolver_mirage.mli index dcb513d38..4c036ff8d 100644 --- a/mirage/resolver/dns_resolver_mirage.mli +++ b/mirage/resolver/dns_resolver_mirage.mli @@ -2,9 +2,10 @@ module Make (R : Mirage_random.S) (P : Mirage_clock.PCLOCK) (M : Mirage_clock.MCLOCK) (T : Mirage_time.S) (S : Mirage_stack.V4V6) : sig - val resolver : S.t -> ?root:bool -> ?timer:int -> ?port:int -> Dns_resolver.t -> unit - (** [resolver stack ~root ~timer ~port resolver] registers a caching resolver - on the provided [port] (both udp and tcp) using the [resolver] - configuration. The [timer] is in milliseconds and defaults to 500 - milliseconds.*) + val resolver : S.t -> ?root:bool -> ?timer:int -> ?udp:bool -> ?tcp:bool -> ?tls:Tls.Config.server -> ?port:int -> ?tls_port:int -> Dns_resolver.t -> unit + (** [resolver stack ~root ~timer ~udp ~tcp ~tls ~port ~tls_port resolver] + registers a caching resolver on the provided protocols [udp], [tcp], [tls] + using [port] for udp and tcp (defaults to 53), [tls_port] for tls (defaults + to 853) using the [resolver] configuration. The [timer] is in milliseconds + and defaults to 500 milliseconds.*) end diff --git a/mirage/resolver/dune b/mirage/resolver/dune index 0b34eb835..2b29388cd 100644 --- a/mirage/resolver/dune +++ b/mirage/resolver/dune @@ -2,4 +2,4 @@ (name dns_resolver_mirage) (public_name dns-resolver.mirage) (wrapped false) - (libraries dns dns-resolver dns-server dns-mirage lwt duration mirage-time mirage-clock mirage-stack mirage-random)) + (libraries dns dns-resolver dns-server dns-mirage lwt duration mirage-time mirage-clock mirage-stack mirage-random tls tls-mirage))