Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix Ipv6 code #428

Merged
merged 10 commits into from
Aug 26, 2020
3 changes: 2 additions & 1 deletion src/ipv6/ipv6.ml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ module Make (E : Mirage_protocols.ETHERNET)

let input t ~tcp ~udp ~default buf =
let now = C.elapsed_ns () in
let _, outs, actions = Ndpv6.handle ~now ~random:R.generate t.ctx buf in
let ctx, outs, actions = Ndpv6.handle ~now ~random:R.generate t.ctx buf in
t.ctx <- ctx;
Lwt_list.iter_s (function
| `Tcp (src, dst, buf) -> tcp ~src ~dst buf
| `Udp (src, dst, buf) -> udp ~src ~dst buf
Expand Down
4 changes: 2 additions & 2 deletions src/ipv6/ndpv6.ml
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,13 @@ module Allocate = struct
let size' = size + Ipv6_wire.sizeof_ipv6 in
let fill ipbuf =
Ipv6_wire.set_ipv6_version_flow ipbuf 0x60000000l; (* IPv6 *)
Ipv6_wire.set_ipv6_len ipbuf size;
ipaddr_to_cstruct_raw src (Ipv6_wire.get_ipv6_src ipbuf) 0;
ipaddr_to_cstruct_raw dst (Ipv6_wire.get_ipv6_dst ipbuf) 0;
Ipv6_wire.set_ipv6_hlim ipbuf hlim;
Ipv6_wire.set_ipv6_nhdr ipbuf (Ipv6_wire.protocol_to_int proto);
let hdr, payload = Cstruct.split ipbuf Ipv6_wire.sizeof_ipv6 in
let len' = fillf hdr payload in
assert (len' <= size') ;
len' + Ipv6_wire.sizeof_ipv6
in
(size', fill)
Expand Down Expand Up @@ -233,8 +233,8 @@ module Allocate = struct
Ipv6_wire.set_pingv6_id icmpbuf id;
Ipv6_wire.set_pingv6_seq icmpbuf seq;
Ipv6_wire.set_pingv6_csum icmpbuf 0;
Ipv6_wire.set_pingv6_csum icmpbuf @@ checksum hdr (icmpbuf :: data :: []);
Cstruct.blit data 0 icmpbuf Ipv6_wire.sizeof_pingv6 (Cstruct.len data);
Ipv6_wire.set_pingv6_csum icmpbuf @@ checksum hdr [ icmpbuf ];
size
in
hdr ~src ~dst ~hlim ~proto:`ICMP ~size fillf
Expand Down
1 change: 1 addition & 0 deletions test/test.ml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ let run test () =
Lwt_main.run (test ())

let () =
Printexc.record_backtrace true;
(* someone has to call Mirage_random_test.initialize () *)
Mirage_random_test.initialize ();
(* enable logging to stdout for all modules *)
Expand Down
13 changes: 8 additions & 5 deletions test/test_iperf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ module Test_iperf (B : Vnetif_backends.Backend) = struct
client : V.Stackv4.t;
}

let default_network ?(backend = B.create ()) () =
V.create_stack ~cidr:client_cidr ~gateway backend >>= fun client ->
V.create_stack ~cidr:server_cidr ~gateway backend >>= fun server ->
let default_network ?mtu ?(backend = B.create ()) () =
V.create_stack ?mtu ~cidr:client_cidr ~gateway backend >>= fun client ->
V.create_stack ?mtu ~cidr:server_cidr ~gateway backend >>= fun server ->
Lwt.return {backend; server; client}

let msg =
Expand Down Expand Up @@ -196,8 +196,11 @@ let test_tcp_iperf_two_stacks_basic amt timeout () =
(Test.tcp_iperf ~server ~client amt timeout)

let test_tcp_iperf_two_stacks_mtu amt timeout () =
let module Test = Test_iperf (Vnetif_backends.Mtu_enforced) in
Test.default_network () >>= fun { backend; Test.client; Test.server } ->
let mtu = 1500 in
let module Test = Test_iperf (Vnetif_backends.Frame_size_enforced) in
let backend = Vnetif_backends.Frame_size_enforced.create () in
Vnetif_backends.Frame_size_enforced.set_max_ip_mtu backend mtu;
Test.default_network ?mtu:(Some mtu) ?backend:(Some backend) () >>= fun { backend; Test.client; Test.server } ->
Test.V.record_pcap backend
(Printf.sprintf "tcp_iperf_two_stacks_mtu_%d.pcap" amt)
(Test.tcp_iperf ~server ~client amt timeout)
Expand Down
22 changes: 13 additions & 9 deletions test/test_ipv6.ml
Original file line number Diff line number Diff line change
Expand Up @@ -49,22 +49,26 @@ let listen ?(tcp = noop) ?(udp = noop) ?(default = noop) stack =

let udp_message = Cstruct.of_string "hello on UDP over IPv6"

let check_for_one_udp_packet netif on_received_one ~src ~dst buf =
Alcotest.(check ip) "sender address" (Ipaddr.V6.of_string_exn "fc00::23") src;
Alcotest.(check ip) "receiver address" (Ipaddr.V6.of_string_exn "fc00::45") dst;
let check_for_one_udp_packet on_received_one ~src ~dst buf =
(match Udp_packet.Unmarshal.of_cstruct buf with
| Ok (_, payload) ->
Alcotest.(check ip) "sender address" (Ipaddr.V6.of_string_exn "fc00::23") src;
Alcotest.(check ip) "receiver address" (Ipaddr.V6.of_string_exn "fc00::45") dst;
Alcotest.(check cstruct) "payload is correct" udp_message payload
| Error m -> Alcotest.fail m);
(try Lwt.wakeup_later on_received_one () with _ -> () (* the first succeeds, the rest raise *));
(*after receiving 1 packet, disconnect stack so test can continue*)
V.disconnect netif
Lwt.return_unit

let send_forever sender receiver_address udp_message =
let rec loop () =
Printf.fprintf stderr "Udp.write\n%!";
Udp.write sender.udp ~dst:receiver_address ~dst_port:1234 udp_message
>|= Rresult.R.get_ok >>= fun () ->
(* Check that we have an IP before sending *)
if List.length (Ipv6.get_ip sender.ip) >= 1 then
begin
Udp.write sender.udp ~dst:receiver_address ~dst_port:1234 udp_message
>|= Rresult.R.get_ok
end else
Lwt.return_unit
>>= fun () ->
Time.sleep_ns (Duration.of_ms 50) >>= fun () ->
loop () in
loop ()
Expand All @@ -77,7 +81,7 @@ let pass_udp_traffic () =
get_stack backend receiver_address >>= fun receiver ->
let received_one, on_received_one = Lwt.task () in
Lwt.pick [
listen receiver ~udp:(check_for_one_udp_packet receiver.netif on_received_one);
listen receiver ~udp:(check_for_one_udp_packet on_received_one);
listen sender;
send_forever sender receiver_address udp_message;
received_one; (* stop on the first packet *)
Expand Down
4 changes: 2 additions & 2 deletions test/test_mtus.ml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ let client_cidr = Ipaddr.V4.Prefix.of_string_exn "192.168.1.10/24"

let server_port = 7

module Backend = Vnetif_backends.Mtu_enforced
module Backend = Vnetif_backends.Frame_size_enforced
module Stack = Vnetif_common.VNETIF_STACK(Backend)

let default_mtu = 1500
Expand Down Expand Up @@ -36,7 +36,7 @@ let get_stacks ?client_mtu ?server_mtu backend =
Stack.create_stack ~cidr:client_cidr ~mtu:client_mtu backend >>= fun client ->
Stack.create_stack ~cidr:server_cidr ~mtu:server_mtu backend >>= fun server ->
let max_mtu = max client_mtu server_mtu in
Backend.set_mtu max_mtu;
Backend.set_max_ip_mtu backend max_mtu;
Lwt.return (server, client)

let start_server ~f server =
Expand Down
42 changes: 34 additions & 8 deletions test/vnetif_backends.ml
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,49 @@ module type Backend = sig
val create : unit -> t
end

(** This backend enforces an MTU. *)
module Mtu_enforced = struct
(** This backend enforces an Ethernet frame size. *)
module Frame_size_enforced = struct
module X = Basic_backend.Make
include X
type t = {
xt : X.t;
mutable frame_size : int;
}

let mtu = ref 1500
type macaddr = X.macaddr
type 'a io = 'a X.io
type buffer = X.buffer
type id = X.id

let register t =
X.register t.xt

let unregister t id =
X.unregister t.xt id

let mac t id =
X.mac t.xt id

let set_listen_fn t id buf =
X.set_listen_fn t.xt id buf

let unregister_and_flush t id =
X.unregister_and_flush t.xt id

let write t id ~size fill =
if size > !mtu then
if size > t.frame_size then
Lwt.return (Error `Invalid_length)
else
X.write t id ~size fill
X.write t.xt id ~size fill

let set_frame_size t m = t.frame_size <- m
let set_max_ip_mtu t m = t.frame_size <- m + Ethernet_wire.sizeof_ethernet

let set_mtu m = mtu := m
let create ~frame_size () =
let xt = X.create ~use_async_readers:true ~yield:(fun() -> Lwt_main.yield () ) () in
{ xt ; frame_size }

let create () =
X.create ~use_async_readers:true ~yield:(fun() -> Lwt_main.yield () ) ()
create ~frame_size:(1500 + Ethernet_wire.sizeof_ethernet) ()

end

Expand Down