Skip to content

Commit

Permalink
Enable netstack save/restore in cloud/gvisor by default.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 703593364
  • Loading branch information
nybidari authored and gvisor-bot committed Dec 16, 2024
1 parent 0f8216c commit af7f267
Show file tree
Hide file tree
Showing 18 changed files with 116 additions and 52 deletions.
3 changes: 3 additions & 0 deletions pkg/sentry/inet/inet.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ type Stack interface {

// IsSaveRestoreEnabled returns true when netstack s/r is enabled.
IsSaveRestoreEnabled() bool

// Stats returns the network stats.
Stats() tcpip.Stats
}

// Interface contains information about a network interface.
Expand Down
6 changes: 6 additions & 0 deletions pkg/sentry/inet/test_stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,9 @@ func (*TestStack) IsSaveRestoreEnabled() bool {
// No-op.
return false
}

// Stats implements Stack.
func (*TestStack) Stats() tcpip.Stats {
// No-op.
return tcpip.Stats{}
}
5 changes: 5 additions & 0 deletions pkg/sentry/kernel/timekeeper_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ func (t *Timekeeper) beforeSave() {
panic("pauseUpdates must be called before Save")
}

if t.clocks == nil {
t.restored = nil
return
}

// N.B. we want the *offset* monotonic time.
var err error
if t.saveMonotonic, err = t.GetTime(time.Monotonic); err != nil {
Expand Down
5 changes: 5 additions & 0 deletions pkg/sentry/socket/hostinet/stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,8 @@ func (*Stack) EnableSaveRestore() error {
func (s *Stack) IsSaveRestoreEnabled() bool {
return false
}

// Stats implements inet.Stack.Stats.
func (s *Stack) Stats() tcpip.Stats {
return tcpip.Stats{}
}
16 changes: 11 additions & 5 deletions pkg/sentry/socket/netstack/stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,7 @@ func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error {

// Statistics implements inet.Stack.Statistics.
func (s *Stack) Statistics(stat any, arg string) error {
netStats := s.Stats()
switch stats := stat.(type) {
case *inet.StatDev:
for _, ni := range s.Stack.NICInfo() {
Expand Down Expand Up @@ -622,7 +623,7 @@ func (s *Stack) Statistics(stat any, arg string) error {
break
}
case *inet.StatSNMPIP:
ip := Metrics.IP
ip := netStats.IP
// TODO(gvisor.dev/issue/969) Support stubbed stats.
*stats = inet.StatSNMPIP{
0, // Ip/Forwarding.
Expand All @@ -646,8 +647,8 @@ func (s *Stack) Statistics(stat any, arg string) error {
0, // Support Ip/FragCreates.
}
case *inet.StatSNMPICMP:
in := Metrics.ICMP.V4.PacketsReceived.ICMPv4PacketStats
out := Metrics.ICMP.V4.PacketsSent.ICMPv4PacketStats
in := netStats.ICMP.V4.PacketsReceived.ICMPv4PacketStats
out := netStats.ICMP.V4.PacketsSent.ICMPv4PacketStats
// TODO(gvisor.dev/issue/969) Support stubbed stats.
*stats = inet.StatSNMPICMP{
0, // Icmp/InMsgs.
Expand Down Expand Up @@ -679,7 +680,7 @@ func (s *Stack) Statistics(stat any, arg string) error {
out.InfoReply.Value(), // OutAddrMaskReps.
}
case *inet.StatSNMPTCP:
tcp := Metrics.TCP
tcp := netStats.TCP
// RFC 2012 (updates 1213): SNMPv2-MIB-TCP.
*stats = inet.StatSNMPTCP{
1, // RtoAlgorithm.
Expand All @@ -699,7 +700,7 @@ func (s *Stack) Statistics(stat any, arg string) error {
tcp.ChecksumErrors.Value(), // InCsumErrors.
}
case *inet.StatSNMPUDP:
udp := Metrics.UDP
udp := netStats.UDP
// TODO(gvisor.dev/issue/969) Support stubbed stats.
*stats = inet.StatSNMPUDP{
udp.PacketsReceived.Value(), // InDatagrams.
Expand All @@ -717,6 +718,11 @@ func (s *Stack) Statistics(stat any, arg string) error {
return nil
}

// Stats implements inet.Stack.Stats.
func (s *Stack) Stats() tcpip.Stats {
return s.Stack.Stats()
}

// RouteTable implements inet.Stack.RouteTable.
func (s *Stack) RouteTable() []inet.Route {
var routeTable []inet.Route
Expand Down
4 changes: 1 addition & 3 deletions pkg/tcpip/stack/addressable_endpoint_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -738,8 +738,6 @@ func (a *AddressableEndpointState) Cleanup() {
var _ AddressEndpoint = (*addressState)(nil)

// addressState holds state for an address.
//
// +stateify savable
type addressState struct {
addressableEndpointState *AddressableEndpointState
addr tcpip.AddressWithPrefix
Expand All @@ -750,7 +748,7 @@ type addressState struct {
//
// AddressableEndpointState.mu
// addressState.mu
mu addressStateRWMutex `state:"nosave"`
mu addressStateRWMutex
refs addressStateRefs
// checklocks:mu
kind AddressKind
Expand Down
1 change: 1 addition & 0 deletions pkg/tcpip/stack/stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -1986,6 +1986,7 @@ func (s *Stack) ReplaceConfig(st *Stack) {
s.nics[id] = nic
_ = s.NextNICID()
}
s.tables = st.tables
}

// Restore restarts the stack after a restore. This must be called after the
Expand Down
2 changes: 1 addition & 1 deletion pkg/tcpip/transport/icmp/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ type endpoint struct {

// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
stack *stack.Stack
transProto tcpip.TransportProtocolNumber
waiterQueue *waiter.Queue
net network.Endpoint
Expand Down
10 changes: 9 additions & 1 deletion pkg/tcpip/transport/icmp/endpoint_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ func (p *icmpPacket) loadReceivedAt(_ context.Context, nsec int64) {

// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad(ctx context.Context) {
stack.RestoreStackFromContext(ctx).RegisterRestoredEndpoint(e)
if e.stack.IsSaveRestoreEnabled() {
e.stack.RegisterRestoredEndpoint(e)
} else {
stack.RestoreStackFromContext(ctx).RegisterRestoredEndpoint(e)
}
}

// beforeSave is invoked by stateify.
Expand All @@ -50,6 +54,10 @@ func (e *endpoint) Restore(s *stack.Stack) {
e.thaw()

e.net.Resume(s)
if e.stack.IsSaveRestoreEnabled() {
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
return
}

e.stack = s
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
Expand Down
4 changes: 2 additions & 2 deletions pkg/tcpip/transport/internal/network/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import (
// +stateify savable
type Endpoint struct {
// The following fields must only be set once then never changed.
stack *stack.Stack `state:"manual"`
stack *stack.Stack
ops *tcpip.SocketOptions
netProto tcpip.NetworkProtocolNumber
transProto tcpip.TransportProtocolNumber
Expand All @@ -53,7 +53,7 @@ type Endpoint struct {
// +checklocks:mu
effectiveNetProto tcpip.NetworkProtocolNumber
// +checklocks:mu
connectedRoute *stack.Route `state:"manual"`
connectedRoute *stack.Route `state:"nosave"`
// +checklocks:mu
multicastMemberships map[multicastMembership]struct{}
// +checklocks:mu
Expand Down
2 changes: 1 addition & 1 deletion pkg/tcpip/transport/packet/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ type endpoint struct {

// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
stack *stack.Stack
waiterQueue *waiter.Queue
cooked bool
ops tcpip.SocketOptions
Expand Down
12 changes: 10 additions & 2 deletions pkg/tcpip/transport/packet/endpoint_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,20 @@ func (ep *endpoint) beforeSave() {

// afterLoad is invoked by stateify.
func (ep *endpoint) afterLoad(ctx context.Context) {
if !ep.stack.IsSaveRestoreEnabled() {
ep.mu.Lock()
ep.stack = stack.RestoreStackFromContext(ctx)
ep.mu.Unlock()
}
ep.stack.RegisterRestoredEndpoint(ep)
}

// Restore implements tcpip.RestoredEndpoint.Restore.
func (ep *endpoint) Restore(_ *stack.Stack) {
ep.mu.Lock()
defer ep.mu.Unlock()

ep.stack = stack.RestoreStackFromContext(ctx)
ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)

if err := ep.stack.RegisterPacketEndpoint(ep.boundNIC, ep.boundNetProto, ep); err != nil {
panic(fmt.Sprintf("RegisterPacketEndpoint(%d, %d, _): %s", ep.boundNIC, ep.boundNetProto, err))
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/tcpip/transport/raw/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ type endpoint struct {

// The following fields are initialized at creation time and are
// immutable.
stack *stack.Stack `state:"manual"`
stack *stack.Stack
transProto tcpip.TransportProtocolNumber
waiterQueue *waiter.Queue
associated bool
Expand Down
15 changes: 11 additions & 4 deletions pkg/tcpip/transport/raw/endpoint_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package raw

import (
"context"
"fmt"
"time"

"gvisor.dev/gvisor/pkg/tcpip"
Expand All @@ -35,7 +34,11 @@ func (p *rawPacket) loadReceivedAt(_ context.Context, nsec int64) {

// afterLoad is invoked by stateify.
func (e *endpoint) afterLoad(ctx context.Context) {
stack.RestoreStackFromContext(ctx).RegisterRestoredEndpoint(e)
if e.stack.IsSaveRestoreEnabled() {
e.stack.RegisterRestoredEndpoint(e)
} else {
stack.RestoreStackFromContext(ctx).RegisterRestoredEndpoint(e)
}
}

// beforeSave is invoked by stateify.
Expand All @@ -46,16 +49,20 @@ func (e *endpoint) beforeSave() {

// Restore implements tcpip.RestoredEndpoint.Restore.
func (e *endpoint) Restore(s *stack.Stack) {
e.setReceiveDisabled(false)
e.net.Resume(s)
if e.stack.IsSaveRestoreEnabled() {
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)
return
}

e.setReceiveDisabled(false)
e.stack = s
e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits)

if e.associated {
netProto := e.net.NetProto()
if err := e.stack.RegisterRawTransportEndpoint(netProto, e.transProto, e); err != nil {
panic(fmt.Sprintf("e.stack.RegisterRawTransportEndpoint(%d, %d, _): %s", netProto, e.transProto, err))
panic("RegisterRawTransportEndpoint failed during restore")
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/tcpip/transport/tcp/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ type Endpoint struct {
isPortReserved bool
isRegistered bool
boundNICID tcpip.NICID
route *stack.Route `state:"manual"`
route *stack.Route `state:"nosave"`
ipv4TTL uint8
ipv6HopLimit int16
isConnectNotified bool
Expand Down
61 changes: 36 additions & 25 deletions pkg/tcpip/transport/tcp/endpoint_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,21 +159,23 @@ func (e *Endpoint) Restore(s *stack.Stack) {
bind := func() {
e.mu.Lock()
defer e.mu.Unlock()
addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.TransportEndpointInfo.ID.LocalPort}, true /* bind */)
if err != nil {
panic("unable to parse BindAddr: " + err.String())
}
portRes := ports.Reservation{
Networks: e.effectiveNetProtos,
Transport: ProtocolNumber,
Addr: addr.Addr,
Port: addr.Port,
Flags: e.boundPortFlags,
BindToDevice: e.boundBindToDevice,
Dest: e.boundDest,
}
if ok := e.stack.ReserveTuple(portRes); !ok {
panic(fmt.Sprintf("unable to re-reserve tuple (%v, %q, %d, %+v, %d, %v)", e.effectiveNetProtos, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest))
if !saveRestoreEnabled {
addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.TransportEndpointInfo.ID.LocalPort}, true /* bind */)
if err != nil {
panic("unable to parse BindAddr: " + err.String())
}
portRes := ports.Reservation{
Networks: e.effectiveNetProtos,
Transport: ProtocolNumber,
Addr: addr.Addr,
Port: addr.Port,
Flags: e.boundPortFlags,
BindToDevice: e.boundBindToDevice,
Dest: e.boundDest,
}
if ok := e.stack.ReserveTuple(portRes); !ok {
panic(fmt.Sprintf("unable to re-reserve tuple (%v, %q, %d, %+v, %d, %v)", e.effectiveNetProtos, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest))
}
}
e.isPortReserved = true

Expand All @@ -183,7 +185,7 @@ func (e *Endpoint) Restore(s *stack.Stack) {

epState := EndpointState(e.origEndpointState)
switch {
case epState.connected():
case epState.connected() || epState == StateTimeWait:
bind()
if e.connectingAddress.BitLen() == 0 {
e.connectingAddress = e.TransportEndpointInfo.ID.RemoteAddress
Expand All @@ -201,6 +203,10 @@ func (e *Endpoint) Restore(s *stack.Stack) {
// Reset the scoreboard to reinitialize the sack information as
// we do not restore SACK information.
e.scoreboard.Reset()
if saveRestoreEnabled {
// Unregister the endpoint before registering again during Connect.
e.stack.UnregisterTransportEndpoint(e.effectiveNetProtos, header.TCPProtocolNumber, e.TransportEndpointInfo.ID, e, e.boundPortFlags, e.boundBindToDevice)
}
e.mu.Lock()
err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.TransportEndpointInfo.ID.RemotePort}, false /* handshake */)
if _, ok := err.(*tcpip.ErrConnectStarted); !ok {
Expand All @@ -224,8 +230,8 @@ func (e *Endpoint) Restore(s *stack.Stack) {
e.mu.Unlock()
connectedLoading.Done()
case epState == StateListen:
tcpip.AsyncLoading.Add(1)
if !saveRestoreEnabled {
tcpip.AsyncLoading.Add(1)
go func() {
connectedLoading.Wait()
bind()
Expand All @@ -244,14 +250,19 @@ func (e *Endpoint) Restore(s *stack.Stack) {
tcpip.AsyncLoading.Done()
}()
} else {
e.LockUser()
// All endpoints will be moved to initial state after
// restore. Set endpoint to its originial listen state.
e.setEndpointState(StateListen)
// Initialize the listening context.
rcvWnd := seqnum.Size(e.receiveBufferAvailable())
e.listenCtx = newListenContext(e.stack, e.protocol, e, rcvWnd, e.ops.GetV6Only(), e.NetProto)
e.UnlockUser()
go func() {
connectedLoading.Wait()
e.LockUser()
// All endpoints will be moved to initial state after
// restore. Set endpoint to its originial listen state.
e.setEndpointState(StateListen)
// Initialize the listening context.
rcvWnd := seqnum.Size(e.receiveBufferAvailable())
e.listenCtx = newListenContext(e.stack, e.protocol, e, rcvWnd, e.ops.GetV6Only(), e.NetProto)
e.UnlockUser()
listenLoading.Done()
tcpip.AsyncLoading.Done()
}()
}
case epState == StateConnecting:
// Initial SYN hasn't been sent yet so initiate a connect.
Expand Down
2 changes: 1 addition & 1 deletion pkg/tcpip/transport/udp/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ type endpoint struct {

// The following fields are initialized at creation time and do not
// change throughout the lifetime of the endpoint.
stack *stack.Stack `state:"manual"`
stack *stack.Stack
waiterQueue *waiter.Queue
net network.Endpoint
stats tcpip.TransportEndpointStats
Expand Down
Loading

0 comments on commit af7f267

Please sign in to comment.