diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9b38a92 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +# Binary generated by `go build`. +/kajiya diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..98b88e4 --- /dev/null +++ b/LICENSE @@ -0,0 +1,27 @@ +// Copyright 2023 The Chromium Authors. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..6626e7b --- /dev/null +++ b/README.md @@ -0,0 +1,39 @@ +# 🔥 鍛冶屋 (Kajiya) + +Kajiya is an RBE-compatible REAPI backend implementation used as a testing +server during development of Chromium's new build tooling. It is not meant +for production use, but can be very useful for local testing of any remote +execution related code. + +## How to use + +```shell +$ go build && ./kajiya + +# Build Bazel using kajiya as the backend. +$ bazel build --remote_executor=grpc://localhost:50051 //src:bazel + +# Build Chromium with autoninja + reclient using kajiya as the backend. +$ gn gen out/default --args="use_remoteexec=true" +$ env \ + RBE_automatic_auth=false \ + RBE_service="localhost:50051" \ + RBE_service_no_security=true \ + RBE_service_no_auth=true \ + RBE_compression_threshold=-1 \ + autoninja -C out/default -j $(nproc) chrome +``` + +## Features + +Kajiya can act as an REAPI remote cache and/or remote executor. By default, both +services are provided, but you can also run an executor without a cache, or a +cache without an executor: + +```shell +# Remote execution without caching +$ ./kajiya -cache=false + +# Remote caching without execution (clients must upload action results) +$ ./kajiya -execution=false +``` diff --git a/actioncache/local.go b/actioncache/local.go new file mode 100644 index 0000000..b4bdb73 --- /dev/null +++ b/actioncache/local.go @@ -0,0 +1,104 @@ +// Copyright 2023 The Chromium Authors +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package actioncache + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/bazelbuild/remote-apis-sdks/go/pkg/digest" + remote "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2" + "google.golang.org/protobuf/proto" +) + +// ActionCache is a simple action cache implementation that stores ActionResults on the local disk. +type ActionCache struct { + dataDir string +} + +// New creates a new local ActionCache. The data directory is created if it does not exist. +func New(dataDir string) (*ActionCache, error) { + if dataDir == "" { + return nil, fmt.Errorf("data directory must be specified") + } + + if err := os.MkdirAll(dataDir, 0755); err != nil { + return nil, err + } + + // Create subdirectories {00, 01, ..., ff} for sharding by hash prefix. + for i := 0; i <= 255; i++ { + err := os.Mkdir(filepath.Join(dataDir, fmt.Sprintf("%02x", i)), 0755) + if err != nil { + if os.IsExist(err) { + continue + } + return nil, err + } + } + + return &ActionCache{ + dataDir: dataDir, + }, nil +} + +// path returns the path to the file with digest d in the action cache. +func (c *ActionCache) path(d digest.Digest) string { + return filepath.Join(c.dataDir, d.Hash[:2], d.Hash) +} + +// Get returns the cached ActionResult for the given digest. +func (c *ActionCache) Get(actionDigest digest.Digest) (*remote.ActionResult, error) { + p := c.path(actionDigest) + + // Read the action result for the requested action into a byte slice. + buf, err := os.ReadFile(p) + if err != nil { + return nil, err + } + + // Unmarshal it into an ActionResult message and return it to the client. + actionResult := &remote.ActionResult{} + if err := proto.Unmarshal(buf, actionResult); err != nil { + return nil, err + } + + return actionResult, nil +} + +// Put stores the given ActionResult for the given digest. +func (c *ActionCache) Put(actionDigest digest.Digest, ar *remote.ActionResult) error { + // Marshal the action result. + actionResultRaw, err := proto.Marshal(ar) + if err != nil { + return err + } + + // Store the action result in our action cache. + f, err := os.CreateTemp(c.dataDir, "tmp_") + if err != nil { + return err + } + if _, err := f.Write(actionResultRaw); err != nil { + f.Close() + return err + } + if err := f.Close(); err != nil { + return err + } + if err := os.Rename(f.Name(), c.path(actionDigest)); err != nil { + // TODO: It's possible that on Windows we cannot rename the file to the destination because it already exists. + // In that case, we should check if the file is identical to the one we're trying to write, and if so, ignore the error. + return err + } + + return nil +} + +// Remove deletes the cached ActionResult for the given digest. +func (c *ActionCache) Remove(d digest.Digest) error { + return fmt.Errorf("not implemented yet") +} diff --git a/actioncache/service.go b/actioncache/service.go new file mode 100644 index 0000000..eb394ab --- /dev/null +++ b/actioncache/service.go @@ -0,0 +1,164 @@ +// Copyright 2023 The Chromium Authors +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Package actioncache implements the REAPI ActionCache service. +package actioncache + +import ( + "context" + "fmt" + "log" + "os" + + "github.com/bazelbuild/remote-apis-sdks/go/pkg/digest" + remote "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/philwo/kajiya/blobstore" +) + +// Service implements the REAPI ActionCache service. +type Service struct { + remote.UnimplementedActionCacheServer + + // The ActionCache to use for storing ActionResults. + ac *ActionCache + + // The blobstore.ContentAddressableStorage to use for reading blobs. + cas *blobstore.ContentAddressableStorage +} + +// Register creates and registers a new Service with the given gRPC server. +func Register(s *grpc.Server, ac *ActionCache, cas *blobstore.ContentAddressableStorage) error { + service, err := NewService(ac, cas) + if err != nil { + return err + } + remote.RegisterActionCacheServer(s, service) + return nil +} + +// NewService creates a new Service. +func NewService(ac *ActionCache, cas *blobstore.ContentAddressableStorage) (Service, error) { + if ac == nil { + return Service{}, fmt.Errorf("ac must be set") + } + + if cas == nil { + return Service{}, fmt.Errorf("cas must be set") + } + + return Service{ + ac: ac, + cas: cas, + }, nil +} + +// GetActionResult returns the ActionResult for a given action digest. +func (s Service) GetActionResult(ctx context.Context, request *remote.GetActionResultRequest) (*remote.ActionResult, error) { + response, err := s.getActionResult(request) + if err != nil { + if status.Code(err) == codes.NotFound { + log.Printf("⚠️ GetActionResult(%v) => Cache miss", request.ActionDigest) + } else { + log.Printf("🚨 GetActionResult(%v) => Error: %v", request.ActionDigest, err) + } + } else { + log.Printf("🎉 GetActionResult(%v) => Cache hit", request.ActionDigest) + } + return response, err +} + +func (s Service) getActionResult(request *remote.GetActionResultRequest) (*remote.ActionResult, error) { + // If the client explicitly specifies a DigestFunction, ensure that it's SHA256. + if request.DigestFunction != remote.DigestFunction_UNKNOWN && request.DigestFunction != remote.DigestFunction_SHA256 { + return nil, status.Errorf(codes.InvalidArgument, "hash function %q is not supported", request.DigestFunction.String()) + } + + actionDigest, err := digest.NewFromProto(request.ActionDigest) + if err != nil { + return nil, status.Error(codes.InvalidArgument, err.Error()) + } + + actionResult, err := s.ac.Get(actionDigest) + if err != nil { + if os.IsNotExist(err) { + return nil, status.Errorf(codes.NotFound, "action digest %s not found in cache", actionDigest) + } + return nil, status.Error(codes.Internal, err.Error()) + } + + return actionResult, nil +} + +// UpdateActionResult stores an ActionResult for a given action digest on disk. +func (s Service) UpdateActionResult(ctx context.Context, request *remote.UpdateActionResultRequest) (*remote.ActionResult, error) { + response, err := s.updateActionResult(request) + if err != nil { + log.Printf("🚨 UpdateActionResult(%v) => Error: %v", request.ActionDigest, err) + } else { + log.Printf("✅ UpdateActionResult(%v) => OK", request.ActionDigest) + } + return response, err +} + +func (s Service) updateActionResult(request *remote.UpdateActionResultRequest) (*remote.ActionResult, error) { + // If the client explicitly specifies a DigestFunction, ensure that it's SHA256. + if request.DigestFunction != remote.DigestFunction_UNKNOWN && request.DigestFunction != remote.DigestFunction_SHA256 { + return nil, status.Errorf(codes.InvalidArgument, "hash function %q is not supported", request.DigestFunction.String()) + } + + // Check that the client didn't send inline stdout / stderr data. + if request.ActionResult.StdoutRaw != nil { + return nil, status.Error(codes.InvalidArgument, "client should not populate stdout_raw during upload") + } + if request.ActionResult.StderrRaw != nil { + return nil, status.Error(codes.InvalidArgument, "client should not populate stderr_raw during upload") + } + + // Check that the action digest is valid. + actionDigest, err := digest.NewFromProto(request.ActionDigest) + if err != nil { + return nil, status.Error(codes.InvalidArgument, err.Error()) + } + + // Check that the action is present in our CAS. + if _, err := s.cas.Stat(actionDigest); err != nil { + return nil, status.Errorf(codes.NotFound, "action digest %s not found in CAS", actionDigest) + } + + // If the action result contains a stdout digest, check that it is present in our CAS. + if request.ActionResult.StdoutDigest != nil { + stdoutDigest, err := digest.NewFromProto(request.ActionResult.StdoutDigest) + if err != nil { + return nil, status.Error(codes.InvalidArgument, err.Error()) + } + if _, err := s.cas.Stat(stdoutDigest); err != nil { + return nil, status.Errorf(codes.NotFound, "stdout digest %s not found in CAS", stdoutDigest) + } + } + + // Same for stderr. + if request.ActionResult.StderrDigest != nil { + stderrDigest, err := digest.NewFromProto(request.ActionResult.StderrDigest) + if err != nil { + return nil, status.Error(codes.InvalidArgument, err.Error()) + } + if _, err := s.cas.Stat(stderrDigest); err != nil { + return nil, status.Errorf(codes.NotFound, "stderr digest %s not found in CAS", stderrDigest) + } + } + + // TODO: Check that all the output files are present in our CAS. + + // Store the action result. + if err := s.ac.Put(actionDigest, request.ActionResult); err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } + + // Return the action result. + return request.ActionResult, nil +} diff --git a/blobstore/fastcopy.go b/blobstore/fastcopy.go new file mode 100644 index 0000000..5427771 --- /dev/null +++ b/blobstore/fastcopy.go @@ -0,0 +1,16 @@ +// Copyright 2023 The Chromium Authors +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +//go:build !darwin + +package blobstore + +import "os" + +// fastCopy copies a file from source to destination using a hard link. +// This is usually the best we can do, unless the operating system supports +// copy-on-write semantics for files (e.g. macOS with APFS). +func fastCopy(source, destination string) error { + return os.Link(source, destination) +} diff --git a/blobstore/fastcopy_darwin.go b/blobstore/fastcopy_darwin.go new file mode 100644 index 0000000..3ca7b9f --- /dev/null +++ b/blobstore/fastcopy_darwin.go @@ -0,0 +1,18 @@ +// Copyright 2023 The Chromium Authors +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +//go:build darwin + +package blobstore + +import ( + "golang.org/x/sys/unix" +) + +// fastCopy copies a file from source to destination using a clonefile syscall. +// This is nicer than using a hard link, because it means that even if the file +// is accidentally modified, the copy will still have the original contents. +func fastCopy(source, destination string) error { + return unix.Clonefile(source, destination, unix.CLONE_NOFOLLOW) +} diff --git a/blobstore/io.go b/blobstore/io.go new file mode 100644 index 0000000..46a2f95 --- /dev/null +++ b/blobstore/io.go @@ -0,0 +1,23 @@ +// Copyright 2023 The Chromium Authors +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package blobstore + +import ( + "io" +) + +// LimitedReadCloser is an io.ReadCloser that limits the number of bytes that can be read. +type LimitedReadCloser struct { + *io.LimitedReader + io.Closer +} + +// LimitReadCloser wraps an io.LimitedReader in a LimitedReadCloser. +func LimitReadCloser(r io.ReadCloser, limit int64) io.ReadCloser { + return &LimitedReadCloser{ + LimitedReader: &io.LimitedReader{R: r, N: limit}, + Closer: r, + } +} diff --git a/blobstore/local.go b/blobstore/local.go new file mode 100644 index 0000000..508ccb5 --- /dev/null +++ b/blobstore/local.go @@ -0,0 +1,200 @@ +// Copyright 2023 The Chromium Authors +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package blobstore + +import ( + "fmt" + "io" + "io/fs" + "log" + "os" + "path/filepath" + + "github.com/bazelbuild/remote-apis-sdks/go/pkg/digest" +) + +// ContentAddressableStorage is a simple CAS implementation that stores files on the local disk. +type ContentAddressableStorage struct { + dataDir string +} + +// New creates a new local CAS. The data directory is created if it does not exist. +func New(dataDir string) (*ContentAddressableStorage, error) { + if dataDir == "" { + return nil, fmt.Errorf("data directory must be specified") + } + + if err := os.MkdirAll(dataDir, 0755); err != nil { + return nil, err + } + + // Create subdirectories {00, 01, ..., ff} for sharding by hash prefix. + for i := 0; i <= 255; i++ { + err := os.Mkdir(filepath.Join(dataDir, fmt.Sprintf("%02x", i)), 0755) + if err != nil { + if os.IsExist(err) { + continue + } + return nil, err + } + } + + cas := &ContentAddressableStorage{ + dataDir: dataDir, + } + + // Ensure that we have the "empty blob" present in the CAS. + // Clients will usually not upload it, but just assume that it's always available. + // A faster way would be to special case the empty digest in the CAS implementation, + // but this is simpler and more robust. + d, err := cas.Put(nil) + if err != nil { + return nil, err + } + if d != digest.Empty { + return nil, fmt.Errorf("empty blob did not have expected hash: got %s, wanted %s", d, digest.Empty) + } + + return cas, nil +} + +// path returns the path to the file with digest d in the CAS. +func (c *ContentAddressableStorage) path(d digest.Digest) string { + return filepath.Join(c.dataDir, d.Hash[:2], d.Hash) +} + +// Stat returns os.FileInfo for the requested digest if it exists. +func (c *ContentAddressableStorage) Stat(d digest.Digest) (os.FileInfo, error) { + p := c.path(d) + + fi, err := os.Lstat(p) + if err != nil { + return nil, err + } + + if fi.Size() != d.Size { + log.Printf("actual file size %d does not match requested size of digest %s", fi.Size(), d.String()) + return nil, fs.ErrNotExist + } + + return fi, nil +} + +// Open returns an io.ReadCloser for the requested digest if it exists. +// The returned ReadCloser is limited to the given offset and limit. +// The offset must be non-negative and no larger than the file size. +// A limit of 0 means no limit, and a limit that's larger than the file size is truncated to the file size. +func (c *ContentAddressableStorage) Open(d digest.Digest, offset int64, limit int64) (io.ReadCloser, error) { + p := c.path(d) + + f, err := os.Open(p) + if err != nil { + return nil, err + } + + // Ensure that the file has the expected size. + size, err := f.Seek(0, io.SeekEnd) + if err != nil { + f.Close() + return nil, err + } + + if size != d.Size { + log.Printf("actual file size %d does not match requested size of digest %s", offset, d.String()) + f.Close() + return nil, fs.ErrNotExist + } + + // Ensure that the offset is not negative and not larger than the file size. + if offset < 0 || offset > size { + f.Close() + return nil, fs.ErrInvalid + } + + // Seek to the requested offset. + if _, err := f.Seek(offset, io.SeekStart); err != nil { + f.Close() + return nil, err + } + + // Cap the limit to the file size, taking the offset into account. + if limit == 0 || limit > size-offset { + limit = size - offset + } + + return LimitReadCloser(f, limit), nil +} + +// Get reads a file for the given digest from disk and returns its contents. +func (c *ContentAddressableStorage) Get(d digest.Digest) ([]byte, error) { + // Just call Open and read the whole file. + f, err := c.Open(d, 0, 0) + if err != nil { + return nil, err + } + defer f.Close() // error is safe to ignore, because we're just reading + return io.ReadAll(f) +} + +// Put stores the given data in the CAS and returns its digest. +func (c *ContentAddressableStorage) Put(data []byte) (digest.Digest, error) { + d := digest.NewFromBlob(data) + p := c.path(d) + + // Check if the file already exists. + // This is a fast path that avoids writing the file if it already exists. + if _, err := os.Stat(p); err == nil { + return d, nil + } + + // Write the file to a temporary location and then rename it. + // This ensures that we don't accidentally serve a partial file if the process is killed while writing. + // It also ensures that we don't serve a file that's still being written. + f, err := os.CreateTemp(c.dataDir, "tmp_") + if err != nil { + return d, err + } + if _, err := f.Write(data); err != nil { + f.Close() + return d, err + } + if err := f.Close(); err != nil { + return d, err + } + if err := os.Rename(f.Name(), p); err != nil { + // This might happen on Windows if the file already exists and we can't replace it. + // Because this is a CAS, we can assume that the file contains the same data that we wanted to write. + // So we can ignore this error. + if os.IsExist(err) { + return d, nil + } + return d, err + } + + return d, nil +} + +// Adopt moves a file from the given path into the CAS. +// The digest is assumed to have been validated by the caller. +func (c *ContentAddressableStorage) Adopt(d digest.Digest, path string) error { + err := os.Rename(path, c.path(d)) + if err != nil { + if os.IsExist(err) { + // The file already exists, so we can ignore this error. + // This might happen on Windows if the file already exists, + // and we can't replace it. + return nil + } + return err + } + return nil +} + +// LinkTo creates a link `path` pointing to the file with digest `d` in the CAS. +// If the operating system supports cloning files via copy-on-write semantics, +// the file is cloned instead of hard linked. +func (c *ContentAddressableStorage) LinkTo(d digest.Digest, path string) error { + return fastCopy(c.path(d), path) +} diff --git a/blobstore/service.go b/blobstore/service.go new file mode 100644 index 0000000..b924982 --- /dev/null +++ b/blobstore/service.go @@ -0,0 +1,585 @@ +// Copyright 2023 The Chromium Authors +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Package cas implements the REAPI ContentAddressableStorage and ByteStream services. +package blobstore + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "io" + "log" + "os" + "path/filepath" + "strconv" + "strings" + + "github.com/bazelbuild/remote-apis-sdks/go/pkg/digest" + remote "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2" + "github.com/google/uuid" + "google.golang.org/genproto/googleapis/bytestream" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" +) + +const ( + // The maximum chunk size to write back to the client in Send calls. + maxChunkSize int64 = 2 * 1024 * 1024 // 2M +) + +// Service implements the REAPI ContentAddressableStorage and ByteStream services. +type Service struct { + remote.UnimplementedContentAddressableStorageServer + bytestream.UnimplementedByteStreamServer + + cas *ContentAddressableStorage + uploadDir string +} + +// Register creates and registers a new Service with the given gRPC server. +// The dataDir is created if it does not exist. +func Register(s *grpc.Server, cas *ContentAddressableStorage, dataDir string) error { + service, err := NewService(cas, dataDir) + if err != nil { + return err + } + bytestream.RegisterByteStreamServer(s, service) + remote.RegisterContentAddressableStorageServer(s, service) + return nil +} + +// NewService creates a new Service. +func NewService(cas *ContentAddressableStorage, uploadDir string) (*Service, error) { + if uploadDir == "" { + return nil, errors.New("uploadDir must be set") + } + + // Ensure that our temporary upload directory exists. + if err := os.MkdirAll(uploadDir, 0755); err != nil { + return nil, err + } + + return &Service{ + cas: cas, + uploadDir: uploadDir, + }, nil +} + +// parseReadResource parses a ReadRequest.ResourceName and returns the validated Digest. +// The resource name should be of the format: {instance_name}/blobs/{hash}/{size} +func parseReadResource(name string) (d digest.Digest, err error) { + fields := strings.Split(name, "/") + + // Strip any parts before "blobs", as they'll belong to an instance name. + for i := range fields { + if fields[i] == "blobs" { + fields = fields[i:] + break + } + } + + if len(fields) != 3 || fields[0] != "blobs" { + return d, status.Errorf(codes.InvalidArgument, "invalid resource name, must match format {instance_name}/blobs/{hash}/{size}: %s", name) + } + + hash := fields[1] + size, err := strconv.ParseInt(fields[2], 10, 64) + if err != nil { + return d, status.Errorf(codes.InvalidArgument, "invalid resource name, fourth component (size) must be an integer: %s", fields[2]) + } + if size < 0 { + return d, status.Errorf(codes.InvalidArgument, "invalid resource name, fourth component (size) must be non-negative: %d", size) + } + d, err = digest.New(hash, size) + if err != nil { + return d, status.Errorf(codes.InvalidArgument, "invalid resource name, third component is not a valid digest: %s => %s", hash, err) + } + + return d, nil +} + +// parseWriteResource parses a WriteRequest.ResourceName and returns the validated Digest and upload ID. +// The resource name must be of the form: {instance_name}/uploads/{uuid}/blobs/{hash}/{size}[/{optionalmetadata}] +func parseWriteResource(name string) (d digest.Digest, u string, err error) { + fields := strings.Split(name, "/") + + // Strip any parts before "uploads", as they'll belong to an instance name. + for i := range fields { + if fields[i] == "uploads" { + fields = fields[i:] + break + } + } + + if len(fields) < 5 || fields[0] != "uploads" || fields[2] != "blobs" { + return d, u, status.Errorf(codes.InvalidArgument, "invalid resource name, must follow format {instance_name}/uploads/{uuid}/blobs/{hash}/{size}[/{optionalmetadata}]: %s", name) + } + + uuid, err := uuid.Parse(fields[1]) + if err != nil { + return d, u, status.Errorf(codes.InvalidArgument, "invalid resource name, second component is not a UUID: %s", fields[1]) + } + u = uuid.String() + + hash := fields[3] + + size, err := strconv.ParseInt(fields[4], 10, 64) + if err != nil { + return d, u, status.Errorf(codes.InvalidArgument, "invalid resource name, fifth component (size) must be an integer: %s", fields[4]) + } + if size < 0 { + return d, u, status.Errorf(codes.InvalidArgument, "invalid resource name, fifth component (size) must be non-negative: %d", size) + } + d, err = digest.New(hash, size) + if err != nil { + return d, u, status.Errorf(codes.InvalidArgument, "invalid resource name, fourth component is not a valid digest: %s => %s", hash, err) + } + + return d, u, nil +} + +// Read implements the ByteStream.Read RPC. +func (s *Service) Read(request *bytestream.ReadRequest, server bytestream.ByteStream_ReadServer) error { + err := s.read(request, server) + if err != nil { + log.Printf("🚨 Read(%v) => Error: %v", request.ResourceName, err) + } else { + log.Printf("✅ Read(%v) => OK", request.ResourceName) + } + return err +} + +func (s *Service) read(request *bytestream.ReadRequest, server bytestream.ByteStream_ReadServer) error { + d, err := parseReadResource(request.ResourceName) + if err != nil { + return err + } + + // A `read_offset` that is negative or greater than the size of the resource + // will cause an `OUT_OF_RANGE` error. + if request.ReadOffset < 0 { + return status.Error(codes.OutOfRange, "offset is negative") + } + if request.ReadOffset > d.Size { + return status.Error(codes.OutOfRange, "offset is greater than the size of the file") + } + + // Prepare a buffer to read the file into. + bufSize := maxChunkSize + if request.ReadLimit > 0 && request.ReadLimit < bufSize { + bufSize = request.ReadLimit + } + buf := make([]byte, bufSize) + + // Open the file and seek to the offset. + f, err := s.cas.Open(d, request.ReadOffset, request.ReadLimit) + if err != nil { + return status.Errorf(codes.Internal, "failed to open file: %v", err) + } + defer f.Close() // OK to ignore error here, since we're only reading. + + // Send the requested data to the client in chunks. + for { + n, err := f.Read(buf) + if n > 0 { + if err := server.Send(&bytestream.ReadResponse{ + Data: buf[:n], + }); err != nil { + return status.Errorf(codes.Internal, "failed to send data to client: %v", err) + } + } + if err == io.EOF { + break + } + if err != nil { + return status.Errorf(codes.Internal, "failed to read data from file: %v", err) + } + } + + return nil +} + +// Write implements the ByteStream.Write RPC. +func (s *Service) Write(server bytestream.ByteStream_WriteServer) error { + resourceName, err := s.write(server) + if err != nil { + log.Printf("🚨 Write(%v) => Error: %v", resourceName, err) + } else { + log.Printf("✅ Write(%v) => OK", resourceName) + } + return err +} + +func (s *Service) write(server bytestream.ByteStream_WriteServer) (resource string, err error) { + expectedDigest := digest.Empty + ourHash := sha256.New() + var committedSize int64 + finishedWriting := false + var tempFile *os.File + var tempPath string + defer func() { + if tempFile != nil { + if err := tempFile.Close(); err != nil { + log.Printf("could not close temporary file %q: %v", tempPath, err) + } + if err := os.Remove(tempPath); err != nil { + log.Printf("could not delete temporary file %q: %v", tempPath, err) + } + } + }() + + for { + // Receive a request from the client. + request, err := server.Recv() + if err == io.EOF { + // If the client closed the connection without ever sending a request, return an error. + if resource == "" { + return resource, status.Error(codes.InvalidArgument, "no resource name provided") + } + + // Check that the client set "finish_write" to true. + if !finishedWriting { + return resource, status.Error(codes.InvalidArgument, "upload finished without finish_write set") + } + + // Check that the digests (= hash and size) match. + d := digest.Digest{Hash: hex.EncodeToString(ourHash.Sum(nil)), Size: committedSize} + if d != expectedDigest { + return resource, status.Errorf(codes.InvalidArgument, "computed digest %v did not match expected digest %v", d, expectedDigest) + } + + // Move the temporary file to the CAS. + if err := s.cas.Adopt(expectedDigest, tempPath); err != nil { + return resource, status.Errorf(codes.Internal, "failed to move file into CAS: %v", err) + } + + // Send the response to the client. + if err := server.SendAndClose(&bytestream.WriteResponse{ + CommittedSize: committedSize, + }); err != nil { + return resource, status.Errorf(codes.Internal, "failed to send response to client: %v", err) + } + + // Yay, we're done! + return resource, nil + } else if err != nil { + return resource, status.Errorf(codes.Internal, "failed to receive request from client: %v", err) + } + + // If the resource name is empty, this is the first request from the client. + if resource == "" { + if request.ResourceName == "" { + return resource, status.Errorf(codes.InvalidArgument, "must set resource name on first request") + } + resource = request.ResourceName + var uuid string + expectedDigest, uuid, err = parseWriteResource(request.ResourceName) + if err != nil { + return resource, err + } + tempPath = filepath.Join(s.uploadDir, uuid) + } else { + // Ensure that the resource name is either not set, or the same as the first request. + if request.ResourceName != "" && request.ResourceName != resource { + return resource, status.Errorf(codes.InvalidArgument, "resource name changed (%v => %v)", resource, request.ResourceName) + } + } + + if finishedWriting { + return resource, status.Error(codes.InvalidArgument, "cannot write more data after finish_write was true") + } + + // If the resource was uploaded concurrently and already exists in our CAS, immediately return success. + if _, err := s.cas.Stat(expectedDigest); err == nil { + return resource, server.SendAndClose(&bytestream.WriteResponse{ + CommittedSize: expectedDigest.Size, + }) + } + + // Create the file for the pending upload if this is the first write. + if tempFile == nil && !finishedWriting { + tempFile, err = os.OpenFile(tempPath, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0644) + if err != nil { + if os.IsExist(err) { + return resource, status.Errorf(codes.InvalidArgument, "upload with same uuid already in progress") + } + return resource, status.Errorf(codes.Internal, "could not create temporary file for upload: %v", err) + } + } + + // Append the received data to the temporary file and hash it. + if _, err := tempFile.Write(request.Data); err != nil { + return resource, status.Errorf(codes.Internal, "failed to write data to temporary file: %v", err) + } + ourHash.Write(request.Data) + committedSize += int64(len(request.Data)) + + // If the file is already larger than the expected size, something is wrong - return an error. + if committedSize > expectedDigest.Size { + return resource, status.Errorf(codes.InvalidArgument, "received %d bytes, more than expected %d", committedSize, expectedDigest.Size) + } + + if request.FinishWrite { + finishedWriting = true + err = tempFile.Close() + if err != nil { + return resource, status.Errorf(codes.Internal, "could not close temporary file: %v", err) + } + // We set tempFile to `nil` *after* checking for an error to give the defer handler a last + // chance to clean things up... + tempFile = nil + } + } +} + +// QueryWriteStatus implements the ByteStream.QueryWriteStatus RPC. +func (s *Service) QueryWriteStatus(ctx context.Context, request *bytestream.QueryWriteStatusRequest) (*bytestream.QueryWriteStatusResponse, error) { + response, err := s.queryWriteStatus(request) + if err != nil { + log.Printf("🚨 QueryWriteStatus(%v) failed: %s", request.ResourceName, err) + } else { + log.Printf("✅ QueryWriteStatus(%v) succeeded", request.ResourceName) + } + return response, err +} + +func (s *Service) queryWriteStatus(request *bytestream.QueryWriteStatusRequest) (*bytestream.QueryWriteStatusResponse, error) { + d, _, err := parseWriteResource(request.ResourceName) + if err != nil { + return nil, err + } + + // Check if the file exists in the CAS, if yes, the upload is complete. + if _, err := s.cas.Stat(d); err == nil { + return &bytestream.QueryWriteStatusResponse{ + CommittedSize: d.Size, + Complete: true, + }, nil + } + + // We don't support resuming uploads yet, so just always return that we don't have any data. + return &bytestream.QueryWriteStatusResponse{ + CommittedSize: 0, + Complete: false, + }, nil +} + +// FindMissingBlobs implements the ContentAddressableStorage.FindMissingBlobs RPC. +func (s *Service) FindMissingBlobs(ctx context.Context, request *remote.FindMissingBlobsRequest) (*remote.FindMissingBlobsResponse, error) { + response, err := s.findMissingBlobs(request) + if err != nil { + log.Printf("🚨 FindMissingBlobs(%d blobs) => Error: %v", len(request.BlobDigests), err) + } else { + log.Printf("✅ FindMissingBlobs(%d blobs) => OK (%d missing)", len(request.BlobDigests), len(response.MissingBlobDigests)) + } + return response, err +} + +func (s *Service) findMissingBlobs(request *remote.FindMissingBlobsRequest) (*remote.FindMissingBlobsResponse, error) { + // If the client explicitly specifies a DigestFunction, ensure that it's SHA256. + if request.DigestFunction != remote.DigestFunction_UNKNOWN && request.DigestFunction != remote.DigestFunction_SHA256 { + return nil, status.Errorf(codes.InvalidArgument, "hash function %q is not supported", request.DigestFunction.String()) + } + + // Make a list that stores the missing blobs. We set the capacity so that we never have to reallocate. + missing := make([]*remote.Digest, 0, len(request.BlobDigests)) + + // For each blob in the list, check if it exists in the CAS. If not, add it to the list of missing blobs. + for _, d := range request.BlobDigests { + dg, err := digest.NewFromProto(d) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid digest: %v", err) + } + if _, err := s.cas.Stat(dg); err != nil { + missing = append(missing, d) + } + } + + // Return the list of missing blobs to the client. + return &remote.FindMissingBlobsResponse{ + MissingBlobDigests: missing, + }, nil +} + +// BatchUpdateBlobs implements the ContentAddressableStorage.BatchUpdateBlobs RPC. +func (s *Service) BatchUpdateBlobs(ctx context.Context, request *remote.BatchUpdateBlobsRequest) (*remote.BatchUpdateBlobsResponse, error) { + response, err := s.batchUploadBlobs(request) + if err != nil { + log.Printf("🚨 BatchUpdateBlobs(%v blobs) => Error: %v", len(request.Requests), err) + } else { + log.Printf("✅ BatchUpdateBlobs(%v blobs) => OK", len(request.Requests)) + } + return response, err +} + +func (s *Service) batchUploadBlobs(request *remote.BatchUpdateBlobsRequest) (*remote.BatchUpdateBlobsResponse, error) { + // If the client explicitly specifies a DigestFunction, ensure that it's SHA256. + if request.DigestFunction != remote.DigestFunction_UNKNOWN && request.DigestFunction != remote.DigestFunction_SHA256 { + return nil, status.Errorf(codes.InvalidArgument, "hash function %q is not supported", request.DigestFunction.String()) + } + + // Prepare a response that we can fill in. + response := &remote.BatchUpdateBlobsResponse{ + Responses: make([]*remote.BatchUpdateBlobsResponse_Response, 0, len(request.Requests)), + } + + // For each blob in the list, check if it exists in the CAS. If not, write it to the CAS. + for _, blob := range request.Requests { + // Ensure that the client didn't send compressed data. + if blob.Compressor != remote.Compressor_IDENTITY { + return nil, status.Errorf(codes.InvalidArgument, "compressed data is not supported") + } + + // Parse the digest. + expectedDigest, err := digest.NewFromProto(blob.Digest) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid digest: %v", err) + } + + // Store the blob in our CAS. + actualDigest, err := s.cas.Put(blob.Data) + if err != nil { + return nil, status.Errorf(codes.Internal, "could not store blob in CAS: %v", err) + } + + // Check that the calculated digest matches the data. + if actualDigest != expectedDigest { + return nil, status.Errorf(codes.InvalidArgument, "digest does not match data") + } + + // Add the response to the list. + response.Responses = append(response.Responses, &remote.BatchUpdateBlobsResponse_Response{ + Digest: blob.Digest, + Status: status.New(codes.OK, "").Proto(), + }) + } + + // Return the response to the client. + return response, nil +} + +func (s *Service) BatchReadBlobs(ctx context.Context, request *remote.BatchReadBlobsRequest) (*remote.BatchReadBlobsResponse, error) { + response, err := s.batchReadBlobs(request) + if err != nil { + log.Printf("🚨 BatchReadBlobs(%v blobs) => Error: %v", len(request.Digests), err) + } else { + log.Printf("✅ BatchReadBlobs(%v blobs) => OK", len(request.Digests)) + } + return response, err +} + +func (s *Service) batchReadBlobs(request *remote.BatchReadBlobsRequest) (*remote.BatchReadBlobsResponse, error) { + // If the client explicitly specifies a DigestFunction, ensure that it's SHA256. + if request.DigestFunction != remote.DigestFunction_UNKNOWN && request.DigestFunction != remote.DigestFunction_SHA256 { + return nil, status.Errorf(codes.InvalidArgument, "hash function %q is not supported", request.DigestFunction.String()) + } + + // Prepare a response that we can fill in. + response := &remote.BatchReadBlobsResponse{ + Responses: make([]*remote.BatchReadBlobsResponse_Response, 0, len(request.Digests)), + } + + // For each blob in the list, check if it exists in the CAS. If yes, read it from the CAS. + for _, d := range request.Digests { + // Parse the digest. + dg, err := digest.NewFromProto(d) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid digest: %v", err) + } + + // Read the blob from the CAS. + data, err := s.cas.Get(dg) + if err != nil { + if os.IsNotExist(err) { + // The blob doesn't exist. Add a response with an appropriate status code. + response.Responses = append(response.Responses, &remote.BatchReadBlobsResponse_Response{ + Digest: d, + Status: status.New(codes.NotFound, "").Proto(), + }) + continue + } else { + return nil, status.Errorf(codes.Internal, "failed to read blob: %v", err) + } + } else { + // The blob exists. Add a response with the data. + response.Responses = append(response.Responses, &remote.BatchReadBlobsResponse_Response{ + Digest: d, + Data: data, + Status: status.New(codes.OK, "").Proto(), + }) + } + } + + // Return the response to the client. + return response, nil +} + +func (s *Service) GetTree(request *remote.GetTreeRequest, treeServer remote.ContentAddressableStorage_GetTreeServer) error { + if err := s.getTree(request, treeServer); err != nil { + log.Printf("🚨 GetTree(%v) => Error: %v", request.RootDigest, err) + return err + } else { + log.Printf("✅ GetTree(%v) => OK", request.RootDigest) + } + return nil +} + +func (s *Service) getTree(request *remote.GetTreeRequest, treeServer remote.ContentAddressableStorage_GetTreeServer) error { + // If the client explicitly specifies a DigestFunction, ensure that it's SHA256. + if request.DigestFunction != remote.DigestFunction_UNKNOWN && request.DigestFunction != remote.DigestFunction_SHA256 { + return status.Errorf(codes.InvalidArgument, "hash function %q is not supported", request.DigestFunction.String()) + } + + // Prepare a response that we can fill in. + response := &remote.GetTreeResponse{ + Directories: make([]*remote.Directory, 0), + } + + // Create a queue of directories to process and add the root directory. + dirQueue := []*remote.DirectoryNode{ + { + Digest: request.RootDigest, + }, + } + + // Iteratively process the directories. + for len(dirQueue) > 0 { + // Take a directoryNode from the queue. + directoryNode := dirQueue[0] + dirQueue = dirQueue[1:] + + // Parse the digest. + d, err := digest.NewFromProto(directoryNode.Digest) + if err != nil { + return status.Errorf(codes.InvalidArgument, "invalid digest: %v", err) + } + + // Get the blob for the directory message from the CAS. + directoryBlob, err := s.cas.Get(d) + if err != nil { + return status.Errorf(codes.NotFound, "directory not found: %v", err) + } + + // Unmarshal the directory message. + var directory *remote.Directory + if err := proto.Unmarshal(directoryBlob, directory); err != nil { + return status.Errorf(codes.Internal, "failed to unmarshal directory: %v", err) + } + + // Add the directory to the response. + response.Directories = append(response.Directories, directory) + + // Add all subdirectory nodes to the queue. + dirQueue = append(dirQueue, directory.Directories...) + } + + // TODO: Add support for pagination? + + // Send the tree to the client. + return treeServer.Send(response) +} diff --git a/capabilities/service.go b/capabilities/service.go new file mode 100644 index 0000000..7bc6eb1 --- /dev/null +++ b/capabilities/service.go @@ -0,0 +1,82 @@ +// Copyright 2023 The Chromium Authors +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Package capabilities implements the REAPI Capabilities service. +package capabilities + +import ( + "context" + "log" + + remote "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2" + "github.com/bazelbuild/remote-apis/build/bazel/semver" + "google.golang.org/grpc" +) + +// Service implements the REAPI Capabilities service. +type Service struct { + remote.UnimplementedCapabilitiesServer +} + +// Register creates and registers a new Service with the given gRPC server. +func Register(s *grpc.Server) { + remote.RegisterCapabilitiesServer(s, NewService()) +} + +// NewService creates a new Service. +func NewService() *Service { + return &Service{} +} + +// GetCapabilities returns the capabilities of the server. +func (s *Service) GetCapabilities(ctx context.Context, request *remote.GetCapabilitiesRequest) (*remote.ServerCapabilities, error) { + response, err := s.getCapabilities(request) + if err != nil { + log.Printf("⚠️ GetCapabilities(%v) => Error: %v", request, err) + } else { + log.Printf("✅ GetCapabilities(%v) => OK", request) + } + return response, err +} + +func (s *Service) getCapabilities(request *remote.GetCapabilitiesRequest) (*remote.ServerCapabilities, error) { + // Return the capabilities. + return &remote.ServerCapabilities{ + CacheCapabilities: &remote.CacheCapabilities{ + DigestFunctions: []remote.DigestFunction_Value{ + remote.DigestFunction_SHA256, + }, + ActionCacheUpdateCapabilities: &remote.ActionCacheUpdateCapabilities{ + UpdateEnabled: true, + }, + CachePriorityCapabilities: &remote.PriorityCapabilities{ + Priorities: []*remote.PriorityCapabilities_PriorityRange{ + { + MinPriority: 0, + MaxPriority: 0, + }, + }, + }, + MaxBatchTotalSizeBytes: 0, // no limit. + SymlinkAbsolutePathStrategy: remote.SymlinkAbsolutePathStrategy_DISALLOWED, // Same as RBE. + }, + ExecutionCapabilities: &remote.ExecutionCapabilities{ + DigestFunction: remote.DigestFunction_SHA256, + DigestFunctions: []remote.DigestFunction_Value{ + remote.DigestFunction_SHA256, + }, + ExecEnabled: true, + ExecutionPriorityCapabilities: &remote.PriorityCapabilities{ + Priorities: []*remote.PriorityCapabilities_PriorityRange{ + { + MinPriority: 0, + MaxPriority: 0, + }, + }, + }, + }, + LowApiVersion: &semver.SemVer{Major: 2, Minor: 0}, + HighApiVersion: &semver.SemVer{Major: 2, Minor: 0}, // RBE does not support higher versions, so we don't either. + }, nil +} diff --git a/execution/local.go b/execution/local.go new file mode 100644 index 0000000..32fe3b7 --- /dev/null +++ b/execution/local.go @@ -0,0 +1,498 @@ +// Copyright 2023 The Chromium Authors +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +package execution + +import ( + "bytes" + "errors" + "fmt" + "log" + "os" + "os/exec" + "path/filepath" + + "github.com/bazelbuild/remote-apis-sdks/go/pkg/digest" + remote "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2" + "google.golang.org/genproto/googleapis/rpc/errdetails" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + + "github.com/philwo/kajiya/blobstore" +) + +// TODO: Make this configurable via a flag. +const fastCopy = true + +type Executor struct { + execDir string + cas *blobstore.ContentAddressableStorage +} + +func New(execDir string, cas *blobstore.ContentAddressableStorage) (*Executor, error) { + if execDir == "" { + return nil, fmt.Errorf("execDir must be set") + } + + if cas == nil { + return nil, fmt.Errorf("cas must be set") + } + + // Create the data directory if it doesn't exist. + if err := os.MkdirAll(execDir, 0755); err != nil { + return nil, fmt.Errorf("failed to create directory %q: %v", execDir, err) + } + + return &Executor{ + execDir: execDir, + cas: cas, + }, nil +} + +// Execute executes the given action and returns the result. +func (e *Executor) Execute(action *remote.Action) (*remote.ActionResult, error) { + var missingBlobs []digest.Digest + + // Get the command from the CAS. + cmd, err := e.getCommand(action.CommandDigest) + if err != nil { + if os.IsNotExist(err) { + missingBlobs = append(missingBlobs, digest.NewFromProtoUnvalidated(action.CommandDigest)) + } else { + return nil, err + } + } + + // Get the input root from the CAS. + inputRoot, err := e.getDirectory(action.InputRootDigest) + if err != nil { + if os.IsNotExist(err) { + missingBlobs = append(missingBlobs, digest.NewFromProtoUnvalidated(action.InputRootDigest)) + } else { + return nil, err + } + } + + // Build an execution directory for the action. If the input root is nil, it means that its + // digest was not found in the CAS, so we skip this part and just return an error next. + var execDir string + if inputRoot != nil { + execDir, err = os.MkdirTemp(e.execDir, "*") + defer e.deleteExecutionDirectory(execDir) + if err != nil { + return nil, fmt.Errorf("failed to create execution directory: %v", err) + } + mb, err := e.materializeDirectory(execDir, inputRoot) + if err != nil { + return nil, fmt.Errorf("failed to materialize input root: %v", err) + } + missingBlobs = append(missingBlobs, mb...) + } + + // If there were any missing blobs, we fail early and return the list to the client. + if len(missingBlobs) > 0 { + return nil, e.formatMissingBlobsError(missingBlobs) + } + + // If a working directory was specified, verify that it exists. + workDir := execDir + if cmd.WorkingDirectory != "" { + if !filepath.IsLocal(cmd.WorkingDirectory) { + return nil, fmt.Errorf("working directory %q points outside of input root", cmd.WorkingDirectory) + } + workDir = filepath.Join(execDir, cmd.WorkingDirectory) + if err := os.MkdirAll(workDir, 0755); err != nil { + return nil, fmt.Errorf("could not create working directory: %v", err) + } + } + + // Create the directories required by all output files and directories. + outputPaths, err := e.createOutputPaths(cmd, workDir) + if err != nil { + return nil, err + } + + // Execute the command. + actionResult, err := e.executeCommand(execDir, cmd) + if err != nil { + return nil, fmt.Errorf("failed to execute command: %v", err) + } + + // Save stdout and stderr to the CAS and update their digests in the action result. + if err := e.saveStdOutErr(actionResult); err != nil { + return nil, err + } + + // Go through all output files and directories and upload them to the CAS. + for _, outputPath := range outputPaths { + joinedPath := filepath.Join(workDir, outputPath) + fi, err := os.Stat(joinedPath) + if err != nil { + if os.IsNotExist(err) { + // Ignore non-existing output files. + continue + } + return nil, fmt.Errorf("failed to stat output path %q: %v", outputPath, err) + } + if fi.IsDir() { + // Upload the directory to the CAS. + dirs, err := e.buildMerkleTree(joinedPath) + if err != nil { + return nil, fmt.Errorf("failed to build merkle tree for %q: %v", outputPath, err) + } + + tree := remote.Tree{ + Root: dirs[0], + } + if len(dirs) > 1 { + tree.Children = dirs[1:] + } + treeBytes, err := proto.Marshal(&tree) + if err != nil { + return nil, fmt.Errorf("failed to marshal tree: %v", err) + } + d, err := e.cas.Put(treeBytes) + if err != nil { + return nil, fmt.Errorf("failed to upload tree to CAS: %v", err) + } + + actionResult.OutputDirectories = append(actionResult.OutputDirectories, &remote.OutputDirectory{ + Path: outputPath, + TreeDigest: d.ToProto(), + IsTopologicallySorted: false, + }) + } else { + // Upload the file to the CAS. + d, err := digest.NewFromFile(joinedPath) + if err != nil { + return nil, fmt.Errorf("failed to compute digest of file %q: %v", outputPath, err) + } + if err := e.cas.Adopt(d, joinedPath); err != nil { + return nil, fmt.Errorf("failed to upload file %q to CAS: %v", outputPath, err) + } + + actionResult.OutputFiles = append(actionResult.OutputFiles, &remote.OutputFile{ + Path: outputPath, + Digest: d.ToProto(), + IsExecutable: fi.Mode()&0111 != 0, + }) + } + } + + return actionResult, nil +} + +// createOutputPaths creates the directories required by all output files and directories. +// It transforms and returns the list of output paths so that they're relative to our current working directory. +func (e *Executor) createOutputPaths(cmd *remote.Command, workDir string) (outputPaths []string, err error) { + if cmd.OutputPaths != nil { + // REAPI v2.1+ + outputPaths = cmd.OutputPaths + } else { + // REAPI v2.0 + outputPaths = make([]string, 0, len(cmd.OutputFiles)+len(cmd.OutputDirectories)) + outputPaths = append(outputPaths, cmd.OutputFiles...) + outputPaths = append(outputPaths, cmd.OutputDirectories...) + } + for _, outputPath := range outputPaths { + // We need to create the parent directories of the output path, because the command + // may not create them itself. + if err := os.MkdirAll(filepath.Join(workDir, filepath.Dir(outputPath)), 0755); err != nil { + return nil, fmt.Errorf("failed to create parent directories for %q: %v", outputPath, err) + } + } + return outputPaths, nil +} + +// saveStdOutErr saves stdout and stderr to the CAS and returns the updated action result. +func (e *Executor) saveStdOutErr(actionResult *remote.ActionResult) error { + d, err := e.cas.Put(actionResult.StdoutRaw) + if err != nil { + return status.Errorf(codes.Internal, "failed to put stdout into CAS: %v", err) + } + actionResult.StdoutDigest = d.ToProto() + + d, err = e.cas.Put(actionResult.StderrRaw) + if err != nil { + return status.Errorf(codes.Internal, "failed to put stderr into CAS: %v", err) + } + actionResult.StderrDigest = d.ToProto() + + // Servers are not required to inline stdout and stderr, so we just set them to nil. + // The client can just fetch them from the CAS if it needs them. + actionResult.StdoutRaw = nil + actionResult.StderrRaw = nil + + return nil +} + +func (e *Executor) getDirectory(d *remote.Digest) (*remote.Directory, error) { + dirDigest, err := digest.NewFromProto(d) + if err != nil { + return nil, fmt.Errorf("failed to parse directory digest: %v", err) + } + dirBytes, err := e.cas.Get(dirDigest) + if err != nil { + return nil, fmt.Errorf("failed to get directory from CAS: %v", err) + } + dir := &remote.Directory{} + if err := proto.Unmarshal(dirBytes, dir); err != nil { + return nil, fmt.Errorf("failed to unmarshal directory: %v", err) + } + return dir, nil +} + +func (e *Executor) getCommand(d *remote.Digest) (*remote.Command, error) { + cmdDigest, err := digest.NewFromProto(d) + if err != nil { + return nil, fmt.Errorf("failed to parse command digest: %v", err) + } + cmdBytes, err := e.cas.Get(cmdDigest) + if err != nil { + return nil, fmt.Errorf("failed to get command from CAS: %v", err) + } + cmd := &remote.Command{} + if err := proto.Unmarshal(cmdBytes, cmd); err != nil { + return nil, fmt.Errorf("failed to unmarshal command: %v", err) + } + return cmd, nil +} + +// materializeDirectory recursively materializes the given directory in the +// local filesystem. The directory itself is created at the given path, and +// all files and subdirectories are created under that path. +func (e *Executor) materializeDirectory(path string, d *remote.Directory) (missingBlobs []digest.Digest, err error) { + // First, materialize all the input files in the directory. + for _, fileNode := range d.Files { + filePath := filepath.Join(path, fileNode.Name) + err = e.materializeFile(filePath, fileNode) + if err != nil { + if os.IsNotExist(err) { + missingBlobs = append(missingBlobs, digest.NewFromProtoUnvalidated(fileNode.Digest)) + continue + } + return nil, fmt.Errorf("failed to materialize file: %v", err) + } + } + + // Next, materialize all the subdirectories. + for _, sdNode := range d.Directories { + sdPath := filepath.Join(path, sdNode.Name) + err = os.Mkdir(sdPath, 0755) + if err != nil { + return nil, fmt.Errorf("failed to create subdirectory: %v", err) + } + + sd, err := e.getDirectory(sdNode.Digest) + if err != nil { + if os.IsNotExist(err) { + missingBlobs = append(missingBlobs, digest.NewFromProtoUnvalidated(sdNode.Digest)) + continue + } + return nil, fmt.Errorf("failed to get subdirectory: %v", err) + } + + sdMissingBlobs, err := e.materializeDirectory(sdPath, sd) + missingBlobs = append(missingBlobs, sdMissingBlobs...) + if err != nil { + return nil, fmt.Errorf("failed to materialize subdirectory: %v", err) + } + } + + // Finally, set the directory properties. We have to do this after the files + // have been materialized, because otherwise the mtime of the directory would + // be updated to the current time. + if d.NodeProperties != nil { + if d.NodeProperties.Mtime != nil { + time := d.NodeProperties.Mtime.AsTime() + if err := os.Chtimes(path, time, time); err != nil { + return nil, fmt.Errorf("failed to set mtime: %v", err) + } + } + + if d.NodeProperties.UnixMode != nil { + if err := os.Chmod(path, os.FileMode(d.NodeProperties.UnixMode.Value)); err != nil { + return nil, fmt.Errorf("failed to set mode: %v", err) + } + } + } + + return missingBlobs, nil +} + +// materializeFile downloads the given file from the CAS and writes it to the given path. +func (e *Executor) materializeFile(filePath string, fileNode *remote.FileNode) error { + fileDigest, err := digest.NewFromProto(fileNode.Digest) + if err != nil { + return fmt.Errorf("failed to parse file digest: %v", err) + } + + // Calculate the file permissions from all relevant fields. + perm := os.FileMode(0644) + if fileNode.NodeProperties != nil && fileNode.NodeProperties.UnixMode != nil { + perm = os.FileMode(fileNode.NodeProperties.UnixMode.Value) + } + if fileNode.IsExecutable { + perm |= 0111 + } + + if fastCopy { + // Fast copy is enabled, so we just create a hard link to the file in the CAS. + err := e.cas.LinkTo(fileDigest, filePath) + if err != nil { + return fmt.Errorf("failed to link to file in CAS: %v", err) + } + + err = os.Chmod(filePath, perm) + if err != nil { + return fmt.Errorf("failed to set mode: %v", err) + } + } else { + fileBytes, err := e.cas.Get(fileDigest) + if err != nil { + return fmt.Errorf("failed to get file from CAS: %v", err) + } + + err = os.WriteFile(filePath, fileBytes, perm) + if err != nil { + return fmt.Errorf("failed to write file: %v", err) + } + } + + if fileNode.NodeProperties != nil && fileNode.NodeProperties.Mtime != nil { + time := fileNode.NodeProperties.Mtime.AsTime() + if err := os.Chtimes(filePath, time, time); err != nil { + return fmt.Errorf("failed to set mtime: %v", err) + } + } + + return nil +} + +// formatMissingBlobsError formats a list of missing blobs as a gRPC "FailedPrecondition" error +// as described in the Remote Execution API. +func (e *Executor) formatMissingBlobsError(blobs []digest.Digest) error { + violations := make([]*errdetails.PreconditionFailure_Violation, 0, len(blobs)) + for _, b := range blobs { + violations = append(violations, &errdetails.PreconditionFailure_Violation{ + Type: "MISSING", + Subject: fmt.Sprintf("blobs/%s/%d", b.Hash, b.Size), + }) + } + + s, err := status.New(codes.FailedPrecondition, "missing blobs").WithDetails(&errdetails.PreconditionFailure{ + Violations: violations, + }) + if err != nil { + return fmt.Errorf("failed to create status: %v", err) + } + + return s.Err() +} + +// executeCommand runs cmd in the input root execDir, which must already have been prepared by the caller. +// If we were able to execute the command, a valid ActionResult will be returned and error is nil. +// This includes the case where we ran the command, and it exited with an exit code != 0. +// However, if something went wrong during preparation or while spawning the process, an error is returned. +func (e *Executor) executeCommand(execDir string, cmd *remote.Command) (*remote.ActionResult, error) { + if cmd.Platform != nil { + for _, prop := range cmd.Platform.Properties { + if prop.Name == "container-image" { + // TODO: Implement containerized execution for actions that ask to run inside a given container image. + } + } + } + + c := exec.Command(cmd.Arguments[0], cmd.Arguments[1:]...) + c.Dir = filepath.Join(execDir, cmd.WorkingDirectory) + + for _, env := range cmd.EnvironmentVariables { + c.Env = append(c.Env, fmt.Sprintf("%s=%s", env.Name, env.Value)) + } + + var stdout, stderr bytes.Buffer + c.Stdout = &stdout + c.Stderr = &stderr + + if err := c.Run(); err != nil { + // ExitError just means that the command returned a non-zero exit code. + // In that case we just set the ExitCode in the ActionResult to it. + // However, other errors mean that something went wrong, and we need to + // return them to the caller. + if exitErr := (&exec.ExitError{}); !errors.As(err, &exitErr) { + return nil, err + } + } + + return &remote.ActionResult{ + ExitCode: int32(c.ProcessState.ExitCode()), + StdoutRaw: stdout.Bytes(), + StderrRaw: stderr.Bytes(), + }, nil +} + +// addDirectoryToTree recursively walks through the given directory and adds itself, all files and +// subdirectories to the given Tree. +func (e *Executor) buildMerkleTree(path string) (dirs []*remote.Directory, err error) { + dir := &remote.Directory{} + + dirEntries, err := os.ReadDir(path) + if err != nil { + return nil, fmt.Errorf("failed to read directory: %v", err) + } + + for _, dirEntry := range dirEntries { + if dirEntry.IsDir() { + subDirs, err := e.buildMerkleTree(filepath.Join(path, dirEntry.Name())) + if err != nil { + return nil, fmt.Errorf("failed to build merkle tree: %v", err) + } + d, err := digest.NewFromMessage(subDirs[0]) + if err != nil { + return nil, fmt.Errorf("failed to get digest: %v", err) + } + dir.Directories = append(dir.Directories, &remote.DirectoryNode{ + Name: dirEntry.Name(), + Digest: d.ToProto(), + }) + dirs = append(dirs, subDirs...) + } else { + d, err := digest.NewFromFile(filepath.Join(path, dirEntry.Name())) + if err != nil { + return nil, fmt.Errorf("failed to get digest: %v", err) + } + fi, err := dirEntry.Info() + if err != nil { + return nil, fmt.Errorf("failed to get file info: %v", err) + } + fileNode := &remote.FileNode{ + Name: dirEntry.Name(), + Digest: d.ToProto(), + IsExecutable: fi.Mode()&0111 != 0, + } + err = e.cas.Adopt(d, filepath.Join(path, dirEntry.Name())) + if err != nil { + return nil, fmt.Errorf("failed to move file into CAS: %v", err) + } + dir.Files = append(dir.Files, fileNode) + } + } + + dirBytes, err := proto.Marshal(dir) + if err != nil { + return nil, fmt.Errorf("failed to marshal directory: %v", err) + } + if _, err = e.cas.Put(dirBytes); err != nil { + return nil, err + } + + return append([]*remote.Directory{dir}, dirs...), nil +} + +func (e *Executor) deleteExecutionDirectory(dir string) { + if err := os.RemoveAll(dir); err != nil { + log.Printf("🚨 failed to remove execution directory: %v", err) + } +} diff --git a/execution/service.go b/execution/service.go new file mode 100644 index 0000000..d79f944 --- /dev/null +++ b/execution/service.go @@ -0,0 +1,200 @@ +// Copyright 2023 The Chromium Authors +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Package execution implements the REAPI Execution service. +package execution + +import ( + "fmt" + "log" + "os" + "time" + + "cloud.google.com/go/longrunning/autogen/longrunningpb" + "github.com/bazelbuild/remote-apis-sdks/go/pkg/digest" + remote "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" + + "github.com/philwo/kajiya/actioncache" + "github.com/philwo/kajiya/blobstore" +) + +// Service implements the REAPI Execution service. +type Service struct { + remote.UnimplementedExecutionServer + + executor *Executor + actionCache *actioncache.ActionCache + cas *blobstore.ContentAddressableStorage +} + +// Register creates and registers a new Service with the given gRPC server. +func Register(s *grpc.Server, executor *Executor, ac *actioncache.ActionCache, cas *blobstore.ContentAddressableStorage) error { + service, err := NewService(executor, ac, cas) + if err != nil { + return err + } + remote.RegisterExecutionServer(s, service) + return nil +} + +// NewService creates a new Service. +func NewService(executor *Executor, ac *actioncache.ActionCache, cas *blobstore.ContentAddressableStorage) (Service, error) { + if executor == nil { + return Service{}, fmt.Errorf("executor must be set") + } + + if cas == nil { + return Service{}, fmt.Errorf("cas must be set") + } + + return Service{ + executor: executor, + actionCache: ac, + cas: cas, + }, nil +} + +// Execute executes the given action and returns the result. +func (s Service) Execute(request *remote.ExecuteRequest, executeServer remote.Execution_ExecuteServer) error { + // Just for fun, measure how long the execution takes and log it. + start := time.Now() + err := s.execute(request, executeServer) + duration := time.Since(start) + if err != nil { + log.Printf("🚨 Execute(%v) => Error: %v", request.ActionDigest, err) + } else { + log.Printf("🎉 Execute(%v) => OK (%v)", request.ActionDigest, duration) + } + return err +} + +func (s Service) execute(request *remote.ExecuteRequest, executeServer remote.Execution_ExecuteServer) error { + // If the client explicitly specifies a DigestFunction, ensure that it's SHA256. + if request.DigestFunction != remote.DigestFunction_UNKNOWN && request.DigestFunction != remote.DigestFunction_SHA256 { + return status.Errorf(codes.InvalidArgument, "hash function %q is not supported", request.DigestFunction.String()) + } + + // Parse the action digest. + actionDigest, err := digest.NewFromProto(request.ActionDigest) + if err != nil { + return status.Errorf(codes.InvalidArgument, "invalid action digest: %v", err) + } + + // If we have an action cache, check if the action is already cached. + if s.actionCache != nil && !request.SkipCacheLookup { + resp, err := s.checkActionCache(actionDigest) + if err != nil { + if !os.IsNotExist(err) { + return status.Errorf(codes.Internal, "failed to check action cache: %v", err) + } + } + if resp != nil { + return executeServer.Send(resp) + } + } + + // Fetch the Action from the CAS. + action, err := s.getAction(actionDigest) + if err != nil { + return err + } + + // Execute the action. + actionResult, err := s.executor.Execute(action) + if err != nil { + return status.Errorf(codes.Internal, "failed to execute action: %v", err) + } + + // Store the result in the action cache. + if s.actionCache != nil && !action.DoNotCache && actionResult.ExitCode == 0 { + if err = s.actionCache.Put(actionDigest, actionResult); err != nil { + return status.Errorf(codes.Internal, "failed to put action into cache: %v", err) + } + } + + // Send the result to the client. + op, err := s.wrapActionResult(actionDigest, actionResult, false) + if err != nil { + return err + } + if err = executeServer.Send(op); err != nil { + return status.Errorf(codes.Internal, "failed to send result to client: %v", err) + } + + return nil +} + +func (s Service) checkActionCache(d digest.Digest) (*longrunningpb.Operation, error) { + // Try to get the result from the cache. + actionResult, err := s.actionCache.Get(d) + if err != nil { + return nil, err + } + + // Nice, cache hit! Let's wrap it up and send it to the client. + op, err := s.wrapActionResult(d, actionResult, true) + if err != nil { + return nil, err + } + return op, nil +} + +func (s Service) wrapActionResult(d digest.Digest, r *remote.ActionResult, cached bool) (*longrunningpb.Operation, error) { + // Construct some metadata for the execution operation and wrap it in an Any. + md, err := anypb.New(&remote.ExecuteOperationMetadata{ + Stage: remote.ExecutionStage_COMPLETED, + ActionDigest: d.ToProto(), + }) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to marshal metadata: %v", err) + } + + // Put the action result into an Any-wrapped ExecuteResponse. + resp, err := anypb.New(&remote.ExecuteResponse{ + Result: r, + CachedResult: cached, + }) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to marshal response: %v", err) + } + + // Wrap all the protos in another proto and return it. + op := &longrunningpb.Operation{ + Name: d.String(), + Metadata: md, + Done: true, + Result: &longrunningpb.Operation_Response{ + Response: resp, + }, + } + return op, nil +} + +// getAction fetches the remote.Action with the given digest.Digest from our CAS. +func (s Service) getAction(d digest.Digest) (*remote.Action, error) { + // Fetch the Action from the CAS. + actionBytes, err := s.cas.Get(d) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get action from CAS: %v", err) + } + + // Unmarshal the Action. + action := &remote.Action{} + err = proto.Unmarshal(actionBytes, action) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to unmarshal action: %v", err) + } + + return action, nil +} + +// WaitExecution waits for the specified execution to complete. +func (s Service) WaitExecution(request *remote.WaitExecutionRequest, executionServer remote.Execution_WaitExecutionServer) error { + return status.Error(codes.Unimplemented, "WaitExecution is not implemented") +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..839fc70 --- /dev/null +++ b/main.go @@ -0,0 +1,180 @@ +// Copyright 2023 The Chromium Authors +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// Kajiya is an RBE-compatible REAPI backend implementation used as a testing +// server during development of Chromium's new build tooling. It is not meant +// for production use, but can be very useful for local testing of any remote +// execution related code. +package main + +import ( + "flag" + "log" + "net" + "os" + "os/signal" + "path/filepath" + "strings" + "syscall" + + "google.golang.org/grpc" + "google.golang.org/grpc/reflection" + + "github.com/philwo/kajiya/actioncache" + "github.com/philwo/kajiya/blobstore" + "github.com/philwo/kajiya/capabilities" + "github.com/philwo/kajiya/execution" +) + +var ( + dataDir = flag.String("dir", getDefaultDataDir(), "the directory to store our data in") + listen = flag.String("listen", "localhost:50051", "the address to listen on (e.g. localhost:50051 or unix:///tmp/kajiya.sock)") + enableCache = flag.Bool("cache", true, "whether to enable the action cache service") + enableExecution = flag.Bool("execution", true, "whether to enable the execution service") +) + +func getDefaultDataDir() string { + cacheDir, err := os.UserCacheDir() + if err != nil { + return "" + } + return filepath.Join(cacheDir, "kajiya") +} + +func main() { + flag.Parse() + + // Ensure our data directory exists. + if *dataDir == "" { + log.Fatalf("no data directory specified") + } + + log.Printf("💾 using data directory: %v", *dataDir) + + // Listen on the specified address. + network, addr := parseAddress(*listen) + listener, err := net.Listen(network, addr) + if err != nil { + log.Fatalf("failed to listen: %v", err) + } + log.Printf("🛜 listening on %v", listener.Addr()) + + // Create the gRPC server and register the services. + grpcServer, err := createServer(*dataDir) + if err != nil { + log.Fatalf("failed to create server: %v", err) + } + + // Handle interrupts gracefully. + HandleInterrupt(func() { + grpcServer.GracefulStop() + }) + + // Start serving. + if err := grpcServer.Serve(listener); err != nil { + log.Fatalf("failed to serve: %v", err) + } +} + +// parseAddress parses the listen address from the command line flag. +// The address can be a TCP address (e.g. localhost:50051) or a Unix domain socket (e.g. unix:///tmp/kajiya.sock). +func parseAddress(addr string) (string, string) { + network := "tcp" + if strings.HasPrefix(addr, "unix://") { + network = "unix" + addr = addr[len("unix://"):] + } + return network, addr +} + +// createServer creates a new gRPC server and registers the services. +func createServer(dataDir string) (*grpc.Server, error) { + s := grpc.NewServer() + + capabilities.Register(s) + log.Printf("✅ capabilities service") + + // Create a CAS backed by a local filesystem. + casDir := filepath.Join(dataDir, "cas") + cas, err := blobstore.New(casDir) + if err != nil { + return nil, err + } + + // CAS service. + uploadDir := filepath.Join(casDir, "tmp") + err = blobstore.Register(s, cas, uploadDir) + if err != nil { + return nil, err + } + log.Printf("✅ content-addressable storage service") + + // Action cache service. + var ac *actioncache.ActionCache + if *enableCache { + acDir := filepath.Join(dataDir, "ac") + ac, err = actioncache.New(acDir) + if err != nil { + return nil, err + } + + err = actioncache.Register(s, ac, cas) + if err != nil { + return nil, err + } + log.Printf("✅ action cache service") + } else { + log.Printf("⚠️ action cache service disabled") + } + + // Execution service. + if *enableExecution { + execDir := filepath.Join(dataDir, "exec") + executor, err := execution.New(execDir, cas) + if err != nil { + return nil, err + } + + err = execution.Register(s, executor, ac, cas) + if err != nil { + return nil, err + } + log.Printf("✅ execution service") + } else { + log.Printf("⚠️ execution service disabled") + } + + // Register the reflection service provided by gRPC. + reflection.Register(s) + log.Printf("✅ gRPC reflection service") + + return s, nil +} + +// HandleInterrupt calls 'fn' in a separate goroutine on SIGTERM or Ctrl+C. +// +// When SIGTERM or Ctrl+C comes for a second time, logs to stderr and kills +// the process immediately via os.Exit(1). +// +// Returns a callback that can be used to remove the installed signal handlers. +func HandleInterrupt(fn func()) (stopper func()) { + ch := make(chan os.Signal, 2) + signal.Notify(ch, os.Interrupt, syscall.SIGTERM) + go func() { + handled := false + for range ch { + if handled { + log.Printf("🚨 received second interrupt signal, exiting now") + os.Exit(1) + } + log.Printf("⚠️ received signal, attempting graceful shutdown") + handled = true + go fn() + } + }() + return func() { + signal.Stop(ch) + close(ch) + } +}