From b4e4c80ce58a267149456d992e39dc2e88eefadc Mon Sep 17 00:00:00 2001 From: Etienne Perot Date: Mon, 9 Dec 2024 19:00:15 -0800 Subject: [PATCH] Internal change (diffbased). PiperOrigin-RevId: 704502884 --- .../benchmarks/httpbench/httpbench.go | 8 +- test/kubernetes/benchmarks/nginx.go | 6 +- test/kubernetes/benchmarks/postgresql.go | 2 +- test/kubernetes/benchmarks/redis.go | 4 +- test/kubernetes/benchmarks/stablediffusion.go | 4 +- test/kubernetes/benchmarks/wordpress.go | 4 +- test/kubernetes/testcluster/BUILD | 2 + test/kubernetes/testcluster/client.go | 190 ++++++++++++++++++ test/kubernetes/testcluster/objects.go | 100 ++------- test/kubernetes/testcluster/testcluster.go | 106 +++++++--- 10 files changed, 304 insertions(+), 122 deletions(-) create mode 100644 test/kubernetes/testcluster/client.go diff --git a/test/kubernetes/benchmarks/httpbench/httpbench.go b/test/kubernetes/benchmarks/httpbench/httpbench.go index 757e4ef463..a698f08cc2 100644 --- a/test/kubernetes/benchmarks/httpbench/httpbench.go +++ b/test/kubernetes/benchmarks/httpbench/httpbench.go @@ -317,17 +317,19 @@ func getMeasurements(data string, onlyReport []MetricType, wantPercentiles []int return false } var metricValues []benchmetric.MetricValue - var totalRequests int + totalRequests := 0 + totalRequestsFound := false for _, line := range strings.Split(data, "\n") { if match := wrk2TotalRequestsRe.FindStringSubmatch(line); match != nil { gotRequests, err := strconv.ParseInt(strings.ReplaceAll(match[1], ",", ""), 10, 64) if err != nil { return 0, nil, fmt.Errorf("failed to parse %q from line %q: %v", match[1], line, err) } - if totalRequests != 0 { + if totalRequestsFound { return 0, nil, fmt.Errorf("found multiple lines matching 'total requests' regex: %d vs %d (%q)", totalRequests, gotRequests, line) } totalRequests = int(gotRequests) + totalRequestsFound = true continue } if match := wrk2LatencyPercentileRE.FindStringSubmatch(line); match != nil { @@ -375,7 +377,7 @@ func getMeasurements(data string, onlyReport []MetricType, wantPercentiles []int continue } } - if totalRequests == 0 { + if !totalRequestsFound { return 0, nil, fmt.Errorf("could not find total requests in output: %q", data) } return totalRequests, metricValues, nil diff --git a/test/kubernetes/benchmarks/nginx.go b/test/kubernetes/benchmarks/nginx.go index 31bc3af7c6..fb5deaae2f 100644 --- a/test/kubernetes/benchmarks/nginx.go +++ b/test/kubernetes/benchmarks/nginx.go @@ -48,9 +48,9 @@ var ( // The test expects that it contains the files to be served at /local, // and will serve files out of `nginxServingDir`. nginxCommand = []string{"nginx", "-c", "/etc/nginx/nginx.conf"} - nginxDocKibibytes = []int{1, 10, 100, 10240} - threads = []int{1, 8, 64, 1000} - targetQPS = []int{1, 8, 64, httpbench.InfiniteQPS} + nginxDocKibibytes = []int{1, 10240} + threads = []int{1, 8, 1000} + targetQPS = []int{1, 64, httpbench.InfiniteQPS} wantPercentiles = []int{50, 95, 99} ) diff --git a/test/kubernetes/benchmarks/postgresql.go b/test/kubernetes/benchmarks/postgresql.go index 7b3d0ec441..d6b127f2e0 100644 --- a/test/kubernetes/benchmarks/postgresql.go +++ b/test/kubernetes/benchmarks/postgresql.go @@ -46,7 +46,7 @@ const ( ) var ( - numConnections = []int{1, 2, 6, 16, 32, 64} + numConnections = []int{1, 2, 12, 64} ) // BenchmarkPostgresPGBench runs a PostgreSQL pgbench test. diff --git a/test/kubernetes/benchmarks/redis.go b/test/kubernetes/benchmarks/redis.go index cfe0a8302c..73ec1e9370 100644 --- a/test/kubernetes/benchmarks/redis.go +++ b/test/kubernetes/benchmarks/redis.go @@ -49,9 +49,9 @@ const ( ) var ( - numConnections = []int{1, 2, 4, 8, 16, 32} + numConnections = []int{1, 4, 32} latencyPercentiles = []int{50, 95, 99} - operations = []string{"SET", "GET", "MSET", "LPUSH", "LRANGE_500"} + operations = []string{"GET", "MSET", "LRANGE_500"} ) // BenchmarkRedis runs the Redis performance benchmark using redis-benchmark. diff --git a/test/kubernetes/benchmarks/stablediffusion.go b/test/kubernetes/benchmarks/stablediffusion.go index 5bfe239af1..a6fdc6ab54 100644 --- a/test/kubernetes/benchmarks/stablediffusion.go +++ b/test/kubernetes/benchmarks/stablediffusion.go @@ -34,7 +34,7 @@ import ( const ( // Container image for Stable Diffusion XL. - stableDiffusionImage = k8s.ImageRepoPrefix + "gpu/stable-diffusion-xl" + stableDiffusionImage = k8s.ImageRepoPrefix + "gpu/stable-diffusion-xl:latest" ) // kubernetesPodRunner implements `stablediffusion.ContainerRunner`. @@ -171,7 +171,7 @@ func RunStableDiffusionXL(ctx context.Context, t *testing.T, k8sCtx k8sctx.Kuber t.Skipf("refiner failed in previous benchmark; skipping benchmark with refiner") } } - testCtx, testCancel := context.WithTimeout(ctx, 15*time.Minute) + testCtx, testCancel := context.WithTimeout(ctx, 50*time.Minute) defer testCancel() prompt := &stablediffusion.XLPrompt{ Query: test.query, diff --git a/test/kubernetes/benchmarks/wordpress.go b/test/kubernetes/benchmarks/wordpress.go index 781514300d..aacde41aa4 100644 --- a/test/kubernetes/benchmarks/wordpress.go +++ b/test/kubernetes/benchmarks/wordpress.go @@ -52,8 +52,8 @@ const ( ) var ( - threads = []int{1, 8, 64, 1000} - targetQPS = []int{1, 8, 64, httpbench.InfiniteQPS} + threads = []int{1, 8, 1000} + targetQPS = []int{1, 64, httpbench.InfiniteQPS} wantPercentiles = []int{50, 95, 99} ) diff --git a/test/kubernetes/testcluster/BUILD b/test/kubernetes/testcluster/BUILD index 3f38de140f..200c20a6c2 100644 --- a/test/kubernetes/testcluster/BUILD +++ b/test/kubernetes/testcluster/BUILD @@ -8,6 +8,7 @@ package( go_library( name = "testcluster", srcs = [ + "client.go", "objects.go", "testcluster.go", ], @@ -16,6 +17,7 @@ go_library( ], deps = [ "//pkg/log", + "//pkg/rand", "//pkg/sync", "//test/kubernetes:test_range_config_go_proto", "@io_k8s_api//apps/v1:go_default_library", diff --git a/test/kubernetes/testcluster/client.go b/test/kubernetes/testcluster/client.go new file mode 100644 index 0000000000..c6ec491139 --- /dev/null +++ b/test/kubernetes/testcluster/client.go @@ -0,0 +1,190 @@ +// Copyright 2024 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testcluster + +import ( + "context" + "encoding/hex" + "fmt" + "io" + "time" + + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/rand" + "k8s.io/client-go/kubernetes" +) + +// KubernetesReq is a function that performs a request with a Kubernetes +// client. +type KubernetesReq func(context.Context, kubernetes.Interface) error + +// KubernetesClient is an interface that wraps Kubernetes requests. +type KubernetesClient interface { + // Do performs a request with a Kubernetes client. + Do(context.Context, KubernetesReq) error +} + +// simpleClient is a KubernetesClient that wraps a simple Kubernetes client. +// The `Do` function simply calls the function with the given `client`. +type simpleClient struct { + client kubernetes.Interface +} + +// Do implements `KubernetesClient.Do`. +func (sc *simpleClient) Do(ctx context.Context, fn KubernetesReq) error { + return fn(ctx, sc.client) +} + +// retryableClient is a KubernetesClient that can retry requests by creating +// *new instances* of Kubernetes clients, rather than just retrying requests. +type retryableClient struct { + // client is a Kubernetes client factory, used to create new instances of + // Kubernetes clients and to determine whether a request should be retried. + client UnstableClient + + // clientCh is a channel used to share Kubernetes clients between multiple + // requests. + clientCh chan kubernetes.Interface +} + +// UnstableClient is a Kubernetes client factory that can create new instances +// of Kubernetes clients and determine whether a request should be retried. +type UnstableClient interface { + // Client creates a new instance of a Kubernetes client. + // This function may also block (in a context-respecting manner) + // in order to implement backoff between Kubernetes client creation + // attempts. + Client(context.Context) (kubernetes.Interface, error) + + // RetryError returns whether the given error should be retried. + // numAttempt is the number of attempts made so far. + // This function may also block (in a context-respecting manner) + // in order to implement backoff between request retries. + RetryError(ctx context.Context, err error, numAttempt int) bool +} + +// NewRetryableClient creates a new retryable Kubernetes client. +// It takes an `UnstableClient` as input, which is used to create new +// instances of Kubernetes clients as needed, and to determine whether +// a request should be retried. +// This can be safely used concurrently, in which case additional +// Kubernetes clients will be created as needed, and reused when +// possible (but never garbage-collected, unless they start emitting +// retriable errors). +// It will immediately create an initial Kubernetes client from the +// `UnstableClient` as the initial client to use. +func NewRetryableClient(ctx context.Context, client UnstableClient) (KubernetesClient, error) { + initialClient, err := client.Client(ctx) + if err != nil { + return nil, fmt.Errorf("cannot get initial client: %w", err) + } + clientCh := make(chan kubernetes.Interface, 128) + clientCh <- initialClient + return &retryableClient{client: client, clientCh: clientCh}, nil +} + +// getClient returns a Kubernetes client. +// It will either return the client from the clientCh, or create a new one +// if none are available. +func (rc *retryableClient) getClient(ctx context.Context) (kubernetes.Interface, error) { + select { + case client := <-rc.clientCh: + return client, nil + default: + client, err := rc.client.Client(ctx) + if err != nil { + return nil, fmt.Errorf("cannot get client: %w", err) + } + return client, nil + } +} + +// putClient puts a Kubernetes client back into the `clientCh`. +func (rc *retryableClient) putClient(client kubernetes.Interface) { + select { + case rc.clientCh <- client: + default: + // If full, just spawn a goroutine to put it back when possible. + go func() { rc.clientCh <- client }() + } +} + +// Do implements `KubernetesClient.Do`. +// It retries the request if the error is retryable. +func (rc *retryableClient) Do(ctx context.Context, fn KubernetesReq) error { + client, err := rc.getClient(ctx) + if err != nil { + return fmt.Errorf("cannot get client: %w", err) + } + if err = fn(ctx, client); err == nil || !rc.client.RetryError(ctx, err, 0) { // Happy path. + rc.putClient(client) + return err + } + + // We generate a random ID here to distinguish between multiple retriable + // operations in the logs. + var operationIDBytes [8]byte + if _, err := io.ReadFull(rand.Reader, operationIDBytes[:]); err != nil { + return fmt.Errorf("cannot read random bytes: %w", err) + } + operationID := hex.EncodeToString(operationIDBytes[:]) + + logger := log.BasicRateLimitedLogger(30 * time.Second) + deadline, hasDeadline := ctx.Deadline() + if hasDeadline { + logger.Infof("Retryable operation [%s] @ %s failed on initial attempt with retryable error (%v); retrying until %v...", operationID, time.Now().Format(time.TimeOnly), err, deadline) + } else { + logger.Infof("Retryable operation [%s] @ %s failed on initial attempt with retryable error (%v); retrying...", operationID, time.Now().Format(time.TimeOnly), err) + } + lastErr := err + numAttempt := 1 + for ctx.Err() == nil { + numAttempt++ + client, err := rc.getClient(ctx) + if err != nil { + return fmt.Errorf("cannot get client: %w", err) + } + if err = fn(ctx, client); err == nil || !rc.client.RetryError(ctx, err, numAttempt) { + // We don't use `logger` here because we want to make sure it is logged + // so that the logs reflect that the operation succeeded upon a retry. + // Otherwise the logs can be confusing because it may seem that we are + // still in the retry loop. + if err == nil { + log.Infof("Retryable operation [%s] @ %s succeeded on attempt %d.", operationID, time.Now().Format(time.TimeOnly), numAttempt) + } else { + log.Infof("Retryable operation [%s] @ %s attempt %d returned non-retryable error: %v.", operationID, time.Now().Format(time.TimeOnly), numAttempt, numAttempt, err) + } + rc.putClient(client) + return err + } + logger.Infof("Retryable operation [%s] @ %s failed on attempt %d (retryable error: %v); will retry again...", operationID, time.Now().Format(time.TimeOnly), numAttempt, err, deadline) + lastErr = err + } + log.Infof("Retryable operation [%s] @ %s failed after %d attempts with retryable error (%v) but context was cancelled (%v); bailing out.", operationID, time.Now().Format(time.TimeOnly), numAttempt, lastErr) + return lastErr +} + +// request wraps a function that takes a KubernetesClient and returns a value of +// type T. It is useful for functions that return more than just an error, +// e.g. lookup functions that return a pod info or other Kubernetes resources. +func request[T any](ctx context.Context, client KubernetesClient, fn func(context.Context, kubernetes.Interface) (T, error)) (T, error) { + var result T + err := client.Do(ctx, func(ctx context.Context, client kubernetes.Interface) error { + var err error + result, err = fn(ctx, client) + return err + }) + return result, err +} diff --git a/test/kubernetes/testcluster/objects.go b/test/kubernetes/testcluster/objects.go index 8ebb673d30..07a4e15f43 100644 --- a/test/kubernetes/testcluster/objects.go +++ b/test/kubernetes/testcluster/objects.go @@ -19,7 +19,6 @@ import ( "errors" "fmt" "reflect" - "strconv" cspb "google.golang.org/genproto/googleapis/container/v1" "google.golang.org/protobuf/proto" @@ -181,16 +180,14 @@ type RuntimeType string // List of known runtime types. const ( - RuntimeTypeGVisor = RuntimeType("gvisor") - RuntimeTypeUnsandboxed = RuntimeType("runc") - RuntimeTypeGVisorNvidia = RuntimeType("gvisor-nvidia") - RuntimeTypeGVisorTPU = RuntimeType("gvisor-tpu") - RuntimeTypeUnsandboxedNvidia = RuntimeType("runc-nvidia") - RuntimeTypeUnsandboxedTPU = RuntimeType("runc-tpu") + RuntimeTypeGVisor = RuntimeType("gvisor") + RuntimeTypeUnsandboxed = RuntimeType("runc") + RuntimeTypeGVisorTPU = RuntimeType("gvisor-tpu") + RuntimeTypeUnsandboxedTPU = RuntimeType("runc-tpu") ) // ApplyNodepool modifies the nodepool to configure it to use the runtime. -func (t RuntimeType) ApplyNodepool(nodepool *cspb.NodePool, accelType AcceleratorType, accelShape string, accelRes string) { +func (t RuntimeType) ApplyNodepool(nodepool *cspb.NodePool) { if nodepool.GetConfig().GetLabels() == nil { nodepool.GetConfig().Labels = map[string]string{} } @@ -204,81 +201,27 @@ func (t RuntimeType) ApplyNodepool(nodepool *cspb.NodePool, accelType Accelerato case RuntimeTypeUnsandboxed: nodepool.GetConfig().Labels[NodepoolRuntimeKey] = string(RuntimeTypeUnsandboxed) // Do nothing. - case RuntimeTypeGVisorNvidia: - nodepool.Config.SandboxConfig = &cspb.SandboxConfig{ - Type: cspb.SandboxConfig_GVISOR, - } - accelCount, err := strconv.Atoi(accelShape) - if err != nil { - panic(fmt.Sprintf("GPU count must be a valid number, got %v", accelShape)) - } - if accelCount == 0 { - panic("GPU count needs to be >=1") - } - nodepool.Config.MachineType = DefaultNvidiaMachineType - nodepool.Config.Accelerators = []*cspb.AcceleratorConfig{ - { - AcceleratorType: string(accelType), - AcceleratorCount: int64(accelCount), - }, - } - nodepool.Config.Labels[NodepoolRuntimeKey] = string(RuntimeTypeGVisorNvidia) - nodepool.Config.Labels[NodepoolNumAcceleratorsKey] = strconv.Itoa(accelCount) case RuntimeTypeGVisorTPU: - nodepool.Config.MachineType = TPUAcceleratorMachineTypeMap[accelType] - if err := setNodePlacementPolicyCompact(nodepool, accelShape); err != nil { - panic(fmt.Sprintf("failed to set node placement policy: %v", err)) - } nodepool.Config.Labels[gvisorNodepoolKey] = gvisorRuntimeClass nodepool.Config.Labels[NodepoolRuntimeKey] = string(RuntimeTypeGVisorTPU) - nodepool.Config.Labels[NodepoolTPUTopologyKey] = accelShape nodepool.Config.Taints = append(nodepool.Config.Taints, &cspb.NodeTaint{ Key: gvisorNodepoolKey, Value: gvisorRuntimeClass, Effect: cspb.NodeTaint_NO_SCHEDULE, }) - case RuntimeTypeUnsandboxedNvidia: - accelCount, err := strconv.Atoi(accelShape) - if err != nil { - panic(fmt.Sprintf("GPU count must be a valid number, got %v", accelShape)) - } - if accelCount == 0 { - panic("GPU count needs to be >=1") - } - nodepool.Config.MachineType = DefaultNvidiaMachineType - nodepool.Config.Accelerators = []*cspb.AcceleratorConfig{ - { - AcceleratorType: string(accelType), - AcceleratorCount: int64(accelCount), - }, - } - nodepool.Config.Labels[NodepoolRuntimeKey] = string(RuntimeTypeUnsandboxedNvidia) - nodepool.Config.Labels[NodepoolNumAcceleratorsKey] = strconv.Itoa(accelCount) case RuntimeTypeUnsandboxedTPU: - nodepool.Config.MachineType = TPUAcceleratorMachineTypeMap[accelType] - if err := setNodePlacementPolicyCompact(nodepool, accelShape); err != nil { - panic(fmt.Sprintf("failed to set node placement policy: %v", err)) - } nodepool.Config.Labels[NodepoolRuntimeKey] = string(RuntimeTypeUnsandboxedTPU) - nodepool.Config.Labels[NodepoolTPUTopologyKey] = accelShape default: panic(fmt.Sprintf("unsupported runtime %q", t)) } - if accelRes != "" { - nodepool.Config.ReservationAffinity = &cspb.ReservationAffinity{ - ConsumeReservationType: cspb.ReservationAffinity_SPECIFIC_RESERVATION, - Key: "compute.googleapis.com/reservation-name", - Values: []string{accelRes}, - } - } } -// setNodePlacementPolicyCompact sets the node placement policy to COMPACT +// SetNodePlacementPolicyCompact sets the node placement policy to COMPACT // and with the given TPU topology. // This is done by reflection because the NodePool_PlacementPolicy proto // message isn't available in the latest exported version of the genproto API. // This is only used for TPU nodepools so not critical for most benchmarks. -func setNodePlacementPolicyCompact(nodepool *cspb.NodePool, tpuTopology string) error { +func SetNodePlacementPolicyCompact(nodepool *cspb.NodePool, tpuTopology string) error { placementPolicyField := reflect.ValueOf(nodepool).Elem().FieldByName("PlacementPolicy") if !placementPolicyField.IsValid() { return errors.New("nodepool does not have a PlacementPolicy field") @@ -305,7 +248,15 @@ func (t RuntimeType) ApplyPodSpec(podSpec *v13.PodSpec) { case RuntimeTypeGVisor: podSpec.RuntimeClassName = proto.String(gvisorRuntimeClass) podSpec.NodeSelector[NodepoolRuntimeKey] = string(RuntimeTypeGVisor) + podSpec.Tolerations = append(podSpec.Tolerations, v13.Toleration{ + Key: "nvidia.com/gpu", + Operator: v13.TolerationOpExists, + }) case RuntimeTypeUnsandboxed: + podSpec.Tolerations = append(podSpec.Tolerations, v13.Toleration{ + Key: "nvidia.com/gpu", + Operator: v13.TolerationOpExists, + }) // Allow the pod to schedule on gVisor nodes as well. // This enables the use of `--test-nodepool-runtime=runc` to run // unsandboxed benchmarks on gVisor test clusters. @@ -315,13 +266,6 @@ func (t RuntimeType) ApplyPodSpec(podSpec *v13.PodSpec) { Operator: v13.TolerationOpEqual, Value: gvisorRuntimeClass, }) - case RuntimeTypeGVisorNvidia: - podSpec.RuntimeClassName = proto.String(gvisorRuntimeClass) - podSpec.NodeSelector[NodepoolRuntimeKey] = string(RuntimeTypeGVisorNvidia) - podSpec.Tolerations = append(podSpec.Tolerations, v13.Toleration{ - Key: "nvidia.com/gpu", - Operator: v13.TolerationOpExists, - }) case RuntimeTypeGVisorTPU: podSpec.RuntimeClassName = proto.String(gvisorRuntimeClass) podSpec.NodeSelector[NodepoolRuntimeKey] = string(RuntimeTypeGVisorTPU) @@ -329,20 +273,6 @@ func (t RuntimeType) ApplyPodSpec(podSpec *v13.PodSpec) { Key: "google.com/tpu", Operator: v13.TolerationOpExists, }) - case RuntimeTypeUnsandboxedNvidia: - podSpec.Tolerations = append(podSpec.Tolerations, v13.Toleration{ - Key: "nvidia.com/gpu", - Operator: v13.TolerationOpExists, - }) - // Allow the pod to schedule on gVisor nodes as well. - // This enables the use of `--test-nodepool-runtime=runc-nvidia` to run - // unsandboxed benchmarks on gVisor test clusters. - podSpec.Tolerations = append(podSpec.Tolerations, v13.Toleration{ - Effect: v13.TaintEffectNoSchedule, - Key: gvisorNodepoolKey, - Operator: v13.TolerationOpEqual, - Value: gvisorRuntimeClass, - }) case RuntimeTypeUnsandboxedTPU: podSpec.Tolerations = append(podSpec.Tolerations, v13.Toleration{ Key: "google.com/tpu", diff --git a/test/kubernetes/testcluster/testcluster.go b/test/kubernetes/testcluster/testcluster.go index e476b1a109..123a97b107 100644 --- a/test/kubernetes/testcluster/testcluster.go +++ b/test/kubernetes/testcluster/testcluster.go @@ -140,7 +140,8 @@ const ( // TestCluster wraps clusters with their individual ClientSets so that helper methods can be called. type TestCluster struct { clusterName string - client kubernetes.Interface + + client KubernetesClient // testNodepoolRuntimeOverride, if set, overrides the runtime used for pods // running on the test nodepool. If unset, the test nodepool's default @@ -209,6 +210,12 @@ func NewTestClusterFromProto(ctx context.Context, cluster *testpb.Cluster) (*Tes // NewTestClusterFromClient returns a new TestCluster client with a given client. func NewTestClusterFromClient(clusterName string, client kubernetes.Interface) *TestCluster { + return NewTestClusterFromKubernetesClient(clusterName, &simpleClient{client}) +} + +// NewTestClusterFromKubernetesClient returns a new TestCluster client with a +// given KubernetesClient. +func NewTestClusterFromKubernetesClient(clusterName string, client KubernetesClient) *TestCluster { return &TestCluster{ clusterName: clusterName, client: client, @@ -248,17 +255,24 @@ func (t *TestCluster) OverrideTestNodepoolRuntime(testRuntime RuntimeType) { // createNamespace creates a namespace. func (t *TestCluster) createNamespace(ctx context.Context, namespace *v13.Namespace) (*v13.Namespace, error) { - return t.client.CoreV1().Namespaces().Create(ctx, namespace, v1.CreateOptions{}) + return request(ctx, t.client, func(ctx context.Context, client kubernetes.Interface) (*v13.Namespace, error) { + return client.CoreV1().Namespaces().Create(ctx, namespace, v1.CreateOptions{}) + }) } // getNamespace returns the given namespace in the cluster if it exists. func (t *TestCluster) getNamespace(ctx context.Context, namespaceName string) (*v13.Namespace, error) { - return t.client.CoreV1().Namespaces().Get(ctx, namespaceName, v1.GetOptions{}) + return request(ctx, t.client, func(ctx context.Context, client kubernetes.Interface) (*v13.Namespace, error) { + return client.CoreV1().Namespaces().Get(ctx, namespaceName, v1.GetOptions{}) + }) } // deleteNamespace is a helper method to delete a namespace. func (t *TestCluster) deleteNamespace(ctx context.Context, namespaceName string) error { - if err := t.client.CoreV1().Namespaces().Delete(ctx, namespaceName, v1.DeleteOptions{}); err != nil { + err := t.client.Do(ctx, func(ctx context.Context, client kubernetes.Interface) error { + return client.CoreV1().Namespaces().Delete(ctx, namespaceName, v1.DeleteOptions{}) + }) + if err != nil { return err } // Wait for the namespace to disappear or for the context to expire. @@ -282,7 +296,9 @@ func (t *TestCluster) getNodePool(ctx context.Context, nodepoolType NodePoolType t.nodepoolsMu.Lock() defer t.nodepoolsMu.Unlock() if t.nodepools == nil { - nodes, err := t.client.CoreV1().Nodes().List(ctx, v1.ListOptions{}) + nodes, err := request(ctx, t.client, func(ctx context.Context, client kubernetes.Interface) (*v13.NodeList, error) { + return client.CoreV1().Nodes().List(ctx, v1.ListOptions{}) + }) if err != nil { return nil, fmt.Errorf("cannot list nodes: %w", err) } @@ -363,7 +379,7 @@ func (t *TestCluster) HasGVisorTestRuntime(ctx context.Context) (bool, error) { if err != nil { return false, err } - return testNodePool.runtime == RuntimeTypeGVisor || testNodePool.runtime == RuntimeTypeGVisorNvidia, nil + return testNodePool.runtime == RuntimeTypeGVisor || testNodePool.runtime == RuntimeTypeGVisorTPU, nil } // CreatePod is a helper to create a pod. @@ -371,22 +387,31 @@ func (t *TestCluster) CreatePod(ctx context.Context, pod *v13.Pod) (*v13.Pod, er if pod.GetObjectMeta().GetNamespace() == "" { pod.SetNamespace(NamespaceDefault) } - return t.client.CoreV1().Pods(pod.GetNamespace()).Create(ctx, pod, v1.CreateOptions{}) + return request(ctx, t.client, func(ctx context.Context, client kubernetes.Interface) (*v13.Pod, error) { + return client.CoreV1().Pods(pod.GetNamespace()).Create(ctx, pod, v1.CreateOptions{}) + }) } // GetPod is a helper method to Get a pod's metadata. func (t *TestCluster) GetPod(ctx context.Context, pod *v13.Pod) (*v13.Pod, error) { - return t.client.CoreV1().Pods(pod.GetNamespace()).Get(ctx, pod.GetName(), v1.GetOptions{}) + return request(ctx, t.client, func(ctx context.Context, client kubernetes.Interface) (*v13.Pod, error) { + return client.CoreV1().Pods(pod.GetNamespace()).Get(ctx, pod.GetName(), v1.GetOptions{}) + }) } // ListPods is a helper method to List pods in a cluster. func (t *TestCluster) ListPods(ctx context.Context, namespace string) (*v13.PodList, error) { - return t.client.CoreV1().Pods(namespace).List(ctx, v1.ListOptions{}) + return request(ctx, t.client, func(ctx context.Context, client kubernetes.Interface) (*v13.PodList, error) { + return client.CoreV1().Pods(namespace).List(ctx, v1.ListOptions{}) + }) } // DeletePod is a helper method to delete a pod. func (t *TestCluster) DeletePod(ctx context.Context, pod *v13.Pod) error { - if err := t.client.CoreV1().Pods(pod.GetNamespace()).Delete(ctx, pod.GetName(), v1.DeleteOptions{}); err != nil { + err := t.client.Do(ctx, func(ctx context.Context, client kubernetes.Interface) error { + return client.CoreV1().Pods(pod.GetNamespace()).Delete(ctx, pod.GetName(), v1.DeleteOptions{}) + }) + if err != nil { return err } // Wait for the pod to disappear or for the context to expire. @@ -406,7 +431,9 @@ func (t *TestCluster) DeletePod(ctx context.Context, pod *v13.Pod) error { // GetLogReader gets an io.ReadCloser from which logs can be read. It is the caller's // responsibility to close it. func (t *TestCluster) GetLogReader(ctx context.Context, pod *v13.Pod, opts v13.PodLogOptions) (io.ReadCloser, error) { - return t.client.CoreV1().Pods(pod.GetNamespace()).GetLogs(pod.GetName(), &opts).Stream(ctx) + return request(ctx, t.client, func(ctx context.Context, client kubernetes.Interface) (io.ReadCloser, error) { + return client.CoreV1().Pods(pod.GetNamespace()).GetLogs(pod.GetName(), &opts).Stream(ctx) + }) } // ReadPodLogs reads logs from a pod. @@ -602,22 +629,36 @@ func (t *TestCluster) ContainerDurationSecondsByName(ctx context.Context, pod *v // CreateService is a helper method to create a service in a cluster. func (t *TestCluster) CreateService(ctx context.Context, service *v13.Service) (*v13.Service, error) { - return t.client.CoreV1().Services(service.GetNamespace()).Create(ctx, service, v1.CreateOptions{}) + return request(ctx, t.client, func(ctx context.Context, client kubernetes.Interface) (*v13.Service, error) { + return client.CoreV1().Services(service.GetNamespace()).Create(ctx, service, v1.CreateOptions{}) + }) +} + +// GetService is a helper method to get a service in a cluster. +func (t *TestCluster) GetService(ctx context.Context, service *v13.Service) (*v13.Service, error) { + return request(ctx, t.client, func(ctx context.Context, client kubernetes.Interface) (*v13.Service, error) { + return client.CoreV1().Services(service.GetNamespace()).Get(ctx, service.GetName(), v1.GetOptions{}) + }) } // ListServices is a helper method to List services in a cluster. func (t *TestCluster) ListServices(ctx context.Context, namespace string) (*v13.ServiceList, error) { - return t.client.CoreV1().Services(namespace).List(ctx, v1.ListOptions{}) + return request(ctx, t.client, func(ctx context.Context, client kubernetes.Interface) (*v13.ServiceList, error) { + return client.CoreV1().Services(namespace).List(ctx, v1.ListOptions{}) + }) } // DeleteService is a helper to delete a given service. func (t *TestCluster) DeleteService(ctx context.Context, service *v13.Service) error { - if err := t.client.CoreV1().Services(service.GetNamespace()).Delete(ctx, service.GetName(), v1.DeleteOptions{}); err != nil { + err := t.client.Do(ctx, func(ctx context.Context, client kubernetes.Interface) error { + return client.CoreV1().Services(service.GetNamespace()).Delete(ctx, service.GetName(), v1.DeleteOptions{}) + }) + if err != nil { return err } // Wait for the service to disappear or for the context to expire. for ctx.Err() == nil { - if _, err := t.client.CoreV1().Services(service.GetNamespace()).Get(ctx, service.GetName(), v1.GetOptions{}); err != nil { + if _, err := t.GetService(ctx, service); err != nil { return nil } select { @@ -639,7 +680,7 @@ func (t *TestCluster) WaitForServiceReady(ctx context.Context, service *v13.Serv case <-ctx.Done(): return fmt.Errorf("context expired waiting for service %q: %w (last: %v)", service.GetName(), ctx.Err(), lastService) case <-pollCh.C: - s, err := t.client.CoreV1().Services(service.GetNamespace()).Get(ctx, service.GetName(), v1.GetOptions{}) + s, err := t.GetService(ctx, service) if err != nil { return fmt.Errorf("cannot look up service %q: %w", service.GetName(), err) } @@ -662,12 +703,16 @@ func (t *TestCluster) CreatePersistentVolume(ctx context.Context, volume *v13.Pe if volume.GetObjectMeta().GetNamespace() == "" { volume.SetNamespace(NamespaceDefault) } - return t.client.CoreV1().PersistentVolumeClaims(volume.GetNamespace()).Create(ctx, volume, v1.CreateOptions{}) + return request(ctx, t.client, func(ctx context.Context, client kubernetes.Interface) (*v13.PersistentVolumeClaim, error) { + return client.CoreV1().PersistentVolumeClaims(volume.GetNamespace()).Create(ctx, volume, v1.CreateOptions{}) + }) } // DeletePersistentVolume deletes a persistent volume. func (t *TestCluster) DeletePersistentVolume(ctx context.Context, volume *v13.PersistentVolumeClaim) error { - return t.client.CoreV1().PersistentVolumeClaims(volume.GetNamespace()).Delete(ctx, volume.GetName(), v1.DeleteOptions{}) + return t.client.Do(ctx, func(ctx context.Context, client kubernetes.Interface) error { + return client.CoreV1().PersistentVolumeClaims(volume.GetNamespace()).Delete(ctx, volume.GetName(), v1.DeleteOptions{}) + }) } // CreateDaemonset creates a daemonset with default options. @@ -675,12 +720,23 @@ func (t *TestCluster) CreateDaemonset(ctx context.Context, ds *appsv1.DaemonSet) if ds.GetObjectMeta().GetNamespace() == "" { ds.SetNamespace(NamespaceDefault) } - return t.client.AppsV1().DaemonSets(ds.GetNamespace()).Create(ctx, ds, v1.CreateOptions{}) + return request(ctx, t.client, func(ctx context.Context, client kubernetes.Interface) (*appsv1.DaemonSet, error) { + return client.AppsV1().DaemonSets(ds.GetNamespace()).Create(ctx, ds, v1.CreateOptions{}) + }) +} + +// GetDaemonset gets a daemonset. +func (t *TestCluster) GetDaemonset(ctx context.Context, ds *appsv1.DaemonSet) (*appsv1.DaemonSet, error) { + return request(ctx, t.client, func(ctx context.Context, client kubernetes.Interface) (*appsv1.DaemonSet, error) { + return client.AppsV1().DaemonSets(ds.GetNamespace()).Get(ctx, ds.GetName(), v1.GetOptions{}) + }) } // DeleteDaemonset deletes a daemonset from this cluster. func (t *TestCluster) DeleteDaemonset(ctx context.Context, ds *appsv1.DaemonSet) error { - return t.client.AppsV1().DaemonSets(ds.GetNamespace()).Delete(ctx, ds.GetName(), v1.DeleteOptions{}) + return t.client.Do(ctx, func(ctx context.Context, client kubernetes.Interface) error { + return client.AppsV1().DaemonSets(ds.GetNamespace()).Delete(ctx, ds.GetName(), v1.DeleteOptions{}) + }) } // GetPodsInDaemonSet returns the list of pods of the given DaemonSet. @@ -689,7 +745,9 @@ func (t *TestCluster) GetPodsInDaemonSet(ctx context.Context, ds *appsv1.DaemonS if appLabel, found := ds.Spec.Template.Labels[k8sApp]; found { listOptions.LabelSelector = fmt.Sprintf("%s=%s", k8sApp, appLabel) } - pods, err := t.client.CoreV1().Pods(ds.ObjectMeta.Namespace).List(ctx, listOptions) + pods, err := request(ctx, t.client, func(ctx context.Context, client kubernetes.Interface) (*v13.PodList, error) { + return client.CoreV1().Pods(ds.ObjectMeta.Namespace).List(ctx, listOptions) + }) if err != nil { return nil, err } @@ -709,7 +767,7 @@ func (t *TestCluster) WaitForDaemonset(ctx context.Context, ds *appsv1.DaemonSet defer pollCh.Stop() // Poll-based loop to wait for the DaemonSet to be ready. for { - d, err := t.client.AppsV1().DaemonSets(ds.GetNamespace()).Get(ctx, ds.GetName(), v1.GetOptions{}) + d, err := t.GetDaemonset(ctx, ds) if err != nil { return fmt.Errorf("failed to get daemonset %q: %v", ds.GetName(), err) } @@ -778,7 +836,7 @@ func (t *TestCluster) StreamDaemonSetLogs(ctx context.Context, ds *appsv1.Daemon if _, seen := nodesSeen[pod.Spec.NodeName]; seen { continue // Node already seen. } - logReader, err := t.client.CoreV1().Pods(pod.GetNamespace()).GetLogs(pod.GetName(), &opts).Stream(ctx) + logReader, err := t.GetLogReader(ctx, &pod, opts) if err != nil { // This can happen if the container hasn't run yet, for example // because other init containers that run earlier are still executing. @@ -813,7 +871,7 @@ Outer: } break Outer case <-timeTicker.C: - d, err := t.client.AppsV1().DaemonSets(ds.GetNamespace()).Get(ctx, ds.GetName(), v1.GetOptions{}) + d, err := t.GetDaemonset(ctx, ds) if err != nil { loopError = fmt.Errorf("failed to get DaemonSet: %v", err) break Outer