Skip to content

Commit

Permalink
Merge pull request #429 from hannesm/tcp-disconnect
Browse files Browse the repository at this point in the history
stack-direct & tcp: implement disconnect
  • Loading branch information
hannesm authored Aug 26, 2020
2 parents 56b36bf + e55c844 commit 8dbb9c5
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 55 deletions.
81 changes: 44 additions & 37 deletions src/stack-direct/tcpip_stack_direct.ml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ module Make
tcpv4 : Tcpv4.t;
udpv4_listeners: (int, Udpv4.callback) Hashtbl.t;
tcpv4_listeners: (int, Tcpv4.listener) Hashtbl.t;
mutable task : unit Lwt.t option;
}

let pp fmt t =
Expand Down Expand Up @@ -85,51 +86,57 @@ module Make
with Not_found -> None

let listen t =
Log.debug (fun f -> f "Establishing or updating listener for stack %a" pp t);
let ethif_listener = Ethernet.input
~arpv4:(Arpv4.input t.arpv4)
~ipv4:(
Ipv4.input
~tcp:(Tcpv4.input t.tcpv4
~listeners:(tcpv4_listeners t))
~udp:(Udpv4.input t.udpv4
~listeners:(udpv4_listeners t))
~default:(fun ~proto ~src ~dst buf ->
match proto with
| 1 -> Icmpv4.input t.icmpv4 ~src ~dst buf
| _ -> Lwt.return_unit)
t.ipv4)
~ipv6:(fun _ -> Lwt.return_unit)
t.ethif
in
Netif.listen t.netif ~header_size:Ethernet_wire.sizeof_ethernet ethif_listener
>>= function
| Error e ->
Log.warn (fun p -> p "%a" Netif.pp_error e) ;
(* XXX: error should be passed to the caller *)
Lwt.return_unit
| Ok _res ->
let nstat = Netif.get_stats_counters t.netif in
let open Mirage_net in
Log.info (fun f ->
f "listening loop of interface %s terminated regularly:@ %Lu bytes \
(%lu packets) received, %Lu bytes (%lu packets) sent@ "
(Macaddr.to_string (Netif.mac t.netif))
nstat.rx_bytes nstat.rx_pkts
nstat.tx_bytes nstat.tx_pkts) ;
Lwt.return_unit
Lwt.catch (fun () ->
Log.debug (fun f -> f "Establishing or updating listener for stack %a" pp t);
let ethif_listener = Ethernet.input
~arpv4:(Arpv4.input t.arpv4)
~ipv4:(
Ipv4.input
~tcp:(Tcpv4.input t.tcpv4
~listeners:(tcpv4_listeners t))
~udp:(Udpv4.input t.udpv4
~listeners:(udpv4_listeners t))
~default:(fun ~proto ~src ~dst buf ->
match proto with
| 1 -> Icmpv4.input t.icmpv4 ~src ~dst buf
| _ -> Lwt.return_unit)
t.ipv4)
~ipv6:(fun _ -> Lwt.return_unit)
t.ethif
in
Netif.listen t.netif ~header_size:Ethernet_wire.sizeof_ethernet ethif_listener
>>= function
| Error e ->
Log.warn (fun p -> p "%a" Netif.pp_error e) ;
(* XXX: error should be passed to the caller *)
Lwt.return_unit
| Ok _res ->
let nstat = Netif.get_stats_counters t.netif in
let open Mirage_net in
Log.info (fun f ->
f "listening loop of interface %s terminated regularly:@ %Lu bytes \
(%lu packets) received, %Lu bytes (%lu packets) sent@ "
(Macaddr.to_string (Netif.mac t.netif))
nstat.rx_bytes nstat.rx_pkts
nstat.tx_bytes nstat.tx_pkts) ;
Lwt.return_unit)
(function
| Lwt.Canceled ->
Log.info (fun f -> f "listen of %a cancelled" pp t);
Lwt.return_unit
| e -> Lwt.fail e)

let connect netif ethif arpv4 ipv4 icmpv4 udpv4 tcpv4 =
let udpv4_listeners = Hashtbl.create 7 in
let tcpv4_listeners = Hashtbl.create 7 in
let t = { netif; ethif; arpv4; ipv4; icmpv4; tcpv4; udpv4;
udpv4_listeners; tcpv4_listeners } in
udpv4_listeners; tcpv4_listeners; task = None } in
Log.info (fun f -> f "stack assembled: %a" pp t);
Lwt.async (fun () -> listen t);
Lwt.async (fun () -> let task = listen t in t.task <- Some task; task);
Lwt.return t

let disconnect t =
(* TODO: kill the listening thread *)
Log.info (fun f -> f "disconnect called (currently a noop): %a" pp t);
Log.info (fun f -> f "disconnect called: %a" pp t);
(match t.task with None -> () | Some task -> Lwt.cancel task);
Lwt.return_unit
end
50 changes: 32 additions & 18 deletions src/tcp/flow.ml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ struct

type t = {
ip : Ip.t;
mutable active : bool ;
mutable localport : int;
channels: (WIRE.t, connection) Hashtbl.t;
(* server connections the process of connecting - SYN-ACK sent
Expand Down Expand Up @@ -537,19 +538,23 @@ struct
>>= fun _ -> Lwt.return_unit (* if send fails, who cares *)

let input_no_pcb t listeners (parsed, payload) id =
let { sequence; Tcp_packet.ack_number; window; options; syn; fin; rst; ack; _ } = parsed in
match rst, syn, ack with
| true, _, _ -> process_reset t id ~ack ~ack_number
| false, true, true ->
process_synack t id ~ack_number ~sequence ~tx_wnd:window ~options ~syn ~fin
| false, true , false -> process_syn t id ~listeners ~tx_wnd:window
~ack_number ~sequence ~options ~syn ~fin
| false, false, true ->
let open RXS in
process_ack t id ~pkt:{ header = parsed; payload}
| false, false, false ->
Log.debug (fun f -> f "incoming packet matches no connection table entry and has no useful flags set; dropping it");
if not t.active then
(* TODO: eventually send an RST? *)
Lwt.return_unit
else
let { sequence; Tcp_packet.ack_number; window; options; syn; fin; rst; ack; _ } = parsed in
match rst, syn, ack with
| true, _, _ -> process_reset t id ~ack ~ack_number
| false, true, true ->
process_synack t id ~ack_number ~sequence ~tx_wnd:window ~options ~syn ~fin
| false, true , false -> process_syn t id ~listeners ~tx_wnd:window
~ack_number ~sequence ~options ~syn ~fin
| false, false, true ->
let open RXS in
process_ack t id ~pkt:{ header = parsed; payload}
| false, false, false ->
Log.debug (fun f -> f "incoming packet matches no connection table entry and has no useful flags set; dropping it");
Lwt.return_unit

(* Main input function for TCP packets *)
let input t ~listeners ~src ~dst data =
Expand Down Expand Up @@ -714,9 +719,12 @@ struct
pp_error e Ip.pp_ipaddr daddr dport)

let create_connection ?keepalive tcp (daddr, dport) =
connect ?keepalive tcp ~dst:daddr ~dst_port:dport >>= function
| Error e -> log_failure daddr dport e; Lwt.return @@ Error e
| Ok (fl, _) -> Lwt.return (Ok fl)
if not tcp.active then
Lwt.return (Error `Timeout) (* TODO: custom error variant *)
else
connect ?keepalive tcp ~dst:daddr ~dst_port:dport >>= function
| Error e -> log_failure daddr dport e; Lwt.return @@ Error e
| Ok (fl, _) -> Lwt.return (Ok fl)

(* Construct the main TCP thread *)
let connect ip =
Expand All @@ -726,7 +734,13 @@ struct
let listens = Hashtbl.create 1 in
let connects = Hashtbl.create 1 in
let channels = Hashtbl.create 7 in
Lwt.return { ip; localport; channels; listens; connects }

let disconnect _ = Lwt.return_unit
Lwt.return { ip; active = true; localport; channels; listens; connects }

let disconnect t =
t.active <- false;
let conns = Hashtbl.fold (fun _ (pcb, _) acc -> pcb :: acc) t.channels [] in
Lwt_list.iter_p close conns >|= fun () ->
Hashtbl.reset t.listens;
Hashtbl.reset t.connects
(* TODO: should there be Lwt tasks being cancelled? *)
end

0 comments on commit 8dbb9c5

Please sign in to comment.