Skip to content

Commit

Permalink
Enable netstack s/r for all tests only.
Browse files Browse the repository at this point in the history
Set TESTONLY-save-restore-netstack flag to true for all tests. This will make
all save restore tests run with netstack s/r.

PiperOrigin-RevId: 703593364
  • Loading branch information
nybidari authored and gvisor-bot committed Dec 7, 2024
1 parent 22b95a8 commit 6babbc9
Show file tree
Hide file tree
Showing 19 changed files with 158 additions and 91 deletions.
3 changes: 3 additions & 0 deletions pkg/sentry/inet/inet.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ type Stack interface {

// IsSaveRestoreEnabled returns true when netstack s/r is enabled.
IsSaveRestoreEnabled() bool

// Stats returns the network stats.
Stats() tcpip.Stats
}

// Interface contains information about a network interface.
Expand Down
6 changes: 6 additions & 0 deletions pkg/sentry/inet/test_stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,9 @@ func (*TestStack) IsSaveRestoreEnabled() bool {
// No-op.
return false
}

// Stats implements Stack.
func (*TestStack) Stats() tcpip.Stats {
// No-op.
return tcpip.Stats{}
}
5 changes: 5 additions & 0 deletions pkg/sentry/kernel/timekeeper_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ func (t *Timekeeper) beforeSave() {
panic("pauseUpdates must be called before Save")
}

if t.clocks == nil {
t.restored = nil
return
}

// N.B. we want the *offset* monotonic time.
var err error
if t.saveMonotonic, err = t.GetTime(time.Monotonic); err != nil {
Expand Down
5 changes: 5 additions & 0 deletions pkg/sentry/socket/hostinet/stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,8 @@ func (*Stack) EnableSaveRestore() error {
func (s *Stack) IsSaveRestoreEnabled() bool {
return false
}

// Stats implements inet.Stack.Stats.
func (s *Stack) Stats() tcpip.Stats {
return tcpip.Stats{}
}
16 changes: 11 additions & 5 deletions pkg/sentry/socket/netstack/stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,7 @@ func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error {

// Statistics implements inet.Stack.Statistics.
func (s *Stack) Statistics(stat any, arg string) error {
netStats := s.Stats()
switch stats := stat.(type) {
case *inet.StatDev:
for _, ni := range s.Stack.NICInfo() {
Expand Down Expand Up @@ -622,7 +623,7 @@ func (s *Stack) Statistics(stat any, arg string) error {
break
}
case *inet.StatSNMPIP:
ip := Metrics.IP
ip := netStats.IP
// TODO(gvisor.dev/issue/969) Support stubbed stats.
*stats = inet.StatSNMPIP{
0, // Ip/Forwarding.
Expand All @@ -646,8 +647,8 @@ func (s *Stack) Statistics(stat any, arg string) error {
0, // Support Ip/FragCreates.
}
case *inet.StatSNMPICMP:
in := Metrics.ICMP.V4.PacketsReceived.ICMPv4PacketStats
out := Metrics.ICMP.V4.PacketsSent.ICMPv4PacketStats
in := netStats.ICMP.V4.PacketsReceived.ICMPv4PacketStats
out := netStats.ICMP.V4.PacketsSent.ICMPv4PacketStats
// TODO(gvisor.dev/issue/969) Support stubbed stats.
*stats = inet.StatSNMPICMP{
0, // Icmp/InMsgs.
Expand Down Expand Up @@ -679,7 +680,7 @@ func (s *Stack) Statistics(stat any, arg string) error {
out.InfoReply.Value(), // OutAddrMaskReps.
}
case *inet.StatSNMPTCP:
tcp := Metrics.TCP
tcp := netStats.TCP
// RFC 2012 (updates 1213): SNMPv2-MIB-TCP.
*stats = inet.StatSNMPTCP{
1, // RtoAlgorithm.
Expand All @@ -699,7 +700,7 @@ func (s *Stack) Statistics(stat any, arg string) error {
tcp.ChecksumErrors.Value(), // InCsumErrors.
}
case *inet.StatSNMPUDP:
udp := Metrics.UDP
udp := netStats.UDP
// TODO(gvisor.dev/issue/969) Support stubbed stats.
*stats = inet.StatSNMPUDP{
udp.PacketsReceived.Value(), // InDatagrams.
Expand All @@ -717,6 +718,11 @@ func (s *Stack) Statistics(stat any, arg string) error {
return nil
}

// Stats implements inet.Stack.Stats.
func (s *Stack) Stats() tcpip.Stats {
return s.Stack.Stats()
}

// RouteTable implements inet.Stack.RouteTable.
func (s *Stack) RouteTable() []inet.Route {
var routeTable []inet.Route
Expand Down
4 changes: 1 addition & 3 deletions pkg/tcpip/stack/addressable_endpoint_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -738,8 +738,6 @@ func (a *AddressableEndpointState) Cleanup() {
var _ AddressEndpoint = (*addressState)(nil)

// addressState holds state for an address.
//
// +stateify savable
type addressState struct {
addressableEndpointState *AddressableEndpointState
addr tcpip.AddressWithPrefix
Expand All @@ -750,7 +748,7 @@ type addressState struct {
//
// AddressableEndpointState.mu
// addressState.mu
mu addressStateRWMutex `state:"nosave"`
mu addressStateRWMutex
refs addressStateRefs
// checklocks:mu
kind AddressKind
Expand Down
1 change: 1 addition & 0 deletions pkg/tcpip/stack/stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -1988,6 +1988,7 @@ func (s *Stack) ReplaceConfig(st *Stack) {
s.nics[id] = nic
_ = s.NextNICID()
}
s.tables = st.tables
}

// Restore restarts the stack after a restore. This must be called after the
Expand Down
2 changes: 1 addition & 1 deletion pkg/tcpip/transport/icmp/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ type endpoint struct {

// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
stack *stack.Stack
transProto tcpip.TransportProtocolNumber
waiterQueue *waiter.Queue
net network.Endpoint
Expand Down
40 changes: 24 additions & 16 deletions pkg/tcpip/transport/icmp/endpoint_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ func (p *icmpPacket) loadReceivedAt(_ context.Context, nsec int64) {

// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad(ctx context.Context) {
stack.RestoreStackFromContext(ctx).RegisterRestoredEndpoint(e)
if e.stack.IsSaveRestoreEnabled() {
e.stack.RegisterRestoredEndpoint(e)
} else {
stack.RestoreStackFromContext(ctx).RegisterRestoredEndpoint(e)
}
}

// beforeSave is invoked by stateify.
Expand All @@ -49,24 +53,28 @@ func (e *endpoint) beforeSave() {
func (e *endpoint) Restore(s *stack.Stack) {
e.thaw()

saveRestoreEnabled := e.stack.IsSaveRestoreEnabled()
e.net.Resume(s)
if saveRestoreEnabled {
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
} else {
e.stack = s
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)

e.stack = s
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)

switch state := e.net.State(); state {
case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
var err tcpip.Error
info := e.net.Info()
info.ID.LocalPort = e.ident
info.ID, err = e.registerWithStack(info.NetProto, info.ID)
if err != nil {
panic(fmt.Sprintf("e.registerWithStack(%d, %#v): %s", info.NetProto, info.ID, err))
switch state := e.net.State(); state {
case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed:
case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected:
var err tcpip.Error
info := e.net.Info()
info.ID.LocalPort = e.ident
info.ID, err = e.registerWithStack(info.NetProto, info.ID)
if err != nil {
panic(fmt.Sprintf("e.registerWithStack(%d, %#v): %s", info.NetProto, info.ID, err))
}
e.ident = info.ID.LocalPort
default:
panic(fmt.Sprintf("unhandled state = %s", state))
}
e.ident = info.ID.LocalPort
default:
panic(fmt.Sprintf("unhandled state = %s", state))
}
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/tcpip/transport/internal/network/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import (
// +stateify savable
type Endpoint struct {
// The following fields must only be set once then never changed.
stack *stack.Stack `state:"manual"`
stack *stack.Stack
ops *tcpip.SocketOptions
netProto tcpip.NetworkProtocolNumber
transProto tcpip.TransportProtocolNumber
Expand All @@ -53,7 +53,7 @@ type Endpoint struct {
// +checklocks:mu
effectiveNetProto tcpip.NetworkProtocolNumber
// +checklocks:mu
connectedRoute *stack.Route `state:"manual"`
connectedRoute *stack.Route `state:"nosave"`
// +checklocks:mu
multicastMemberships map[multicastMembership]struct{}
// +checklocks:mu
Expand Down
2 changes: 1 addition & 1 deletion pkg/tcpip/transport/packet/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ type endpoint struct {

// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
stack *stack.Stack
waiterQueue *waiter.Queue
cooked bool
ops tcpip.SocketOptions
Expand Down
12 changes: 10 additions & 2 deletions pkg/tcpip/transport/packet/endpoint_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,20 @@ func (ep *endpoint) beforeSave() {

// afterLoad is invoked by stateify.
func (ep *endpoint) afterLoad(ctx context.Context) {
if !ep.stack.IsSaveRestoreEnabled() {
ep.mu.Lock()
ep.stack = stack.RestoreStackFromContext(ctx)
ep.mu.Unlock()
}
ep.stack.RegisterRestoredEndpoint(ep)
}

// Restore implements tcpip.RestoredEndpoint.Restore.
func (ep *endpoint) Restore(_ *stack.Stack) {
ep.mu.Lock()
defer ep.mu.Unlock()

ep.stack = stack.RestoreStackFromContext(ctx)
ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)

if err := ep.stack.RegisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep); err != nil {
panic(fmt.Sprintf("RegisterPacketEndpoint(%d, %d, _): %s", ep.boundNIC, ep.boundNetProto, err))
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/tcpip/transport/raw/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ type endpoint struct {

// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
stack *stack.Stack
transProto tcpip.TransportProtocolNumber
waiterQueue *waiter.Queue
associated bool
Expand Down
27 changes: 17 additions & 10 deletions pkg/tcpip/transport/raw/endpoint_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package raw

import (
"context"
"fmt"
"time"

"gvisor.dev/gvisor/pkg/tcpip"
Expand All @@ -35,7 +34,11 @@ func (p *rawPacket) loadReceivedAt(_ context.Context, nsec int64) {

// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad(ctx context.Context) {
stack.RestoreStackFromContext(ctx).RegisterRestoredEndpoint(e)
if e.stack.IsSaveRestoreEnabled() {
e.stack.RegisterRestoredEndpoint(e)
} else {
stack.RestoreStackFromContext(ctx).RegisterRestoredEndpoint(e)
}
}

// beforeSave is invoked by stateify.
Expand All @@ -46,16 +49,20 @@ func (e *endpoint) beforeSave() {

// Restore implements tcpip.RestoredEndpoint.Restore.
func (e *endpoint) Restore(s *stack.Stack) {
e.net.Resume(s)

e.setReceiveDisabled(false)
e.stack = s
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
if saveRestoredEnabled := e.stack.IsSaveRestoreEnabled(); saveRestoredEnabled {
e.net.Resume(s)
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
} else {
e.net.Resume(s)
e.stack = s
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)

if e.associated {
netProto := e.net.NetProto()
if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil {
panic(fmt.Sprintf("e.stack.RegisterRawTransportEndpoint(%d, %d, _): %s", netProto, e.transProto, err))
if e.associated {
netProto := e.net.NetProto()
if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil {
panic("RegisterRawTransportEndpoint failed during restore")
}
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/tcpip/transport/tcp/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ type Endpoint struct {
isPortReserved bool
isRegistered bool
boundNICID tcpip.NICID
route *stack.Route `state:"manual"`
route *stack.Route `state:"nosave"`
ipv4TTL uint8
ipv6HopLimit int16
isConnectNotified bool
Expand Down
Loading

0 comments on commit 6babbc9

Please sign in to comment.