Skip to content

Commit

Permalink
dns-resolver.mirage: add dns-over-tls support
Browse files Browse the repository at this point in the history
  • Loading branch information
robur-team committed Oct 19, 2021
1 parent aa8b8dc commit 8e863fc
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 15 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
/etc/resolv.conf sequentially (lwt and mirage) (#269 @reynir and @hannesm)
* BREAKING dns-client remove UDP support from lwt (#270 @reynir and @hannesm)

* BREAKING dns-resolver.mirage add DNS-over-TLS support (@reynir @hannesm)
* BREAKING dns-resolver remove "mode" from codebase, default to recursive
(a stub resolver is available as dns-stub) (#260 @hannesm)
* dns-resolver: use dns.cache instead of copy in Dns_resolver_cache
Expand Down
1 change: 1 addition & 0 deletions dns-resolver.opam
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ depends: [
"mirage-random" {>= "2.0.0"}
"mirage-stack" {>= "2.0.0"}
"alcotest" {with-test}
"tls" "tls-mirage"
]

build: [
Expand Down
98 changes: 89 additions & 9 deletions mirage/resolver/dns_resolver_mirage.ml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ module Make (R : Mirage_random.S) (P : Mirage_clock.PCLOCK) (M : Mirage_clock.MC

module T = S.TCP

module TLS = Tls_mirage.Make(T)

type tls_flow = { tls_flow : TLS.flow ; mutable linger : Cstruct.t }

module FM = Map.Make(struct
type t = Ipaddr.t * int
let compare (ip, p) (ip', p') =
Expand All @@ -19,13 +23,24 @@ module Make (R : Mirage_random.S) (P : Mirage_clock.PCLOCK) (M : Mirage_clock.MC
| x -> x
end)

let resolver stack ?(root = false) ?(timer = 500) ?(port = 53) t =
let resolver stack ?(root = false) ?(timer = 500) ?(udp = true) ?(tcp = true) ?tls ?(port = 53) ?(tls_port = 853) t =
(* according to RFC5452 4.5, we can chose source port between 1024-49152 *)
let sport () = 1024 + Randomconv.int ~bound:48128 R.generate in
let state = ref t in
let tcp_in = ref FM.empty in
let tcp_out = ref Ipaddr.Map.empty in

let send_tls flow data =
let len = Cstruct.create 2 in
Cstruct.BE.set_uint16 len 0 (Cstruct.length data);
TLS.writev flow [len; data] >>= function
| Ok () -> Lwt.return (Ok ())
| Error e ->
Log.err (fun m -> m "tls error %a while writing" TLS.pp_write_error e);
TLS.close flow >|= fun () ->
Error ()
in

let rec client_out dst port =
T.create_connection (S.tcp stack) (dst, port) >|= function
| Error e ->
Expand Down Expand Up @@ -94,9 +109,14 @@ module Make (R : Mirage_random.S) (P : Mirage_clock.PCLOCK) (M : Mirage_clock.MC
Log.err (fun m -> m "wanted to answer %a:%d via TCP, but couldn't find a flow"
Ipaddr.pp dst dst_port) ;
Lwt.return_unit
| Some flow -> Dns.send_tcp flow data >|= function
| Ok () -> ()
| Error () -> tcp_in := FM.remove (dst, dst_port) !tcp_in
| Some `Tcp flow ->
(Dns.send_tcp flow data >|= function
| Ok () -> ()
| Error () -> tcp_in := FM.remove (dst, dst_port) !tcp_in)
| Some `Tls flow ->
(send_tls flow data >|= function
| Ok () -> ()
| Error () -> tcp_in := FM.remove (dst, dst_port) !tcp_in)
and udp_cb req ~src ~dst:_ ~src_port buf =
let now = Ptime.v (P.now_d_ps ())
and ts = M.elapsed_ns ()
Expand All @@ -108,13 +128,15 @@ module Make (R : Mirage_random.S) (P : Mirage_clock.PCLOCK) (M : Mirage_clock.MC
Lwt_list.iter_p handle_answer answers >>= fun () ->
Lwt_list.iter_p handle_query queries
in
S.listen_udp stack ~port (udp_cb true) ;
Log.app (fun f -> f "DNS resolver listening on UDP port %d" port);
if udp then begin
S.listen_udp stack ~port (udp_cb true);
Log.app (fun f -> f "DNS resolver listening on UDP port %d" port);
end;

let tcp_cb query flow =
let dst_ip, dst_port = T.dst flow in
Log.info (fun m -> m "tcp connection from %a:%d" Ipaddr.pp dst_ip dst_port) ;
tcp_in := FM.add (dst_ip, dst_port) flow !tcp_in ;
tcp_in := FM.add (dst_ip, dst_port) (`Tcp flow) !tcp_in ;
let f = Dns.of_flow flow in
let rec loop () =
Dns.read_tcp f >>= function
Expand All @@ -134,8 +156,66 @@ module Make (R : Mirage_random.S) (P : Mirage_clock.PCLOCK) (M : Mirage_clock.MC
in
loop ()
in
S.listen_tcp stack ~port (tcp_cb true) ;
Log.info (fun m -> m "DNS resolver listening on TCP port %d" port) ;
if tcp then begin
S.listen_tcp stack ~port (tcp_cb true);
Log.info (fun m -> m "DNS resolver listening on TCP port %d" port);
end;

let rec read_tls ({ tls_flow ; linger } as f) length =
if Cstruct.length linger >= length then
let a, b = Cstruct.split linger length in
f.linger <- b;
Lwt.return (Ok a)
else
TLS.read tls_flow >>= function
| Ok `Eof -> Log.debug (fun m -> m "end of file while reading"); TLS.close tls_flow >|= fun () -> Error ()
| Error e -> Log.warn (fun m -> m "error reading TLS: %a" TLS.pp_error e); TLS.close tls_flow >|= fun () -> Error ()
| Ok (`Data d) ->
f.linger <- Cstruct.append linger d;
read_tls f length
in
let read_tls_packet f =
read_tls f 2 >>= function
| Error () -> Lwt.return (Error ())
| Ok k ->
let len = Cstruct.BE.get_uint16 k 0 in
read_tls f len
in

let tls_cb cfg flow =
let dst_ip, dst_port = T.dst flow in
TLS.server_of_flow cfg flow >>= function
| Error e ->
Log.warn (fun m -> m "TLS error (from %a:%d): %a" Ipaddr.pp dst_ip dst_port
TLS.pp_write_error e);
Lwt.return_unit
| Ok tls ->
Log.info (fun m -> m "tls connection from %a:%d" Ipaddr.pp dst_ip dst_port);
tcp_in := FM.add (dst_ip, dst_port) (`Tls tls) !tcp_in ;
let tls_and_linger = { tls_flow = tls ; linger = Cstruct.empty } in
let rec loop () =
read_tls_packet tls_and_linger >>= function
| Error () ->
tcp_in := FM.remove (dst_ip, dst_port) !tcp_in ;
Lwt.return_unit
| Ok data ->
let now = Ptime.v (P.now_d_ps ()) in
let ts = M.elapsed_ns () in
let new_state, answers, queries =
Dns_resolver.handle_buf !state now ts true `Tcp dst_ip dst_port data
in
state := new_state ;
Lwt_list.iter_p handle_answer answers >>= fun () ->
Lwt_list.iter_p handle_query queries >>= fun () ->
loop ()
in
loop ()
in
(match tls with
| None -> ()
| Some cfg ->
S.listen_tcp stack ~port:tls_port (tls_cb cfg);
Log.info (fun m -> m "DNS resolver listening on TLS port %d" port));

let rec time () =
let new_state, answers, queries =
Expand Down
11 changes: 6 additions & 5 deletions mirage/resolver/dns_resolver_mirage.mli
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

module Make (R : Mirage_random.S) (P : Mirage_clock.PCLOCK) (M : Mirage_clock.MCLOCK) (T : Mirage_time.S) (S : Mirage_stack.V4V6) : sig

val resolver : S.t -> ?root:bool -> ?timer:int -> ?port:int -> Dns_resolver.t -> unit
(** [resolver stack ~root ~timer ~port resolver] registers a caching resolver
on the provided [port] (both udp and tcp) using the [resolver]
configuration. The [timer] is in milliseconds and defaults to 500
milliseconds.*)
val resolver : S.t -> ?root:bool -> ?timer:int -> ?udp:bool -> ?tcp:bool -> ?tls:Tls.Config.server -> ?port:int -> ?tls_port:int -> Dns_resolver.t -> unit
(** [resolver stack ~root ~timer ~udp ~tcp ~tls ~port ~tls_port resolver]
registers a caching resolver on the provided protocols [udp], [tcp], [tls]
using [port] for udp and tcp (defaults to 53), [tls_port] for tls (defaults
to 853) using the [resolver] configuration. The [timer] is in milliseconds
and defaults to 500 milliseconds.*)
end
2 changes: 1 addition & 1 deletion mirage/resolver/dune
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
(name dns_resolver_mirage)
(public_name dns-resolver.mirage)
(wrapped false)
(libraries dns dns-resolver dns-server dns-mirage lwt duration mirage-time mirage-clock mirage-stack mirage-random))
(libraries dns dns-resolver dns-server dns-mirage lwt duration mirage-time mirage-clock mirage-stack mirage-random tls tls-mirage))

0 comments on commit 8e863fc

Please sign in to comment.