From 8d8b99cc9a4429c187ca14786496d90d4cc784eb Mon Sep 17 00:00:00 2001 From: "Anas H. Sulaiman" Date: Tue, 13 Jun 2023 11:30:03 -0400 Subject: [PATCH] casng, byte streaming implementation (#462) Byte streaming is used to upload large files that do not fit in a batching request. This implementation avoids using the chunker package by leveraging the streaming nature of IO buffering. Furthermore, error handling is more robust in this implementation since all errors are handled while ensuring a graceful cancelation when necessary. --- go/pkg/casng/BUILD.bazel | 1 + go/pkg/casng/batching.go | 198 +++++++++++++- go/pkg/casng/batching_write_bytes_test.go | 301 ++++++++++++++++++++++ go/pkg/casng/throttler.go | 39 +++ go/pkg/casng/uploader.go | 16 +- 5 files changed, 544 insertions(+), 11 deletions(-) create mode 100644 go/pkg/casng/batching_write_bytes_test.go create mode 100644 go/pkg/casng/throttler.go diff --git a/go/pkg/casng/BUILD.bazel b/go/pkg/casng/BUILD.bazel index d2771f469..ea6f0e2fd 100644 --- a/go/pkg/casng/BUILD.bazel +++ b/go/pkg/casng/BUILD.bazel @@ -5,6 +5,7 @@ go_library( srcs = [ "batching.go", "config.go", + "throttler.go", "uploader.go", ], importpath = "github.com/bazelbuild/remote-apis-sdks/go/pkg/casng", diff --git a/go/pkg/casng/batching.go b/go/pkg/casng/batching.go index 9e155c915..9b7736a7a 100644 --- a/go/pkg/casng/batching.go +++ b/go/pkg/casng/batching.go @@ -2,13 +2,19 @@ package casng import ( "context" + "fmt" "io" + "sync" + "time" "github.com/bazelbuild/remote-apis-sdks/go/pkg/contextmd" "github.com/bazelbuild/remote-apis-sdks/go/pkg/digest" "github.com/bazelbuild/remote-apis-sdks/go/pkg/errors" + "github.com/bazelbuild/remote-apis-sdks/go/pkg/retry" repb "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2" log "github.com/golang/glog" + "github.com/klauspost/compress/zstd" + bspb "google.golang.org/genproto/googleapis/bytestream" ) // MissingBlobs queries the CAS for digests and returns a slice of the missing ones. @@ -57,7 +63,7 @@ func (u *BatchingUploader) MissingBlobs(ctx context.Context, digests []digest.Di req := &repb.FindMissingBlobsRequest{InstanceName: u.instanceName} for _, batch := range batches { req.BlobDigests = batch - errRes = u.withRetry(ctx, u.queryRPCCfg.RetryPredicate, u.queryRPCCfg.RetryPolicy, func() error { + errRes = retry.WithPolicy(ctx, u.queryRPCCfg.RetryPredicate, u.queryRPCCfg.RetryPolicy, func() error { ctx, ctxCancel := context.WithTimeout(ctx, u.queryRPCCfg.Timeout) defer ctxCancel() res, errRes = u.cas.FindMissingBlobs(ctx, req) @@ -89,15 +95,195 @@ func (u *BatchingUploader) MissingBlobs(ctx context.Context, digests []digest.Di // ctx is used to make and cancel remote calls. // This method does not use the uploader's context which means it is safe to call even after that context is cancelled. // -// size is used to toggle compression as well as report some stats. It must reflect the actual number of bytes r has to give. +// Compression is turned on based the resource name. +// size is used to report some stats. It must reflect the actual number of bytes r has to give. // The server is notified to finalize the resource name and subsequent writes may not succeed. // The errors returned are either from the context, ErrGRPC, ErrIO, or ErrCompression. More errors may be wrapped inside. // If an error was returned, the returned stats may indicate that all the bytes were sent, but that does not guarantee that the server committed all of them. -func (u *BatchingUploader) WriteBytes(ctx context.Context, name string, r io.Reader, size int64, offset int64) (Stats, error) { - panic("not yet implemented") +func (u *BatchingUploader) WriteBytes(ctx context.Context, name string, r io.Reader, size, offset int64) (Stats, error) { + if !u.streamThrottle.acquire(ctx) { + return Stats{}, ctx.Err() + } + defer u.streamThrottle.release() + return u.writeBytes(ctx, name, r, size, offset, true) } // WriteBytesPartial is the same as WriteBytes, but does not notify the server to finalize the resource name. -func (u *BatchingUploader) WriteBytesPartial(ctx context.Context, name string, r io.Reader, size int64, offset int64) (Stats, error) { - panic("not yet implemented") +func (u *BatchingUploader) WriteBytesPartial(ctx context.Context, name string, r io.Reader, size, offset int64) (Stats, error) { + if !u.streamThrottle.acquire(ctx) { + return Stats{}, ctx.Err() + } + defer u.streamThrottle.release() + return u.writeBytes(ctx, name, r, size, offset, false) +} + +func (u *uploader) writeBytes(ctx context.Context, name string, r io.Reader, size, offset int64, finish bool) (Stats, error) { + contextmd.Infof(ctx, log.Level(1), "[casng] upload.write_bytes: name=%s, size=%d, offset=%d, finish=%t", name, size, offset, finish) + defer contextmd.Infof(ctx, log.Level(1), "[casng] upload.write_bytes.done: name=%s, size=%d, offset=%d, finish=%t", name, size, offset, finish) + if log.V(3) { + startTime := time.Now() + defer func() { + log.Infof("[casng] upload.write_bytes.duration: start=%d, end=%d, name=%s, size=%d, chunk_size=%d", startTime.UnixNano(), time.Now().UnixNano(), name, size, u.ioCfg.BufferSize) + }() + } + + var stats Stats + // Read raw bytes if compression is disabled. + src := r + + // If compression is enabled, plug in the encoder via a pipe. + var errEnc error + var nRawBytes int64 // Track the actual number of the consumed raw bytes. + var encWg sync.WaitGroup + var withCompression bool // Used later to ensure the pipe is closed. + if IsCompressedWriteResourceName(name) { + contextmd.Infof(ctx, log.Level(1), "[casng] upload.write_bytes.compressing: name=%s, size=%d", name, size) + withCompression = true + pr, pw := io.Pipe() + // Closing pr always returns a nil error, but also sends ErrClosedPipe to pw. + defer pr.Close() + src = pr // Read compressed bytes instead of raw bytes. + + enc := u.zstdEncoders.Get().(*zstd.Encoder) + defer u.zstdEncoders.Put(enc) + // (Re)initialize the encoder with this writer. + enc.Reset(pw) + // Get it going. + encWg.Add(1) + go func() { + defer encWg.Done() + // Closing pw always returns a nil error, but also sends an EOF to pr. + defer pw.Close() + + // The encoder will theoretically read continuously. However, pw will block it + // while pr is not reading from the other side. + // In other words, the chunk size of the encoder's output is controlled by the reader. + nRawBytes, errEnc = enc.ReadFrom(r) + // Closing the encoder is necessary to flush remaining bytes. + errEnc = errors.Join(enc.Close(), errEnc) + if errors.Is(errEnc, io.ErrClosedPipe) { + // pr was closed first, which means the actual error is on that end. + errEnc = nil + } + }() + } + + ctx, ctxCancel := context.WithCancel(ctx) + defer ctxCancel() + + stream, errStream := u.byteStream.Write(ctx) + if errStream != nil { + return stats, errors.Join(ErrGRPC, errStream) + } + + // buf slice is never resliced which makes it safe to use a pointer-like type. + buf := u.buffers.Get().([]byte) + defer u.buffers.Put(buf) + + cacheHit := false + var err error + req := &bspb.WriteRequest{ + ResourceName: name, + WriteOffset: offset, + } + for { + n, errRead := src.Read(buf) + if errRead != nil && errRead != io.EOF { + err = errors.Join(ErrIO, errRead, err) + break + } + + n64 := int64(n) + stats.LogicalBytesMoved += n64 // This may be adjusted later to exclude compression. See below. + stats.EffectiveBytesMoved += n64 + + req.Data = buf[:n] + req.FinishWrite = finish && errRead == io.EOF + errStream := retry.WithPolicy(ctx, u.streamRPCCfg.RetryPredicate, u.streamRPCCfg.RetryPolicy, func() error { + timer := time.NewTimer(u.streamRPCCfg.Timeout) + // Ensure the timer goroutine terminates if Send does not timeout. + success := make(chan struct{}) + defer close(success) + go func() { + select { + case <-timer.C: + ctxCancel() // Cancel the stream to allow Send to return. + case <-success: + } + }() + stats.TotalBytesMoved += n64 + return stream.Send(req) + }) + if errStream != nil && errStream != io.EOF { + err = errors.Join(ErrGRPC, errStream, err) + break + } + + // The server says the content for the specified resource already exists. + if errStream == io.EOF { + cacheHit = true + break + } + + req.WriteOffset += n64 + + // The reader is done (interrupted or completed). + if errRead == io.EOF { + break + } + } + + // In case of a cache hit or an error, the pipe must be closed to terminate the encoder's goroutine + // which would have otherwise terminated after draining the reader. + if srcCloser, ok := src.(io.Closer); ok && withCompression { + if errClose := srcCloser.Close(); errClose != nil { + err = errors.Join(ErrIO, errClose, err) + } + } + + // This theoretically will block until the encoder's goroutine has returned, which is the happy path. + // If the reader failed without the encoder's knowledge, closing the pipe will trigger the encoder to terminate, which is done above. + // In any case, waiting here is necessary because the encoder's goroutine currently owns errEnc and nRawBytes. + encWg.Wait() + if errEnc != nil { + err = errors.Join(ErrCompression, errEnc, err) + } + + // Capture stats before processing errors. + stats.BytesRequested = size + if nRawBytes > 0 { + // Compression was turned on. + // nRawBytes may be smaller than compressed bytes (additional headers without effective compression). + stats.LogicalBytesMoved = nRawBytes + } + if cacheHit { + stats.LogicalBytesCached = size + } + stats.LogicalBytesStreamed = stats.LogicalBytesMoved + stats.LogicalBytesBatched = 0 + stats.InputFileCount = 0 + stats.InputDirCount = 0 + stats.InputSymlinkCount = 0 + if cacheHit { + stats.CacheHitCount = 1 + } else { + stats.CacheMissCount = 1 + } + stats.DigestCount = 0 + stats.BatchedCount = 0 + if err == nil { + stats.StreamedCount = 1 + } + + res, errClose := stream.CloseAndRecv() + if errClose != nil { + return stats, errors.Join(ErrGRPC, errClose, err) + } + + // CommittedSize is based on the uncompressed size of the blob. + if !cacheHit && res.CommittedSize != size { + err = errors.Join(ErrGRPC, fmt.Errorf("committed size mismatch: got %d, want %d", res.CommittedSize, size), err) + } + + return stats, err } diff --git a/go/pkg/casng/batching_write_bytes_test.go b/go/pkg/casng/batching_write_bytes_test.go new file mode 100644 index 000000000..5c5094190 --- /dev/null +++ b/go/pkg/casng/batching_write_bytes_test.go @@ -0,0 +1,301 @@ +package casng_test + +import ( + "bytes" + "context" + "fmt" + "io" + "strings" + "testing" + + "github.com/bazelbuild/remote-apis-sdks/go/pkg/casng" + "github.com/bazelbuild/remote-apis-sdks/go/pkg/errors" + "github.com/bazelbuild/remote-apis-sdks/go/pkg/retry" + "github.com/google/go-cmp/cmp" + bspb "google.golang.org/genproto/googleapis/bytestream" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestUpload_WriteBytes(t *testing.T) { + errWrite := fmt.Errorf("write error") + errClose := fmt.Errorf("close error") + + tests := []struct { + name string + bs *fakeByteStreamClient + b []byte + offset int64 + finish bool + wantErr error + wantStats casng.Stats + retryPolicy *retry.BackoffPolicy + }{ + { + name: "no_compression", + bs: &fakeByteStreamClient{ + write: func(_ context.Context, _ ...grpc.CallOption) (bspb.ByteStream_WriteClient, error) { + bytesSent := int64(0) + return &fakeByteStream_WriteClient{ + send: func(wr *bspb.WriteRequest) error { + bytesSent += int64(len(wr.Data)) + return nil + }, + closeAndRecv: func() (*bspb.WriteResponse, error) { + return &bspb.WriteResponse{CommittedSize: bytesSent}, nil + }, + }, nil + }, + }, + b: []byte("abs"), + wantErr: nil, + wantStats: casng.Stats{ + BytesRequested: 3, + EffectiveBytesMoved: 3, + TotalBytesMoved: 3, + LogicalBytesMoved: 3, + LogicalBytesStreamed: 3, + CacheMissCount: 1, + StreamedCount: 1, + }, + }, + { + name: "compression", + bs: &fakeByteStreamClient{ + write: func(_ context.Context, _ ...grpc.CallOption) (bspb.ByteStream_WriteClient, error) { + return &fakeByteStream_WriteClient{ + send: func(wr *bspb.WriteRequest) error { + return nil + }, + closeAndRecv: func() (*bspb.WriteResponse, error) { + return &bspb.WriteResponse{CommittedSize: 3500}, nil + }, + }, nil + }, + }, + b: []byte(strings.Repeat("abcdefg", 500)), + wantErr: nil, + wantStats: casng.Stats{ + BytesRequested: 3500, + EffectiveBytesMoved: 29, + TotalBytesMoved: 29, + LogicalBytesMoved: 3500, + LogicalBytesStreamed: 3500, + CacheMissCount: 1, + StreamedCount: 1, + }, + }, + { + name: "write_call_error", + bs: &fakeByteStreamClient{ + write: func(ctx context.Context, opts ...grpc.CallOption) (bspb.ByteStream_WriteClient, error) { + return nil, errWrite + }, + }, + b: []byte("abc"), + wantErr: errWrite, + wantStats: casng.Stats{}, + }, + { + name: "cache_hit", + bs: &fakeByteStreamClient{ + write: func(ctx context.Context, opts ...grpc.CallOption) (bspb.ByteStream_WriteClient, error) { + return &fakeByteStream_WriteClient{ + send: func(wr *bspb.WriteRequest) error { + return io.EOF + }, + closeAndRecv: func() (*bspb.WriteResponse, error) { + return &bspb.WriteResponse{}, nil + }, + }, nil + }, + }, + b: []byte("abc"), + wantErr: nil, + wantStats: casng.Stats{ + BytesRequested: 3, + EffectiveBytesMoved: 2, // matches buffer size + TotalBytesMoved: 2, + LogicalBytesMoved: 2, + LogicalBytesStreamed: 2, + CacheHitCount: 1, + LogicalBytesCached: 3, + StreamedCount: 1, + }, + }, + { + name: "send_error", + bs: &fakeByteStreamClient{ + write: func(ctx context.Context, opts ...grpc.CallOption) (bspb.ByteStream_WriteClient, error) { + return &fakeByteStream_WriteClient{ + send: func(wr *bspb.WriteRequest) error { + return errWrite + }, + closeAndRecv: func() (*bspb.WriteResponse, error) { + return &bspb.WriteResponse{}, nil + }, + }, nil + }, + }, + b: []byte("abc"), + wantErr: casng.ErrGRPC, + wantStats: casng.Stats{ + BytesRequested: 3, + EffectiveBytesMoved: 2, // matches buffer size + TotalBytesMoved: 2, + LogicalBytesMoved: 2, + LogicalBytesStreamed: 2, + CacheMissCount: 1, + StreamedCount: 0, + }, + }, + { + name: "send_retry_timeout", + bs: &fakeByteStreamClient{ + write: func(ctx context.Context, opts ...grpc.CallOption) (bspb.ByteStream_WriteClient, error) { + return &fakeByteStream_WriteClient{ + send: func(wr *bspb.WriteRequest) error { + return status.Error(codes.DeadlineExceeded, "error") + }, + closeAndRecv: func() (*bspb.WriteResponse, error) { + return &bspb.WriteResponse{}, nil + }, + }, nil + }, + }, + b: []byte("abc"), + wantErr: casng.ErrGRPC, + wantStats: casng.Stats{ + BytesRequested: 3, + EffectiveBytesMoved: 2, // matches one buffer size + TotalBytesMoved: 4, // matches two buffer sizes + LogicalBytesMoved: 2, + LogicalBytesStreamed: 2, + CacheMissCount: 1, + StreamedCount: 0, + }, + retryPolicy: &retryTwice, + }, + { + name: "stream_close_error", + bs: &fakeByteStreamClient{ + write: func(ctx context.Context, opts ...grpc.CallOption) (bspb.ByteStream_WriteClient, error) { + return &fakeByteStream_WriteClient{ + send: func(wr *bspb.WriteRequest) error { + return nil + }, + closeAndRecv: func() (*bspb.WriteResponse, error) { + return nil, errClose + }, + }, nil + }, + }, + b: []byte("abc"), + wantErr: casng.ErrGRPC, + wantStats: casng.Stats{ + BytesRequested: 3, + EffectiveBytesMoved: 3, + TotalBytesMoved: 3, + LogicalBytesMoved: 3, + LogicalBytesStreamed: 3, + CacheMissCount: 1, + StreamedCount: 1, + }, + }, + { + name: "arbitrary_offset", + bs: &fakeByteStreamClient{ + write: func(ctx context.Context, opts ...grpc.CallOption) (bspb.ByteStream_WriteClient, error) { + return &fakeByteStream_WriteClient{ + send: func(wr *bspb.WriteRequest) error { + if wr.WriteOffset < 5 { + return fmt.Errorf("mismatched offset: want 5, got %d", wr.WriteOffset) + } + return nil + }, + closeAndRecv: func() (*bspb.WriteResponse, error) { + return &bspb.WriteResponse{CommittedSize: 3}, nil + }, + }, nil + }, + }, + b: []byte("abc"), + offset: 5, + wantStats: casng.Stats{ + BytesRequested: 3, + EffectiveBytesMoved: 3, + TotalBytesMoved: 3, + LogicalBytesMoved: 3, + LogicalBytesStreamed: 3, + CacheMissCount: 1, + StreamedCount: 1, + }, + }, + { + name: "finish_write", + bs: &fakeByteStreamClient{ + write: func(ctx context.Context, opts ...grpc.CallOption) (bspb.ByteStream_WriteClient, error) { + return &fakeByteStream_WriteClient{ + send: func(wr *bspb.WriteRequest) error { + if len(wr.Data) == 0 && !wr.FinishWrite { + return fmt.Errorf("finish write was not set") + } + return nil + }, + closeAndRecv: func() (*bspb.WriteResponse, error) { + return &bspb.WriteResponse{CommittedSize: 3}, nil + }, + }, nil + }, + }, + b: []byte("abc"), + finish: true, + wantStats: casng.Stats{ + BytesRequested: 3, + EffectiveBytesMoved: 3, + TotalBytesMoved: 3, + LogicalBytesMoved: 3, + LogicalBytesStreamed: 3, + CacheMissCount: 1, + StreamedCount: 1, + }, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + testRpcCfg := defaultRPCCfg + if test.retryPolicy != nil { + testRpcCfg.RetryPolicy = *test.retryPolicy + } + u, err := casng.NewBatchingUploader(context.Background(), &fakeCAS{}, test.bs, "", testRpcCfg, testRpcCfg, testRpcCfg, defaultIOCfg) + if err != nil { + t.Fatalf("error creating batching uploader: %v", err) + } + var stats casng.Stats + var name string + if len(test.b) >= int(defaultIOCfg.CompressionSizeThreshold) { + name = casng.MakeCompressedWriteResourceName("instance", "hash", 0) + } else { + name = casng.MakeWriteResourceName("instance", "hash", 0) + } + if test.finish { + stats, err = u.WriteBytes(context.Background(), name, bytes.NewReader(test.b), int64(len(test.b)), test.offset) + } else { + stats, err = u.WriteBytesPartial(context.Background(), name, bytes.NewReader(test.b), int64(len(test.b)), test.offset) + } + if test.wantErr == nil && err != nil { + t.Errorf("WriteBytes failed: %v", err) + } + if test.wantErr != nil && !errors.Is(err, test.wantErr) { + t.Errorf("error mismatch: want %v, got %v", test.wantErr, err) + } + if diff := cmp.Diff(test.wantStats, stats); diff != "" { + t.Errorf("stats mismatch, (-want +got): %s", diff) + } + }) + } +} diff --git a/go/pkg/casng/throttler.go b/go/pkg/casng/throttler.go new file mode 100644 index 000000000..68e43f37e --- /dev/null +++ b/go/pkg/casng/throttler.go @@ -0,0 +1,39 @@ +package casng + +import ( + "context" +) + +// throttler provides a simple semaphore interface to limit in-flight goroutines. +type throttler struct { + ch chan struct{} +} + +// acquire blocks until a token can be acquired from the pool. +// +// Returns false if ctx expires before a token is available. Otherwise returns true. +func (t *throttler) acquire(ctx context.Context) bool { + for { + select { + case t.ch <- struct{}{}: + return true + case <-ctx.Done(): + return false + } + } +} + +// release returns a token to the pool. Must be called after acquire. Otherwise, it will block until acquire is called. +func (t *throttler) release() { + <-t.ch +} + +// len returns the number of acquired tokens. +func (t *throttler) len() int { + return len(t.ch) +} + +// newThrottler creates a new instance that allows up to n tokens to be acquired. +func newThrottler(n int64) *throttler { + return &throttler{ch: make(chan struct{}, n)} +} diff --git a/go/pkg/casng/uploader.go b/go/pkg/casng/uploader.go index 835e70fa4..aa77e3297 100644 --- a/go/pkg/casng/uploader.go +++ b/go/pkg/casng/uploader.go @@ -79,10 +79,10 @@ import ( "context" "errors" "fmt" + "strings" "sync" "github.com/bazelbuild/remote-apis-sdks/go/pkg/digest" - "github.com/bazelbuild/remote-apis-sdks/go/pkg/retry" repb "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2" log "github.com/golang/glog" "github.com/klauspost/compress/zstd" @@ -121,6 +121,11 @@ func MakeCompressedWriteResourceName(instanceName, hash string, size int64) stri return fmt.Sprintf("%s/uploads/%s/compressed-blobs/zstd/%s/%d", instanceName, uuid.New(), hash, size) } +// IsCompressedWriteResourceName returns true if the name was generated with MakeCompressedWriteResourceName. +func IsCompressedWriteResourceName(name string) bool { + return strings.Contains(name, "compressed-blobs/zstd") +} + // BatchingUplodaer provides a blocking interface to query and upload to the CAS. type BatchingUploader struct { *uploader @@ -141,6 +146,9 @@ type uploader struct { batchRPCCfg GRPCConfig streamRPCCfg GRPCConfig + // gRPC throttling controls. + streamThrottle *throttler // Controls concurrent calls to the byte streaming API. + // IO controls. ioCfg IOConfig buffers sync.Pool @@ -240,6 +248,8 @@ func newUploader( batchRPCCfg: uploadCfg, streamRPCCfg: streamCfg, + streamThrottle: newThrottler(int64(streamCfg.ConcurrentCallsLimit)), + ioCfg: ioCfg, buffers: sync.Pool{ New: func() any { @@ -264,7 +274,3 @@ func newUploader( return u, nil } - -func (u *uploader) withRetry(ctx context.Context, predicate retry.ShouldRetry, policy retry.BackoffPolicy, fn func() error) error { - return retry.WithPolicy(ctx, predicate, policy, fn) -}