Skip to content

Commit

Permalink
dns-client(eio): improve performance
Browse files Browse the repository at this point in the history
  • Loading branch information
bikallem committed Dec 11, 2022
1 parent 2fa2f98 commit 16bad51
Showing 1 changed file with 52 additions and 62 deletions.
114 changes: 52 additions & 62 deletions eio/client/dns_client_eio.ml
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,19 @@ module Transport : Dns_client.S
type nonrec stack = stack
type +'a io = 'a

type t = {
type t = {
nameservers : Dns.proto * nameservers ;
stack : stack ;
timeout : Eio.Time.Timeout.t ;
mutable ns_connection_condition : Eio.Condition.t option ;
mutable ctx : (Dns.proto * context) option ;
}

and context = {
and context = {
t : t ;
mutable requests : Cstruct.t Eio.Promise.u IM.t ;
mutable ns_connection: <Eio.Flow.two_way> ;
mutable buf : Cstruct.t ;
mutable recv_buf : Cstruct.t ;
}

(* DNS nameservers. *)
Expand Down Expand Up @@ -161,10 +161,7 @@ module Transport : Dns_client.S
let he, actions = Happy_eyeballs.event he (clock ()) event in
he_handle_actions t he actions
end
| Connect_failed _ ->
fun () ->
Log.debug (fun m -> m "[he_handle_actions] connection failed");
None
| Connect_failed _ -> fun () -> None
| Connect_cancelled _ | Resolve_a _ | Resolve_aaaa _ as a ->
fun () ->
Log.warn (fun m -> m "[he_handle_actions] ignoring action %a" Happy_eyeballs.pp_action a);
Expand All @@ -185,7 +182,6 @@ module Transport : Dns_client.S
| Error `Msg m -> invalid_arg ("failed to load trust anchors: " ^ m)

let rec connect t =
Log.debug (fun m -> m "connect : establishing connection to nameservers");
match t.ctx, t.ns_connection_condition with
| Some ctx, _ -> Ok ctx
| None, Some condition ->
Expand All @@ -209,16 +205,17 @@ module Transport : Dns_client.S
let config = Tls.Config.(client ~authenticator ()) in
(Tls_eio.client_of_flow config conn :> Eio.Flow.two_way)
in
let context =
let ctx =
{ t = t
; requests = IM.empty
; ns_connection = conn
; buf = Cstruct.empty
; recv_buf = Cstruct.create 2048
}
in
t.ctx <- Some (`Tcp, context);
t.ctx <- Some (`Tcp, ctx);
Eio.Fiber.fork ~sw:ctx.t.stack.sw ( fun () -> recv_dns_packets ctx );
Eio.Condition.broadcast ns_connection_condition;
Ok (`Tcp, context)
Ok (`Tcp, ctx)
| None ->
t.ns_connection_condition <- None;
Eio.Condition.broadcast ns_connection_condition;
Expand All @@ -231,47 +228,46 @@ module Transport : Dns_client.S
Error (`Msg error_msg)
end

let recv_data t flow id : unit =
let buf = Cstruct.create 512 in
Log.debug (fun m -> m "recv_data (%X): t.buf.len %d" id (Cstruct.length t.buf));
let got = Eio.Flow.single_read flow buf in
Log.debug (fun m -> m "recv_data (%X): got %d" id got);
let buf = Cstruct.sub buf 0 got in
t.buf <- if Cstruct.length t.buf = 0 then buf else Cstruct.append t.buf buf;
Log.debug (fun m -> m "recv_data (%X): t.buf.len %d" id (Cstruct.length t.buf))
and recv_dns_packets ?(recv_data = Cstruct.empty) (ctx : context) =

let rec recv_packet t ns_connection request_id =
Log.debug (fun m -> m "recv_packet (%X)" request_id);
let buf_len = Cstruct.length t.buf in
if buf_len > 2 then (
let packet_len = Cstruct.BE.get_uint16 t.buf 0 in
Log.debug (fun m -> m "recv_packet (%X): packet_len %d" request_id (Cstruct.length t.buf));
if buf_len - 2 >= packet_len then
let packet, rest =
if buf_len - 2 = packet_len
then t.buf, Cstruct.empty
else Cstruct.split t.buf (packet_len + 2)
in
t.buf <- rest;
let response_id = Cstruct.BE.get_uint16 packet 2 in
Log.debug (fun m -> m "recv_packet (%X): got response %X" request_id response_id);
if response_id = request_id
then packet
else begin
(match IM.find response_id t.requests with
| r -> Eio.Promise.resolve r packet
| exception Not_found -> ());
recv_packet t ns_connection request_id
end
else begin
recv_data t ns_connection request_id;
recv_packet t ns_connection request_id
end
)
else begin
recv_data t ns_connection request_id;
recv_packet t ns_connection request_id
end
let append_recv_buf ctx got recv_data =
let buf = Cstruct.sub ctx.recv_buf 0 got in
if Cstruct.is_empty recv_data
then buf
else Cstruct.append recv_data buf
in

let rec handle_data recv_data =
let recv_data_len = Cstruct.length recv_data in
if recv_data_len < 2
then recv_dns_packets ~recv_data ctx
else
match Cstruct.BE.get_uint16 recv_data 0 with
| packet_len when recv_data_len - 2 >= packet_len ->
let packet, recv_data = Cstruct.split recv_data @@ packet_len + 2 in
let response_id = Cstruct.BE.get_uint16 packet 2 in
(match IM.find response_id ctx.requests with
| r ->
ctx.requests <- IM.remove response_id ctx.requests ;
Eio.Promise.resolve r packet
| exception Not_found -> () (* spurious data, ignore *)
);
if not @@ IM.is_empty ctx.requests then handle_data recv_data else ()
| _ -> recv_dns_packets ~recv_data ctx
in

match Eio.Flow.single_read ctx.ns_connection ctx.recv_buf with
| got ->
let recv_data = append_recv_buf ctx got recv_data in
handle_data recv_data
| exception End_of_file ->
ctx.t.ns_connection_condition <- None ;
ctx.t.ctx <- None ;
if not @@ IM.is_empty ctx.requests then
(match connect ctx.t with
| Ok _ -> recv_dns_packets ~recv_data ctx
| Error _ -> Log.warn (fun m -> m "[recv_dns_packets] connection closed while processing dns requests") )
else ()

let validate_query_packet tx =
if Cstruct.length tx > 4 then Ok () else
Expand All @@ -281,22 +277,16 @@ module Transport : Dns_client.S
let* () = validate_query_packet packet in
try
let request_id = Cstruct.BE.get_uint16 packet 2 in
let response_p, response_r = Eio.Promise.create () in
ctx.requests <- IM.add request_id response_r ctx.requests;
Eio.Time.Timeout.run_exn ctx.t.timeout (fun () ->
Eio.Flow.write ctx.ns_connection [packet];
Log.debug (fun m -> m "send_recv (%X): wrote request" request_id);
let response_p, response_r = Eio.Promise.create () in
ctx.requests <- IM.add request_id response_r ctx.requests;
let response =
Eio.Fiber.first
(fun () -> recv_packet ctx ctx.ns_connection request_id)
(fun () -> Eio.Promise.await response_p)
in
Log.debug (fun m -> m "send_recv (%X): got response" request_id);
let response = Eio.Promise.await response_p in
Ok response
)
with
| Eio.Time.Timeout -> Error (`Msg "DNS request timeout")
(* | exn -> Error (`Msg (Printexc.to_string exn)) *)
| End_of_file -> Error (`Msg "Nameserver closed connection")

let close _ = ()
let bind a f = f a
Expand Down

0 comments on commit 16bad51

Please sign in to comment.