Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions cmd/gvisor-gpu-ckpt/cuda.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
//go:build linux

package main

/*
#cgo LDFLAGS: -ldl
#include <dlfcn.h>
#include <stdlib.h>
#include <string.h>

// Function pointer types matching cuCheckpointProcess* signatures.
// Each takes (int pid, void* args) and returns int (CUresult).
typedef int (*cuCheckpointFn)(int, void*);
typedef int (*cuInitFn)(unsigned int);

static void* libcuda_handle = NULL;
static cuInitFn fn_init = NULL;
static cuCheckpointFn fn_lock = NULL;
static cuCheckpointFn fn_checkpoint = NULL;
static cuCheckpointFn fn_restore = NULL;
static cuCheckpointFn fn_unlock = NULL;

// load_libcuda opens libcuda.so.1, initializes CUDA, and resolves
// all checkpoint symbols.
// Returns 0 on success, -1 if dlopen fails, -2 if any dlsym fails,
// -3 if cuInit fails.
static int load_libcuda() {
if (libcuda_handle != NULL) return 0;

libcuda_handle = dlopen("libcuda.so.1", RTLD_NOW);
if (!libcuda_handle) return -1;

fn_init = (cuInitFn)dlsym(libcuda_handle, "cuInit");
fn_lock = (cuCheckpointFn)dlsym(libcuda_handle, "cuCheckpointProcessLock");
fn_checkpoint = (cuCheckpointFn)dlsym(libcuda_handle, "cuCheckpointProcessCheckpoint");
fn_restore = (cuCheckpointFn)dlsym(libcuda_handle, "cuCheckpointProcessRestore");
fn_unlock = (cuCheckpointFn)dlsym(libcuda_handle, "cuCheckpointProcessUnlock");

if (!fn_init || !fn_lock || !fn_checkpoint || !fn_restore || !fn_unlock) return -2;

int rc = fn_init(0);
if (rc != 0) return -3;
return 0;
}

// call_checkpoint_fn invokes a cuCheckpointProcess* function with a
// zeroed 64-byte args buffer (matches NVIDIA's API expectation).
static int call_checkpoint_fn(cuCheckpointFn fn, int pid) {
char args[64];
memset(args, 0, sizeof(args));
return fn(pid, args);
}

static int cuda_lock(int pid) { return call_checkpoint_fn(fn_lock, pid); }
static int cuda_checkpoint(int pid) { return call_checkpoint_fn(fn_checkpoint, pid); }
static int cuda_restore(int pid) { return call_checkpoint_fn(fn_restore, pid); }
static int cuda_unlock(int pid) { return call_checkpoint_fn(fn_unlock, pid); }
*/
import "C"
import (
"fmt"
"unsafe"
)

func loadLibcuda() error {
rc := C.load_libcuda()
switch rc {
case 0:
return nil
case -1:
return fmt.Errorf("dlopen(libcuda.so.1) failed: %s", C.GoString(C.dlerror()))
case -2:
return fmt.Errorf("dlsym failed: one or more cuCheckpointProcess* symbols not found (driver 570+ required)")
case -3:
return fmt.Errorf("cuInit(0) failed")
default:
return fmt.Errorf("load_libcuda returned unknown error: %d", rc)
}
}

func checkpointLock(pid int) error {
rc := C.cuda_lock(C.int(pid))
if rc != 0 {
return fmt.Errorf("cuCheckpointProcessLock(pid=%d) failed: rc=%d", pid, rc)
}
return nil
}

func checkpointCheckpoint(pid int) error {
rc := C.cuda_checkpoint(C.int(pid))
if rc != 0 {
return fmt.Errorf("cuCheckpointProcessCheckpoint(pid=%d) failed: rc=%d", pid, rc)
}
return nil
}

func checkpointRestore(pid int) error {
rc := C.cuda_restore(C.int(pid))
if rc != 0 {
return fmt.Errorf("cuCheckpointProcessRestore(pid=%d) failed: rc=%d", pid, rc)
}
return nil
}

func checkpointUnlock(pid int) error {
rc := C.cuda_unlock(C.int(pid))
if rc != 0 {
return fmt.Errorf("cuCheckpointProcessUnlock(pid=%d) failed: rc=%d", pid, rc)
}
return nil
}

// Ensure unsafe import is used (needed for cgo).
var _ = unsafe.Pointer(nil)
119 changes: 119 additions & 0 deletions cmd/gvisor-gpu-ckpt/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
//go:build linux

// Binary gvisor-gpu-ckpt is a SaveRestoreExec binary for gVisor that
// handles GPU checkpoint/restore via NVIDIA's cuCheckpointProcess* API.
//
// This binary runs INSIDE the gVisor sandbox. It targets the container's
// init process (PID 1) because cuda-checkpoint ioctls must go through
// nvproxy to reach the real NVIDIA driver. Calling from the host with
// the sentry PID fails because the sentry creates GPU contexts via raw
// ioctls without libcuda.so initialization.
//
// gVisor invokes this binary with the GVISOR_SAVE_RESTORE_AUTO_EXEC_MODE
// env var set to "save", "restore", or "resume".
//
// Usage with gVisor:
//
// runsc --save-restore-exec-argv=/path/to/gvisor-gpu-ckpt checkpoint <container-id>
package main

import (
"fmt"
"os"
"strconv"
"strings"
)

func main() {
mode := os.Getenv("GVISOR_SAVE_RESTORE_AUTO_EXEC_MODE")
if mode == "" {
fmt.Fprintln(os.Stderr, "gvisor-gpu-ckpt: GVISOR_SAVE_RESTORE_AUTO_EXEC_MODE not set")
os.Exit(1)
}

pid, err := getTargetPID()
if err != nil {
fmt.Fprintf(os.Stderr, "gvisor-gpu-ckpt: failed to determine target PID: %v\n", err)
os.Exit(1)
}

if err := loadLibcuda(); err != nil {
// No CUDA driver available — if there are no GPU contexts,
// this is expected and not an error.
fmt.Fprintf(os.Stderr, "gvisor-gpu-ckpt: %v (no GPU contexts to checkpoint)\n", err)
os.Exit(0)
}

switch mode {
case "save":
if err := doSave(pid); err != nil {
fmt.Fprintf(os.Stderr, "gvisor-gpu-ckpt: save failed: %v\n", err)
os.Exit(1)
}
case "restore", "resume":
if err := doRestore(pid); err != nil {
fmt.Fprintf(os.Stderr, "gvisor-gpu-ckpt: %s failed: %v\n", mode, err)
os.Exit(1)
}
default:
fmt.Fprintf(os.Stderr, "gvisor-gpu-ckpt: unknown mode %q\n", mode)
os.Exit(1)
}
}

func doSave(pid int) error {
fmt.Fprintf(os.Stderr, "gvisor-gpu-ckpt: locking GPU contexts for PID %d\n", pid)
if err := checkpointLock(pid); err != nil {
if isNoContextError(err) {
fmt.Fprintf(os.Stderr, "gvisor-gpu-ckpt: no CUDA contexts for PID %d, nothing to checkpoint\n", pid)
return nil
}
return fmt.Errorf("lock: %w", err)
}

fmt.Fprintf(os.Stderr, "gvisor-gpu-ckpt: checkpointing GPU state for PID %d\n", pid)
if err := checkpointCheckpoint(pid); err != nil {
// Attempt to unlock on checkpoint failure.
_ = checkpointUnlock(pid)
return fmt.Errorf("checkpoint: %w", err)
}

fmt.Fprintf(os.Stderr, "gvisor-gpu-ckpt: GPU checkpoint complete for PID %d\n", pid)
return nil
}

func doRestore(pid int) error {
fmt.Fprintf(os.Stderr, "gvisor-gpu-ckpt: restoring GPU state for PID %d\n", pid)
if err := checkpointRestore(pid); err != nil {
return fmt.Errorf("restore: %w", err)
}

fmt.Fprintf(os.Stderr, "gvisor-gpu-ckpt: unlocking GPU contexts for PID %d\n", pid)
if err := checkpointUnlock(pid); err != nil {
return fmt.Errorf("unlock: %w", err)
}

fmt.Fprintf(os.Stderr, "gvisor-gpu-ckpt: GPU restore complete for PID %d\n", pid)
return nil
}

// isNoContextError returns true if the error indicates the target PID
// has no CUDA contexts (rc=3 = CUDA_ERROR_NOT_INITIALIZED).
func isNoContextError(err error) bool {
return err != nil && strings.Contains(err.Error(), "rc=3")
}

// getTargetPID returns the PID to pass to cuCheckpointProcess* calls.
// Inside the gVisor sandbox, this should be the container's init process
// (PID 1), since all GPU contexts route through nvproxy to the sentry.
// Override with GVISOR_CHECKPOINT_PID env var if needed.
func getTargetPID() (int, error) {
if s := os.Getenv("GVISOR_CHECKPOINT_PID"); s != "" {
pid, err := strconv.Atoi(s)
if err != nil {
return 0, fmt.Errorf("invalid GVISOR_CHECKPOINT_PID %q: %w", s, err)
}
return pid, nil
}
return 1, nil
}
10 changes: 7 additions & 3 deletions pkg/sentry/devices/nvproxy/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,8 @@ func (o *rmAllocObject) Release(ctx context.Context) func() {
// miscObject is an objectImpl tracking a driver object allocated by something
// other than an invocation of NV_ESC_RM_ALLOC, whose class is not represented
// by a more specific type.
//
// +stateify savable
type miscObject struct {
object
}
Expand Down Expand Up @@ -409,14 +411,16 @@ func (c *rootClient) getObject(ctx context.Context, h nvgpu.Handle) *object {
}

// osDescMem is an objectImpl tracking a NV01_MEMORY_SYSTEM_OS_DESCRIPTOR.
//
// +stateify savable
type osDescMem struct {
object
pinnedRanges []mm.PinnedRange
pinnedRanges []mm.PinnedRange `state:"nosave"`

// If m is non-zero, it is the start address of a mapping of length len
// that should be unmapped when the osDescMem is released.
m uintptr
len uintptr
m uintptr `state:"nosave"`
len uintptr `state:"nosave"`
}

// Release implements objectImpl.Release.
Expand Down
106 changes: 94 additions & 12 deletions pkg/sentry/devices/nvproxy/save_restore_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,32 +19,114 @@ package nvproxy

import (
goContext "context"
"fmt"
"path/filepath"

"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/devutil"
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/waiter"
)

type restoreContext struct {
context.NoTask
log.Logger
goContext.Context
}

func newRestoreContext(ctx goContext.Context) context.Context {
return &restoreContext{Context: ctx, Logger: log.Log()}
}

func (nvp *nvproxy) beforeSaveImpl() {
nvp.clientsMu.RLock()
defer nvp.clientsMu.RUnlock()
if len(nvp.clients) != 0 {
panic("can't save with live nvproxy clients")
}
// GPU state is preserved by SaveRestoreExec; nvproxy state is
// serialized by stateify.
}

func (nvp *nvproxy) afterLoadImpl(goContext.Context) {
// no-op
// no-op: frontendFDs map is restored by stateify,
// ABI is rebuilt in save_restore.go:afterLoad().
}

func (fd *frontendFD) beforeSaveImpl() {
panic("nvproxy.frontendFD is not saveable")
// hostFD value is serialized but stale on restore; afterLoadImpl
// replaces it with a freshly opened FD.
}

func (fd *frontendFD) afterLoadImpl(goContext.Context) {
panic("nvproxy.frontendFD is not restorable")
func (fd *frontendFD) afterLoadImpl(goCtx goContext.Context) {
ctx := newRestoreContext(goCtx)
devPath := fd.dev.basename()

// Reopen host device file.
if fd.dev.nvp.useDevGofer {
provider := devutil.GoferClientProviderFromContext(ctx)
if provider == nil {
panic(fmt.Sprintf("nvproxy: device gofer client provider not available for %s on restore", devPath))
}
devClient := provider.GetDevGoferClient(fd.containerName)
if devClient == nil {
panic(fmt.Sprintf("nvproxy: device gofer client for container %q not available on restore", fd.containerName))
}
hostFD, err := devClient.OpenAt(ctx, devPath, unix.O_RDWR)
if err != nil {
panic(fmt.Sprintf("nvproxy: failed to reopen device %s via gofer on restore: %v", devPath, err))
}
fd.hostFD = int32(hostFD)
} else {
abspath := filepath.Join("/dev", devPath)
hostFD, err := unix.Openat(-1, abspath, unix.O_RDWR|unix.O_NOFOLLOW, 0)
if err != nil {
panic(fmt.Sprintf("nvproxy: failed to reopen host %s on restore: %v", abspath, err))
}
fd.hostFD = int32(hostFD)
}

// Re-initialize the eventListener callback (not serialized) but do NOT
// re-register the entry; it's already in the queue from checkpoint state.
fd.internalEntry.Init(fd, waiter.AllEvents)
if err := fdnotifier.AddFD(fd.hostFD, &fd.internalQueue); err != nil {
panic(fmt.Sprintf("nvproxy: fdnotifier.AddFD failed on restore: %v", err))
}

fd.memmapFile.SetFD(int(fd.hostFD))
}

func (fd *uvmFD) beforeSaveImpl() {
panic("nvproxy.uvmFD is not saveable")
// hostFD value is serialized but stale on restore; afterLoadImpl
// replaces it with a freshly opened FD.
}

func (fd *uvmFD) afterLoadImpl(goContext.Context) {
panic("nvproxy.uvmFD is not restorable")
func (fd *uvmFD) afterLoadImpl(goCtx goContext.Context) {
ctx := newRestoreContext(goCtx)
// Reopen host device file.
if fd.dev.nvp.useDevGofer {
provider := devutil.GoferClientProviderFromContext(ctx)
if provider == nil {
panic("nvproxy: device gofer client provider not available for nvidia-uvm on restore")
}
devClient := provider.GetDevGoferClient(fd.containerName)
if devClient == nil {
panic(fmt.Sprintf("nvproxy: device gofer client for container %q not available on restore", fd.containerName))
}
hostFD, err := devClient.OpenAt(ctx, "nvidia-uvm", unix.O_RDWR)
if err != nil {
panic(fmt.Sprintf("nvproxy: failed to reopen nvidia-uvm via gofer on restore: %v", err))
}
fd.hostFD = int32(hostFD)
} else {
hostFD, err := unix.Openat(-1, "/dev/nvidia-uvm", unix.O_RDWR|unix.O_NOFOLLOW, 0)
if err != nil {
panic(fmt.Sprintf("nvproxy: failed to reopen host /dev/nvidia-uvm on restore: %v", err))
}
fd.hostFD = int32(hostFD)
}

if err := fdnotifier.AddFD(fd.hostFD, &fd.queue); err != nil {
panic(fmt.Sprintf("nvproxy: fdnotifier.AddFD failed on restore: %v", err))
}

fd.memmapFile.SetFD(int(fd.hostFD))
fd.memmapFile.RequireAddrEqualsFileOffset()
}
Loading