diff --git a/pkg/context/context.go b/pkg/context/context.go index 7f94da2478..d7fe3fc2d8 100644 --- a/pkg/context/context.go +++ b/pkg/context/context.go @@ -43,6 +43,9 @@ type Blocker interface { // Interrupted notes whether this context is Interrupted. Interrupted() bool + // Killed returns true if this context is interrupted by a fatal signal. + Killed() bool + // BlockOn blocks until one of the previously registered events occurs, // or some external interrupt (cancellation). // @@ -94,6 +97,11 @@ func (nt *NoTask) Interrupted() bool { return nt.cancel != nil && len(nt.cancel) > 0 } +// Killed implements Blocker.Killed. +func (nt *NoTask) Killed() bool { + return false +} + // Block implements Blocker.Block. func (nt *NoTask) Block(C <-chan struct{}) error { if nt.cancel == nil { diff --git a/pkg/sentry/kernel/pending_signals.go b/pkg/sentry/kernel/pending_signals.go index af455c434b..5886834e0a 100644 --- a/pkg/sentry/kernel/pending_signals.go +++ b/pkg/sentry/kernel/pending_signals.go @@ -16,6 +16,7 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/atomicbitops" "gvisor.dev/gvisor/pkg/bits" ) @@ -49,7 +50,7 @@ type pendingSignals struct { // Bit i of pendingSet is set iff there is at least one signal with signo // i+1 pending. - pendingSet linux.SignalSet `state:"manual"` + pendingSet atomicbitops.Uint64 `state:"manual"` } // pendingSignalQueue holds a pendingSignalList for a single signal number. @@ -86,7 +87,7 @@ func (p *pendingSignals) enqueue(info *linux.SignalInfo, timer *IntervalTimer) b } q.pendingSignalList.PushBack(&pendingSignal{SignalInfo: info, timer: timer}) q.length++ - p.pendingSet |= linux.SignalSetOf(sig) + p.pendingSet.Store(p.pendingSet.RacyLoad() | uint64(linux.SignalSetOf(sig))) return true } @@ -103,7 +104,7 @@ func (p *pendingSignals) dequeue(mask linux.SignalSet) *linux.SignalInfo { // process, POSIX leaves it unspecified which is delivered first. Linux, // like many other implementations, gives priority to standard signals in // this case." - signal(7) - lowestPendingUnblockedBit := bits.TrailingZeros64(uint64(p.pendingSet &^ mask)) + lowestPendingUnblockedBit := bits.TrailingZeros64(p.pendingSet.RacyLoad() &^ uint64(mask)) if lowestPendingUnblockedBit >= linux.SignalMaximum { return nil } @@ -119,7 +120,7 @@ func (p *pendingSignals) dequeueSpecific(sig linux.Signal) *linux.SignalInfo { q.pendingSignalList.Remove(ps) q.length-- if q.length == 0 { - p.pendingSet &^= linux.SignalSetOf(sig) + p.pendingSet.Store(p.pendingSet.RacyLoad() &^ uint64(linux.SignalSetOf(sig))) } if ps.timer != nil { ps.timer.updateDequeuedSignalLocked(ps.SignalInfo) @@ -137,5 +138,5 @@ func (p *pendingSignals) discardSpecific(sig linux.Signal) { } q.pendingSignalList.Reset() q.length = 0 - p.pendingSet &^= linux.SignalSetOf(sig) + p.pendingSet.Store(p.pendingSet.RacyLoad() &^ uint64(linux.SignalSetOf(sig))) } diff --git a/pkg/sentry/kernel/task_block.go b/pkg/sentry/kernel/task_block.go index b65dea55cf..f5c537b32d 100644 --- a/pkg/sentry/kernel/task_block.go +++ b/pkg/sentry/kernel/task_block.go @@ -234,7 +234,7 @@ func (t *Task) Interrupted() bool { } // Indicate that t's task goroutine is still responsive (i.e. reset the // watchdog timer). - t.accountTaskGoroutineRunning() + t.touchGostateTime() return false } diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go index 69a1741808..02c4a34c24 100644 --- a/pkg/sentry/kernel/task_exit.go +++ b/pkg/sentry/kernel/task_exit.go @@ -100,19 +100,28 @@ func (t *Task) killLocked() { t.interrupt() } +// Killed implements context.Blocker.Killed. +func (t *Task) Killed() bool { + if t.killed() { + return true + } + // Indicate that t's task goroutine is still responsive (i.e. reset the + // watchdog timer). + t.touchGostateTime() + return false +} + // killed returns true if t has a SIGKILL pending. killed is analogous to // Linux's fatal_signal_pending(). // // Preconditions: The caller must be running on the task goroutine. func (t *Task) killed() bool { - t.tg.signalHandlers.mu.Lock() - defer t.tg.signalHandlers.mu.Unlock() - return t.killedLocked() + return linux.SignalSet(t.pendingSignals.pendingSet.Load())&linux.SignalSetOf(linux.SIGKILL) != 0 } // Preconditions: The signal mutex must be locked. func (t *Task) killedLocked() bool { - return t.pendingSignals.pendingSet&linux.SignalSetOf(linux.SIGKILL) != 0 + return linux.SignalSet(t.pendingSignals.pendingSet.RacyLoad())&linux.SignalSetOf(linux.SIGKILL) != 0 } // PrepareExit indicates an exit with the given status. diff --git a/pkg/sentry/kernel/task_sched.go b/pkg/sentry/kernel/task_sched.go index 8f3de982d0..5cce565f59 100644 --- a/pkg/sentry/kernel/task_sched.go +++ b/pkg/sentry/kernel/task_sched.go @@ -116,14 +116,6 @@ func (t *Task) accountTaskGoroutineLeave(state TaskGoroutineState) { t.gostateSeq.EndWrite() } -// Preconditions: The caller must be running on the task goroutine. -func (t *Task) accountTaskGoroutineRunning() { - if oldState := t.TaskGoroutineState(); oldState != TaskGoroutineRunningSys { - panic(fmt.Sprintf("Task goroutine in state %v (expected %v)", oldState, TaskGoroutineRunningSys)) - } - t.touchGostateTime() -} - // Preconditions: The caller must be running on the task goroutine. func (t *Task) touchGostateTime() { t.gostateTime.Store(t.k.cpuClock.Load()) diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go index 36d02b6cd1..fd944a39ca 100644 --- a/pkg/sentry/kernel/task_signals.go +++ b/pkg/sentry/kernel/task_signals.go @@ -146,9 +146,7 @@ func (tg *ThreadGroup) discardSpecificLocked(sig linux.Signal) { // PendingSignals returns the set of pending signals. func (t *Task) PendingSignals() linux.SignalSet { - sh := t.tg.signalLock() - defer sh.mu.Unlock() - return t.pendingSignals.pendingSet | t.tg.pendingSignals.pendingSet + return linux.SignalSet(t.pendingSignals.pendingSet.Load() | t.tg.pendingSignals.pendingSet.Load()) } // deliverSignal delivers the given signal and returns the following run state. @@ -612,7 +610,7 @@ func (t *Task) setSignalMaskLocked(mask linux.SignalSet) { // signal, but will no longer do so as a result of its new signal mask, so // we have to pick a replacement. blocked := mask &^ oldMask - blockedGroupPending := blocked & t.tg.pendingSignals.pendingSet + blockedGroupPending := blocked & linux.SignalSet(t.tg.pendingSignals.pendingSet.RacyLoad()) if blockedGroupPending != 0 && t.interrupted() { linux.ForEachSignal(blockedGroupPending, func(sig linux.Signal) { if nt := t.tg.findSignalReceiverLocked(sig); nt != nil { @@ -626,7 +624,7 @@ func (t *Task) setSignalMaskLocked(mask linux.SignalSet) { // the old mask, and at least one such signal is pending, we may now need // to handle that signal. unblocked := oldMask &^ mask - unblockedPending := unblocked & (t.pendingSignals.pendingSet | t.tg.pendingSignals.pendingSet) + unblockedPending := unblocked & linux.SignalSet(t.pendingSignals.pendingSet.RacyLoad()|t.tg.pendingSignals.pendingSet.RacyLoad()) if unblockedPending != 0 { t.interruptSelf() } diff --git a/pkg/sentry/mm/address_space.go b/pkg/sentry/mm/address_space.go index df30235f5f..b2064b6eb3 100644 --- a/pkg/sentry/mm/address_space.go +++ b/pkg/sentry/mm/address_space.go @@ -18,6 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/hostarch" "gvisor.dev/gvisor/pkg/sentry/memmap" "gvisor.dev/gvisor/pkg/sentry/pgalloc" @@ -172,7 +173,7 @@ func (mm *MemoryManager) Deactivate() { // - ar.Length() != 0. // - ar must be page-aligned. // - pseg == mm.pmas.LowerBoundSegment(ar.Start). -func (mm *MemoryManager) mapASLocked(pseg pmaIterator, ar hostarch.AddrRange, platformEffect memmap.MMapPlatformEffect) error { +func (mm *MemoryManager) mapASLocked(ctx context.Context, pseg pmaIterator, ar hostarch.AddrRange, platformEffect memmap.MMapPlatformEffect) error { // By default, map entire pmas at a time, under the assumption that there // is no cost to mapping more of a pma than necessary. mapAR := hostarch.AddrRange{0, ^hostarch.Addr(hostarch.PageSize - 1)} @@ -217,8 +218,33 @@ func (mm *MemoryManager) mapASLocked(pseg pmaIterator, ar hostarch.AddrRange, pl perms.Write = false } if perms.Any() { // MapFile precondition - if err := mm.as.MapFile(pmaMapAR.Start, pma.file, pseg.fileRangeOf(pmaMapAR), perms, platformEffect == memmap.PlatformEffectCommit); err != nil { - return err + // If the length of the mapping exceeds singleMapThreshold, call + // AddressSpace.MapFile() on singleMapThreshold-aligned chunks so + // we can check ctx.Killed() reasonably frequently. + const singleMapThreshold = 1 << 30 + if pmaMapAR.Length() <= singleMapThreshold { + if err := mm.as.MapFile(pmaMapAR.Start, pma.file, pseg.fileRangeOf(pmaMapAR), perms, platformEffect == memmap.PlatformEffectCommit); err != nil { + return err + } + if ctx.Killed() { + return linuxerr.EINTR + } + } else { + windowStart := pmaMapAR.Start &^ (singleMapThreshold - 1) + for { + windowAR := hostarch.AddrRange{windowStart, windowStart + singleMapThreshold} + thisMapAR := pmaMapAR.Intersect(windowAR) + if err := mm.as.MapFile(thisMapAR.Start, pma.file, pseg.fileRangeOf(thisMapAR), perms, platformEffect == memmap.PlatformEffectCommit); err != nil { + return err + } + if ctx.Killed() { + return linuxerr.EINTR + } + windowStart = windowAR.End + if windowStart >= pmaMapAR.End { + break + } + } } } pseg = pseg.NextSegment() diff --git a/pkg/sentry/mm/io.go b/pkg/sentry/mm/io.go index 8c2c233279..ffa8f7d8b8 100644 --- a/pkg/sentry/mm/io.go +++ b/pkg/sentry/mm/io.go @@ -514,7 +514,7 @@ func (mm *MemoryManager) handleASIOFault(ctx context.Context, addr hostarch.Addr // anymore. mm.activeMu.DowngradeLock() - err = mm.mapASLocked(pseg, ar, memmap.PlatformEffectDefault) + err = mm.mapASLocked(ctx, pseg, ar, memmap.PlatformEffectDefault) mm.activeMu.RUnlock() return translateIOError(ctx, err) } diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go index 29b7616554..079d8a333c 100644 --- a/pkg/sentry/mm/syscalls.go +++ b/pkg/sentry/mm/syscalls.go @@ -66,7 +66,7 @@ func (mm *MemoryManager) HandleUserFault(ctx context.Context, addr hostarch.Addr mm.activeMu.DowngradeLock() // Map the faulted page into the active AddressSpace. - err = mm.mapASLocked(pseg, ar, memmap.PlatformEffectDefault) + err = mm.mapASLocked(ctx, pseg, ar, memmap.PlatformEffectDefault) mm.activeMu.RUnlock() return err } @@ -201,7 +201,7 @@ func (mm *MemoryManager) populateVMA(ctx context.Context, vseg vmaIterator, ar h // Downgrade to a read-lock on activeMu since we don't need to mutate pmas // anymore. mm.activeMu.DowngradeLock() - err = mm.mapASLocked(pseg, ar, platformEffect) + err = mm.mapASLocked(ctx, pseg, ar, platformEffect) mm.activeMu.RUnlock() return err } @@ -250,7 +250,7 @@ func (mm *MemoryManager) populateVMAAndUnlock(ctx context.Context, vseg vmaItera // As above, errors are silently ignored. mm.activeMu.DowngradeLock() - mm.mapASLocked(pseg, ar, platformEffect) + mm.mapASLocked(ctx, pseg, ar, platformEffect) mm.activeMu.RUnlock() } @@ -930,7 +930,7 @@ func (mm *MemoryManager) MLock(ctx context.Context, addr hostarch.Addr, length u mm.mappingMu.RUnlock() if mm.as != nil { mm.activeMu.DowngradeLock() - err := mm.mapASLocked(mm.pmas.LowerBoundSegment(ar.Start), ar, memmap.PlatformEffectCommit) + err := mm.mapASLocked(ctx, mm.pmas.LowerBoundSegment(ar.Start), ar, memmap.PlatformEffectCommit) mm.activeMu.RUnlock() if err != nil { return err @@ -1014,7 +1014,7 @@ func (mm *MemoryManager) MLockAll(ctx context.Context, opts MLockAllOpts) error mm.mappingMu.RUnlock() if mm.as != nil { mm.activeMu.DowngradeLock() - mm.mapASLocked(mm.pmas.FirstSegment(), mm.applicationAddrRange(), memmap.PlatformEffectCommit) + mm.mapASLocked(ctx, mm.pmas.FirstSegment(), mm.applicationAddrRange(), memmap.PlatformEffectCommit) mm.activeMu.RUnlock() } else { mm.activeMu.Unlock()