diff --git a/.errcheck-excludes b/.errcheck-excludes new file mode 100644 index 00000000..077e44f1 --- /dev/null +++ b/.errcheck-excludes @@ -0,0 +1,5 @@ +// (*io.PipeWriter).CloseWithError "never overwrites the previous error if it +// exists and always returns nil". +// +// https://golang.org/pkg/io/#PipeWriter.CloseWithError +(*io.PipeWriter).CloseWithError diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..754ebe3d --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,38 @@ +name: CI + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +jobs: + build: + strategy: + fail-fast: false + matrix: + go: [1.13, 1.14, 1.15] + os: [ubuntu-latest, macos-latest] + name: Build & Test + runs-on: ${{ matrix.os }} + steps: + - name: Set up Go ${{ matrix.go }} + uses: actions/setup-go@v2 + with: + go-version: ${{ matrix.go }} + - name: Check out + uses: actions/checkout@v2 + - name: Test + run: go test -tags=unit ./... + golangci: + name: Lint + runs-on: ubuntu-latest + steps: + - name: Check out + uses: actions/checkout@v2 + - name: golangci-lint + uses: golangci/golangci-lint-action@v2 + with: + version: v1.29 + only-new-issues: true + args: --timeout=5m diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 00000000..2ddfb50a --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,95 @@ +run: + concurrency: 4 + timeout: 1m + + # exit code when at least one issue was found + issues-exit-code: 1 + + # include test files + tests: true + + build-tags: + + # which dirs to skip: issues from them won't be reported; + skip-dirs: + + # enables skipping of default directories: + # vendor$, third_party$, testdata$, examples$, Godeps$, builtin$ + skip-dirs-use-default: true + + # which files to skip: they will be analyzed, but issues from them won't be + # reported. + skip-files: + + # disallow multiple parallel golangci-lint instances + allow-parallel-runners: false + +output: + # colored-line-number|line-number|json|tab|checkstyle|code-climate + format: colored-line-number + + # print lines of code with issue + print-issued-lines: true + + # print linter name in the end of issue text + print-linter-name: true + + # make issues output unique by line + uniq-by-line: true + +linters-settings: + errcheck: + # do not report about not checking errors in type assertions: `a := + # b.(MyStruct)` + check-type-assertions: false + + # do not report about assignment of errors to blank identifier: `num, _ := + # strconv.Atoi(numStr)` + check-blank: false + + # path to a file containing a list of functions to exclude from checking + # see https://github.com/kisielk/errcheck#excluding-functions for details + exclude: .errcheck-excludes + + govet: + # report about shadowed variables + check-shadowing: true + + # settings per analyzer + settings: + # run `go tool vet help` to see all analyzers + printf: + # run `go tool vet help printf` to see available settings for `printf` + # analyzer + funcs: + - (github.com/grailbio/base/log).Fatal + - (github.com/grailbio/base/log).Output + - (github.com/grailbio/base/log).Outputf + - (github.com/grailbio/base/log).Panic + - (github.com/grailbio/base/log).Panicf + - (github.com/grailbio/base/log).Print + - (github.com/grailbio/base/log).Printf + + unused: + # do not report unused exported identifiers + check-exported: false + + misspell: + locale: US + +linters: + disable-all: true + fast: false + enable: + - deadcode + - goimports + - gosimple + - govet + - errcheck + - ineffassign + - misspell + - staticcheck + - structcheck + - typecheck + - unused + - varcheck diff --git a/README.md b/README.md index 44e6fb2e..341a7fc9 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,10 @@ The grailbio/base project includes all the packages used by many other grailbio Go packages: +- API documentation: [godoc.org/github.com/grailbio/base](https://godoc.org/github.com/grailbio/base) +- Issue tracker: [github.com/grailbio/base/issues](https://github.com/grailbio/base/issues) +- [![CI](https://github.com/grailbio/base/workflows/CI/badge.svg)](https://github.com/grailbio/base/actions?query=workflow%3ACI) + - [recordio](https://godoc.org/github.com/grailbio/base/recordio): encrypted and compressed record oriented files with indexing support - [file](https://godoc.org/github.com/grailbio/base/file): unified file API for the local file system and S3 - [digest](https://godoc.org/github.com/grailbio/base/digest): common in-memory and serialized representation of digests @@ -12,4 +16,9 @@ other grailbio Go packages: - [syncqueue](https://godoc.org/github.com/grailbio/base/syncqueue): various flavors of producer-consumer queues - [unsafe](https://godoc.org/github.com/grailbio/base/unsafe): conversion from []byte to string, etc. - [compress/libdeflate](https://godoc.org/github.com/grailbio/base/compress/libdeflate): efficient block compression/decompression +- [bitset](https://godoc.org/github.com/grailbio/base/bitset): []uintptr bitset support - [simd](https://godoc.org/github.com/grailbio/base/simd): fast operations on []byte +- [tsv](https://godoc.org/github.com/grailbio/base/tsv): simple and efficient TSV writer +- [cloud/spotadvisor](https://godoc.org/github.com/grailbio/base/cloud/spotadvisor): provides an interface for fetching and utilizing AWS Spot Advisor data +- [cloud/spotfeed](https://godoc.org/github.com/grailbio/base/cloud/spotfeed): provides interfaces for interacting with the AWS spot data feed format for files hosted on S3 + \ No newline at end of file diff --git a/admit/admit.go b/admit/admit.go new file mode 100644 index 00000000..f92ba0ef --- /dev/null +++ b/admit/admit.go @@ -0,0 +1,288 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +// Package admit contains utilities for admission control. +package admit + +import ( + "context" + "expvar" + "sync" + + "github.com/grailbio/base/log" + "github.com/grailbio/base/retry" + "github.com/grailbio/base/sync/ctxsync" +) + +var ( + admitLimit = expvar.NewMap("admit.limit") + admitUsed = expvar.NewMap("admit.used") +) + +// Policy implements the low level details of an admission control +// policy. Users typically use a utility function such as admit.Do +// or admit.Retry. +type Policy interface { + // Acquire acquires a number of tokens from the admission controller. + // Returns on success, or if the context was canceled. + // Acquire can also return with an error if the number of requested tokens + // exceeds the upper limit of available tokens. + Acquire(ctx context.Context, need int) error + + // Release a number of tokens to the admission controller, + // reporting whether the request was within the capacity limits. + Release(tokens int, ok bool) +} + +// Do calls f after being admitted by the controller. f's bool return value is +// passed on to the underlying policy upon Release, and the error is simply +// returned back to the caller as a convenience. +// If policy is nil, then this will simply call f. +func Do(ctx context.Context, policy Policy, tokens int, f func() (bool, error)) error { + if policy == nil { + _, err := f() + return err + } + if err := policy.Acquire(ctx, tokens); err != nil { + return err + } + var ( + ok bool + err error + ) + defer func() { policy.Release(tokens, ok) }() + ok, err = f() + return err +} + +// CapacityStatus is the feedback provided by the user to Retry about the underlying resource being managed by Policy. +type CapacityStatus int + +const ( + // Within means that the underlying resource is within capacity. + Within CapacityStatus = iota + // OverNoRetry means that the underlying resource is over capacity but no retry is needed. + // This is useful in situations where a request using the resource succeeded, but there are + // signs of congestion (for example, in the form of high latency). + OverNoRetry + // OverNeedRetry means that the underlying resource is over capacity and a retry is needed. + // This is useful in situations where requests failed due to the underlying resource hitting capacity limits. + OverNeedRetry +) + +// RetryPolicy combines an admission controller with a retry policy. +type RetryPolicy interface { + Policy + retry.Policy +} + +// Retry calls f after being admitted by the Policy (implied by the given RetryPolicy). +// If f returns Within, true is passed to the underlying policy upon Release and false otherwise. +// If f returns OverNeedRetry, f will be retried as per the RetryPolicy (and the error returned by f is ignored), +// and if f can no longer be retried, the error returned by retry.Policy will be returned. +func Retry(ctx context.Context, policy RetryPolicy, tokens int, f func() (CapacityStatus, error)) error { + var err error + for retries := 0; ; retries++ { + var c CapacityStatus + err = Do(ctx, policy, tokens, func() (bool, error) { + var err error // nolint:govet + c, err = f() + return c == Within, err + }) + // Retry as per retry policy if attempt failed due to over capacity. + if c != OverNeedRetry { + break + } + if err = retry.Wait(ctx, policy, retries); err != nil { + break + } + log.Debug.Printf("admit.Retry: %v, retries=%d", err, retries) + } + return err +} + +const defaultLimitChangeRate = 0.1 + +// Adjust changes the limit by factor. +func adjust(limit int, factor float32) int { + return int(float32(limit) * (1 + factor)) +} + +func min(x, y int) int { + if x < y { + return x + } + return y +} + +func max(x, y int) int { + if x > y { + return x + } + return y +} + +type controller struct { + // limit, used are the current limit and current used tokens respectively. + limit, used int + // low, high define the range within which the limit can be adjusted. + low, high int + mu sync.Mutex + cond *ctxsync.Cond + limitVar, usedVar expvar.Int +} + +type controllerWithRetry struct { + *controller + retry.Policy +} + +func newController(start, limit int) *controller { + c := &controller{limit: start, used: 0, low: start, high: limit} + c.cond = ctxsync.NewCond(&c.mu) + return c +} + +// Controller returns a Policy which starts with a concurrency limit of 'start' +// and can grow upto a maximum of 'limit' as long as errors aren't observed. +// A controller is not fair: tokens are not granted in FIFO order; +// rather, waiters are picked randomly to be granted new tokens. +func Controller(start, limit int) Policy { + return newController(start, limit) +} + +// ControllerWithRetry returns a RetryPolicy which starts with a concurrency +// limit of 'start' and can grow upto a maximum of 'limit' if no errors are seen. +// A controller is not fair: tokens are not granted in FIFO order; +// rather, waiters are picked randomly to be granted new tokens. +func ControllerWithRetry(start, limit int, retryPolicy retry.Policy) RetryPolicy { + return controllerWithRetry{controller: newController(start, limit), Policy: retryPolicy} +} + +// EnableVarExport enables the export of relevant vars useful for debugging/monitoring. +func EnableVarExport(policy Policy, name string) { + switch c := policy.(type) { + case *controller: + admitLimit.Set(name, &c.limitVar) + admitUsed.Set(name, &c.usedVar) + case *aimd: + admitLimit.Set(name, &c.limitVar) + admitUsed.Set(name, &c.usedVar) + } +} + +// Acquire acquires a number of tokens from the admission controller. +// Returns on success, or if the context was canceled. +func (c *controller) Acquire(ctx context.Context, need int) error { + c.mu.Lock() + defer c.mu.Unlock() + for { + // TODO(swami): should allow an increase only when the last release was ok + lim := min(adjust(c.limit, defaultLimitChangeRate), c.high) + have := lim - c.used + if need <= have || (need > lim && c.used == 0) { + c.used += need + c.usedVar.Set(int64(c.used)) + return nil + } + if err := c.cond.Wait(ctx); err != nil { + return err + } + } +} + +// Release releases a number of tokens to the admission controller, +// reporting whether the request was within the capacity limits. +func (c *controller) Release(tokens int, ok bool) { + c.mu.Lock() + defer c.mu.Unlock() + if ok { + if c.used > c.limit { + c.limit = min(c.used, c.high) + } + } else { + c.limit = max(c.low, adjust(c.limit, -defaultLimitChangeRate)) + } + c.used -= tokens + + c.limitVar.Set(int64(c.limit)) + c.usedVar.Set(int64(c.used)) + c.cond.Broadcast() +} + +type aimd struct { + // limit, used are the current limit and current used tokens respectively. + limit, used int + // min is the minimum limit. + min int + // decfactor is the factor by which tokens are reduced upon congestion. + decfactor float32 + + mu sync.Mutex + cond *ctxsync.Cond + limitVar, usedVar expvar.Int +} + +type aimdWithRetry struct { + *aimd + retry.Policy +} + +func newAimd(min int, decfactor float32) *aimd { + c := &aimd{min: min, limit: min, decfactor: decfactor} + c.cond = ctxsync.NewCond(&c.mu) + return c +} + +// AIMD returns a Policy which uses the Additive increase/multiplicative decrease +// algorithm for computing the amount of the concurrency to allow. +// AIMD is not fair: tokens are not granted in FIFO order; +// rather, waiters are picked randomly to be granted new tokens. +func AIMD(min int, decfactor float32) Policy { + return newAimd(min, decfactor) +} + +// AIMDWithRetry returns a RetryPolicy which uses the Additive increase/multiplicative decrease +// algorithm for computing the amount of the concurrency to allow. +// AIMDWithRetry is not fair: tokens are not granted in FIFO order; +// rather, waiters are picked randomly to be granted new tokens. +func AIMDWithRetry(min int, decfactor float32, retryPolicy retry.Policy) RetryPolicy { + return aimdWithRetry{aimd: newAimd(min, decfactor), Policy: retryPolicy} +} + +// Acquire acquires a number of tokens from the admission controller. +// Returns on success, or if the context was canceled. +func (c *aimd) Acquire(ctx context.Context, need int) error { + c.mu.Lock() + defer c.mu.Unlock() + for { + have := c.limit - c.used + if need <= have || (need > c.limit && c.used == 0) { + c.used += need + c.usedVar.Set(int64(c.used)) + return nil + } + if err := c.cond.Wait(ctx); err != nil { + return err + } + } +} + +// Release releases a number of tokens to the admission controller, +// reporting whether the request was within the capacity limits. +func (c *aimd) Release(tokens int, ok bool) { + c.mu.Lock() + defer c.mu.Unlock() + switch { + case !ok: + c.limit = max(c.min, adjust(c.limit, -c.decfactor)) + case ok && c.used == c.limit: + c.limit += 1 + } + c.used -= tokens + + c.limitVar.Set(int64(c.limit)) + c.usedVar.Set(int64(c.used)) + c.cond.Broadcast() +} diff --git a/admit/admit_test.go b/admit/admit_test.go new file mode 100644 index 00000000..43801cb8 --- /dev/null +++ b/admit/admit_test.go @@ -0,0 +1,262 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +package admit + +import ( + "context" + "expvar" + "fmt" + "math/rand" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/grailbio/base/retry" + "github.com/grailbio/base/traverse" +) + +func checkState(t *testing.T, p Policy, limit, used int) { + t.Helper() + var gotl, gotu int + switch c := p.(type) { + case *controller: + gotl = c.limit + gotu = c.used + case *aimd: + gotl = c.limit + gotu = c.used + } + if gotu != used { + t.Errorf("c.used: got %d, want %d", gotu, used) + } + if gotl != limit { + t.Errorf("c.limit: got %d, want %d", gotl, limit) + } +} + +func checkVars(t *testing.T, key, max, used string) { + t.Helper() + if want, got := max, admitLimit.Get(key).String(); got != want { + t.Errorf("admitLimit got %s, want %s", got, want) + } + if want, got := used, admitUsed.Get(key).String(); got != want { + t.Errorf("admitUsed got %s, want %s", got, want) + } +} + +func getKeys(m *expvar.Map) map[string]bool { + keys := make(map[string]bool) + m.Do(func(kv expvar.KeyValue) { + keys[kv.Key] = true + }) + return keys +} + +func TestController(t *testing.T) { + c := newController(10, 15) + // use up 5. + if err := c.Acquire(context.Background(), 5); err != nil { + t.Fatal(err) + } + checkState(t, c, 10, 5) + // can go upto 6. + if err := c.Acquire(context.Background(), 6); err != nil { + t.Fatal(err) + } + // release and report capacity error. + c.Release(5, false) + checkState(t, c, 10, 6) + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + // 6 still in use and limit should now be 10, so can't acquire 6. + if want, got := context.DeadlineExceeded, c.Acquire(ctx, 6); got != want { + t.Fatalf("got %v, want %v", got, want) + } + cancel() + if want, got := 0, getKeys(admitLimit); len(got) != want { + t.Fatalf("admitLimit got %v, want len %d", got, want) + } + if want, got := 0, getKeys(admitUsed); len(got) != want { + t.Fatalf("admitUsed got %v, want len %d", got, want) + } + EnableVarExport(c, "test") + c.Release(6, true) + checkState(t, c, 10, 0) + checkVars(t, "test", "10", "0") + // max is still 9, but since none are used, should accommodate larger request. + if err := c.Acquire(context.Background(), 18); err != nil { + t.Fatal(err) + } + checkState(t, c, 10, 18) + checkVars(t, "test", "10", "18") + c.Release(17, true) + checkState(t, c, 15, 1) + checkVars(t, "test", "15", "1") + ctx, cancel = context.WithTimeout(context.Background(), time.Millisecond) + // 1 still in use and max is 15, so shouldn't accommodate larger request. + if want, got := context.DeadlineExceeded, c.Acquire(ctx, 18); got != want { + t.Fatalf("got %v, want %v", got, want) + } + cancel() + checkState(t, c, 15, 1) + checkVars(t, "test", "15", "1") + c.Release(1, true) + checkState(t, c, 15, 0) + checkVars(t, "test", "15", "0") +} + +func TestControllerConcurrently(t *testing.T) { + testPolicy(t, ControllerWithRetry(100, 1000, nil)) +} + +func TestAIMD(t *testing.T) { + c := newAimd(10, 0.2) + // use up 5. + if err := c.Acquire(context.Background(), 5); err != nil { + t.Fatal(err) + } + checkState(t, c, 10, 5) + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + // 5 in use and limit should still be 10, so can't acquire 6. + if want, got := context.DeadlineExceeded, c.Acquire(ctx, 6); got != want { + t.Fatalf("got %v, want %v", got, want) + } + cancel() + // release and report capacity error. + EnableVarExport(c, "aimd") + c.Release(5, true) + checkState(t, c, 10, 0) + checkVars(t, "aimd", "10", "0") + + for i := 0; i < 10; i++ { + if err := c.Acquire(context.Background(), 1); err != nil { + t.Fatal(err) + } + } + checkState(t, c, 10, 10) + checkVars(t, "aimd", "10", "10") + for i := 1; i <= 5; i++ { + c.Release(i, true) + if err := c.Acquire(context.Background(), i+1); err != nil { + t.Fatal(err) + } + } + checkState(t, c, 15, 15) + checkVars(t, "aimd", "15", "15") + + c.Release(1, false) + checkState(t, c, 12, 14) + checkVars(t, "aimd", "12", "14") + + c.Release(1, false) + checkState(t, c, 10, 13) + checkVars(t, "aimd", "10", "13") + + ctx, cancel = context.WithTimeout(context.Background(), time.Millisecond) + // 13 still in use and limit should now be 10, so can't acquire 1. + if want, got := context.DeadlineExceeded, c.Acquire(ctx, 1); got != want { + t.Fatalf("got %v, want %v", got, want) + } + cancel() +} + +func TestAIMDConcurrently(t *testing.T) { + testPolicy(t, AIMDWithRetry(100, 0.25, nil)) +} + +func testPolicy(t *testing.T, p Policy) { + const ( + N = 100 + T = 100 + ) + var pending int32 + var begin sync.WaitGroup + begin.Add(N) + err := traverse.Each(N, func(i int) error { + begin.Done() + n := rand.Intn(T/10) + 1 + if err := p.Acquire(context.Background(), n); err != nil { + return err + } + if m := atomic.AddInt32(&pending, int32(n)); m > T { + return fmt.Errorf("too many tokens: %d > %d", m, T) + } + atomic.AddInt32(&pending, -int32(n)) + p.Release(n, (i > 10 && i < 20) || (i > 70 && i < 80)) + return nil + }) + if err != nil { + t.Fatal(err) + } +} + +func TestDo(t *testing.T) { + c := newController(100, 10000) + // Must satisfy even 150 tokens since none are used. + if err := Do(context.Background(), c, 150, func() (bool, error) { return true, nil }); err != nil { + t.Fatal(err) + } + checkState(t, c, 150, 0) + // controller has 150 tokens, use 10 and report capacity error + if want, got := error(nil), Do(context.Background(), c, 10, func() (bool, error) { return false, nil }); got != want { + t.Fatalf("got %v, want %v", got, want) + } + checkState(t, c, 135, 0) + // controller has 135 tokens, use up 35... + c.Acquire(context.Background(), 35) + checkState(t, c, 135, 35) + // can go upto 1.1*135 = 148, so should timeout for 114. + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if want, got := context.DeadlineExceeded, Do(ctx, c, 114, func() (bool, error) { return true, nil }); got != want { + t.Fatalf("got %v, want %v", got, want) + } + checkState(t, c, 135, 35) + // can go upto 1.1*135 = 148, so should timeout for 113. + if want, got := error(nil), Do(context.Background(), c, 113, func() (bool, error) { return true, nil }); got != want { + t.Fatalf("got %v, want %v", got, want) + } + checkState(t, c, 148, 35) + // can go upto 1.1*148 = 162, so should go upto 127. + if err := Do(context.Background(), c, 127, func() (bool, error) { return true, nil }); err != nil { + t.Fatal(err) + } +} + +func TestRetry(t *testing.T) { + const ( + N = 1000 + ) + c := ControllerWithRetry(200, 1000, retry.MaxRetries(retry.Backoff(100*time.Millisecond, time.Minute, 1.5), 5)) + var begin sync.WaitGroup + begin.Add(N) + err := traverse.Each(N, func(i int) error { + begin.Done() + begin.Wait() + randFunc := func() (CapacityStatus, error) { + // Out of every three requests, one will (5% of the time) report over capacity with a need to retry, + // and another (also 5% of the time) will report over capacity with no need to retry. + switch i % 3 { + case 0: + time.Sleep(time.Millisecond * time.Duration(20+rand.Intn(50))) + if rand.Intn(100) < 5 { // 5% of the time. + return OverNeedRetry, nil + } + case 1: + time.Sleep(time.Millisecond * time.Duration(20+rand.Intn(50))) + if rand.Intn(100) < 5 { // 5% of the time. + return OverNoRetry, nil + } + } + time.Sleep(time.Millisecond * time.Duration(5+rand.Intn(20))) + return Within, nil + } + n := rand.Intn(20) + 1 + return Retry(context.Background(), c, n, randFunc) + }) + if err != nil { + t.Fatal(err) + } +} diff --git a/backgroundcontext/backgroundcontext.go b/backgroundcontext/backgroundcontext.go new file mode 100644 index 00000000..8fbb8fd0 --- /dev/null +++ b/backgroundcontext/backgroundcontext.go @@ -0,0 +1,78 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +// Package backgroundcontext manages the singleton v23 context. This +// package is not for general use. It it only for packages that (1) +// need to access a background context, and (2) are used both as a +// binary and as a shared library (e.g., in R). +package backgroundcontext + +import ( + "context" + "sync/atomic" + "unsafe" + + v23context "v.io/v23/context" +) + +var ptr unsafe.Pointer + +// Set sets the singleton global context. It should be called at most once, +// usually immediately after a process starts. Calling Set() multiple times will +// cause a panic. Thread safe. +func Set(ctx *v23context.T) { + if !atomic.CompareAndSwapPointer(&ptr, nil, unsafe.Pointer(ctx)) { + panic("backgroundcontext.Set called twice") + } +} + +// T returns the background context set by Set. It panics if Set has not been +// called yet. Thread safe. +func T() *v23context.T { + p := atomic.LoadPointer(&ptr) + if p == nil { + panic("backgroundcontext.Set not yet called") + } + return (*v23context.T)(p) +} + +// Get returns a background context: if a v23 context has been set, it is returned; +// otherwise the standard Go background context is returned. +func Get() context.Context { + if p := (*v23context.T)(atomic.LoadPointer(&ptr)); p != nil { + return p + } + return context.Background() +} + +type wrapped struct { + context.Context + v23ctx *v23context.T +} + +func (w *wrapped) Value(key interface{}) interface{} { + val := w.Context.Value(key) + if val != nil { + return val + } + return w.v23ctx.Value(key) +} + +// Wrap wraps the provided context, composing it with the defined +// background context, if any. This allows contexts to be wrapped +// so that v23 stubs can get the background context's RPC client. +// +// Cancelations are not forwarded: it is assumed that the set background +// context is never canceled. +// +// BUG: this is very complicated; but it seems required to make +// vanadium play nicely with contexts defined outside of its +// universe. +func Wrap(ctx context.Context) context.Context { + v23ctx := (*v23context.T)(atomic.LoadPointer(&ptr)) + if v23ctx == nil { + return ctx + } + return &wrapped{ctx, v23ctx} +} diff --git a/backgroundcontext/backgroundcontext_test.go b/backgroundcontext/backgroundcontext_test.go new file mode 100644 index 00000000..fa31f52e --- /dev/null +++ b/backgroundcontext/backgroundcontext_test.go @@ -0,0 +1,32 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +//+build !unit + +package backgroundcontext_test + +import ( + "context" + "testing" + + "github.com/grailbio/base/backgroundcontext" + "github.com/grailbio/base/vcontext" + v23 "v.io/v23" + v23context "v.io/v23/context" +) + +func TestWrap(t *testing.T) { + // This sets the backgroundcontext. + _ = vcontext.Background() + + ctx, cancel := context.WithCancel(context.Background()) + bgctx := backgroundcontext.Wrap(ctx) + if v23.GetClient(v23context.FromGoContext(bgctx)) == nil { + t.Fatal("no v23 client returned") + } + cancel() + if got, want := bgctx.Err(), context.Canceled; got != want { + t.Errorf("got %v, want %v", got, want) + } +} diff --git a/bitset/bitset.go b/bitset/bitset.go new file mode 100644 index 00000000..fb2177e5 --- /dev/null +++ b/bitset/bitset.go @@ -0,0 +1,189 @@ +// Copyright 2022 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +// This is similar to github.com/willf/bitset , but with some extraneous +// abstraction removed. See also simd/count_amd64.go. +// +// ([]byte <-> []uintptr adapters will be added when needed.) + +package bitset + +import ( + "math/bits" +) + +// BitsPerWord is the number of bits in a machine word. +const BitsPerWord = 64 + +// Log2BitsPerWord is log_2(BitsPerWord). +const Log2BitsPerWord = uint(6) + +// Set sets the given bit in a []uintptr bitset. +func Set(data []uintptr, bitIdx int) { + // Unsigned division by a power-of-2 constant compiles to a right-shift, + // while signed does not due to negative nastiness. + data[uint(bitIdx)/BitsPerWord] |= 1 << (uint(bitIdx) % BitsPerWord) +} + +// Clear clears the given bit in a []uintptr bitset. +func Clear(data []uintptr, bitIdx int) { + wordIdx := uint(bitIdx) / BitsPerWord + data[wordIdx] = data[wordIdx] &^ (1 << (uint(bitIdx) % BitsPerWord)) +} + +// Test returns true iff the given bit is set. +func Test(data []uintptr, bitIdx int) bool { + return (data[uint(bitIdx)/BitsPerWord] & (1 << (uint(bitIdx) % BitsPerWord))) != 0 +} + +// SetInterval sets the bits at all positions in [startIdx, limitIdx) in a +// []uintptr bitset. +func SetInterval(data []uintptr, startIdx, limitIdx int) { + if startIdx >= limitIdx { + return + } + startWordIdx := startIdx >> Log2BitsPerWord + startBit := uintptr(1) << uint32(startIdx&(BitsPerWord-1)) + limitWordIdx := limitIdx >> Log2BitsPerWord + limitBit := uintptr(1) << uint32(limitIdx&(BitsPerWord-1)) + if startWordIdx == limitWordIdx { + // We can't fill all bits from startBit on in the first word, since the + // limit is also within this word. + data[startWordIdx] |= limitBit - startBit + return + } + // Fill all bits from startBit on in the first word. + data[startWordIdx] |= -startBit + // Fill all bits in intermediate words. + // (todo: ensure compiler doesn't insert pointless slice bounds-checks on + // every iteration) + for wordIdx := startWordIdx + 1; wordIdx < limitWordIdx; wordIdx++ { + data[wordIdx] = ^uintptr(0) + } + // Fill just the bottom bits in the last word, if necessary. + if limitBit != 1 { + data[limitWordIdx] |= limitBit - 1 + } +} + +// ClearInterval clears the bits at all positions in [startIdx, limitIdx) in a +// []uintptr bitset. +func ClearInterval(data []uintptr, startIdx, limitIdx int) { + if startIdx >= limitIdx { + return + } + startWordIdx := startIdx >> Log2BitsPerWord + startBit := uintptr(1) << uint32(startIdx&(BitsPerWord-1)) + limitWordIdx := limitIdx >> Log2BitsPerWord + limitBit := uintptr(1) << uint32(limitIdx&(BitsPerWord-1)) + if startWordIdx == limitWordIdx { + // We can't clear all bits from startBit on in the first word, since the + // limit is also within this word. + data[startWordIdx] &= ^(limitBit - startBit) + return + } + // Clear all bits from startBit on in the first word. + data[startWordIdx] &= startBit - 1 + // Clear all bits in intermediate words. + for wordIdx := startWordIdx + 1; wordIdx < limitWordIdx; wordIdx++ { + data[wordIdx] = 0 + } + // Clear just the bottom bits in the last word, if necessary. + if limitBit != 1 { + data[limitWordIdx] &= -limitBit + } +} + +// NewClearBits creates a []uintptr bitset with capacity for at least nBit +// bits, and all bits clear. +func NewClearBits(nBit int) []uintptr { + nWord := (nBit + BitsPerWord - 1) / BitsPerWord + return make([]uintptr, nWord) +} + +// NewSetBits creates a []uintptr bitset with capacity for at least nBit bits, +// and all bits at positions [0, nBit) set. +func NewSetBits(nBit int) []uintptr { + data := NewClearBits(nBit) + SetInterval(data, 0, nBit) + return data +} + +// NonzeroWordScanner iterates over and clears the set bits in a bitset, with +// the somewhat unusual precondition that the number of nonzero words is known +// in advance. The 'BitsetScanner' name is being reserved for a scanner which +// expects the number of set bits to be known instead. +// +// Note that, when many bits are set, a more complicated double-loop based +// around a function like willf/bitset.NextSetMany() has ~40% less overhead (at +// least with Go 1.10 on a Mac), and you can do even better with manual +// inlining of the iteration logic. As a consequence, it shouldn't be used +// when the bit iteration/clearing process is actually the dominant +// computational cost (and neither should NextSetMany(), manual inlining is +// 2-6x better without much more code, see bitsetManualInlineSubtask() in +// bitset_test.go for an example). However, it's a good choice everywhere +// else, outperforming the other scanners I'm aware of with similar ease of +// use, and maybe a future Go version will inline it properly. +type NonzeroWordScanner struct { + // data is the original bitset. + data []uintptr + // bitIdxOffset is BitsPerWord times the current data[] array index. + bitIdxOffset int + // bitWord is bits[bitIdxOffset / BitsPerWord], with already-iterated-over + // bits cleared. + bitWord uintptr + // nNonzeroWord is the number of nonzero words remaining in data[]. + nNonzeroWord int +} + +// NewNonzeroWordScanner returns a NonzeroWordScanner for the given bitset, +// along with the position of the first bit. (This interface has been chosen +// to make for loops with properly-scoped variables easy to write.) +// +// The bitset is expected to be nonempty; otherwise this will crash the program +// with an out-of-bounds slice access. Similarly, if nNonzeroWord is larger +// than the actual number of nonzero words, or initially <= 0, the standard for +// loop will crash the program. (If nNonzeroWord is smaller but >0, the last +// nonzero words will be ignored.) +func NewNonzeroWordScanner(data []uintptr, nNonzeroWord int) (NonzeroWordScanner, int) { + for wordIdx := 0; ; wordIdx++ { + bitWord := data[wordIdx] + if bitWord != 0 { + bitIdxOffset := wordIdx * BitsPerWord + return NonzeroWordScanner{ + data: data, + bitIdxOffset: bitIdxOffset, + bitWord: bitWord & (bitWord - 1), + nNonzeroWord: nNonzeroWord, + }, bits.TrailingZeros64(uint64(bitWord)) + bitIdxOffset + } + } +} + +// Next returns the position of the next set bit, or -1 if there aren't any. +func (s *NonzeroWordScanner) Next() int { + bitWord := s.bitWord + if bitWord == 0 { + wordIdx := int(uint(s.bitIdxOffset) / BitsPerWord) + s.data[wordIdx] = 0 + s.nNonzeroWord-- + if s.nNonzeroWord == 0 { + // All words with set bits are accounted for, we can exit early. + // This is deliberately == 0 instead of <= 0 since it'll only be less + // than zero if there's a bug in the caller. We want to crash with an + // out-of-bounds access in that case. + return -1 + } + for { + wordIdx++ + bitWord = s.data[wordIdx] + if bitWord != 0 { + break + } + } + s.bitIdxOffset = wordIdx * BitsPerWord + } + s.bitWord = bitWord & (bitWord - 1) + return bits.TrailingZeros64(uint64(bitWord)) + s.bitIdxOffset +} diff --git a/bitset/bitset_test.go b/bitset/bitset_test.go new file mode 100644 index 00000000..0b7f2679 --- /dev/null +++ b/bitset/bitset_test.go @@ -0,0 +1,308 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +package bitset_test + +import ( + "math/bits" + "math/rand" + "runtime" + "testing" + + gbitset "github.com/grailbio/base/bitset" + "github.com/grailbio/testutil/expect" + "github.com/willf/bitset" +) + +func TestSetAndClearIntervals(t *testing.T) { + rand.Seed(1) + nTrialsPerNWord := 100 + for nWord := 1; nWord <= 9; nWord++ { + bs := make([]uintptr, nWord) + nBits := nWord * gbitset.BitsPerWord + expectedBits := make([]bool, nBits) + for trialIdx := 0; trialIdx < nTrialsPerNWord; trialIdx++ { + // We perform a bunch of random SetInterval and ClearInterval operations + // on a []uintptr bitset, use for-loops to update the simpler + // expectedBits slice to what we expect, and use gbitset.Test to verify + // semantic equivalence. + startIdx := rand.Intn(nBits) + limitIdx := startIdx + rand.Intn(nBits-startIdx) + gbitset.SetInterval(bs, startIdx, limitIdx) + for i := startIdx; i < limitIdx; i++ { + expectedBits[i] = true + } + for i := 0; i < nBits; i++ { + expect.EQ(t, gbitset.Test(bs, i), expectedBits[i]) + } + startIdx = rand.Intn(nBits) + limitIdx = startIdx + rand.Intn(nBits-startIdx) + gbitset.ClearInterval(bs, startIdx, limitIdx) + for i := startIdx; i < limitIdx; i++ { + expectedBits[i] = false + } + for i := 0; i < nBits; i++ { + expect.EQ(t, gbitset.Test(bs, i), expectedBits[i]) + } + } + } +} + +/* +Initial benchmark results: + MacBook Pro (15-inch, 2016) + 2.7 GHz Intel Core i7, 16 GB 2133 MHz LPDDR3 + +Benchmark_NonzeroWordLowDensity1-8 5 318053789 ns/op +Benchmark_NonzeroWordLowDensity4-8 20 92268360 ns/op +Benchmark_NonzeroWordLowDensityMax-8 20 75435109 ns/op +Benchmark_NonzeroWordHighDensity1-8 5 338889681 ns/op +Benchmark_NonzeroWordHighDensity4-8 20 93980434 ns/op +Benchmark_NonzeroWordHighDensityMax-8 20 85158994 ns/op + +For comparison, using github.com/willf/bitset.NextSet(): +Benchmark_NonzeroWordLowDensity1-8 5 295363742 ns/op +Benchmark_NonzeroWordLowDensity4-8 20 78013901 ns/op +Benchmark_NonzeroWordLowDensityMax-8 20 73992701 ns/op +Benchmark_NonzeroWordHighDensity1-8 2 600711815 ns/op +Benchmark_NonzeroWordHighDensity4-8 10 156621467 ns/op +Benchmark_NonzeroWordHighDensityMax-8 10 109333530 ns/op + +github.com/willf/bitset.NextSetMany(): +Benchmark_NonzeroWordLowDensity1-8 3 362510428 ns/op +Benchmark_NonzeroWordLowDensity4-8 20 98390731 ns/op +Benchmark_NonzeroWordLowDensityMax-8 20 89888478 ns/op +Benchmark_NonzeroWordHighDensity1-8 10 202346572 ns/op +Benchmark_NonzeroWordHighDensity4-8 20 57818033 ns/op +Benchmark_NonzeroWordHighDensityMax-8 30 49601154 ns/op + +Manual inlining: +Benchmark_NonzeroWordLowDensity1-8 20 66941143 ns/op +Benchmark_NonzeroWordLowDensity4-8 100 17791558 ns/op +Benchmark_NonzeroWordLowDensityMax-8 100 17825100 ns/op +Benchmark_NonzeroWordHighDensity1-8 20 101415506 ns/op +Benchmark_NonzeroWordHighDensity4-8 50 27927527 ns/op +Benchmark_NonzeroWordHighDensityMax-8 50 23895500 ns/op +*/ + +func nonzeroWordSubtask(dst, src []uintptr, nIter int) int { + tot := 0 + nzwPop := 0 + for _, bitWord := range src { + if bitWord != 0 { + nzwPop++ + } + } + for iter := 0; iter < nIter; iter++ { + copy(dst, src) + for s, i := gbitset.NewNonzeroWordScanner(dst, nzwPop); i != -1; i = s.Next() { + tot += i + } + } + return tot +} + +func willfNextSetSubtask(dst, src []uintptr, nIter int) int { + nBits := uint(len(src) * gbitset.BitsPerWord) + bsetSrc := bitset.New(nBits) + for i := uint(0); i != nBits; i++ { + if gbitset.Test(src, int(i)) { + bsetSrc.Set(i) + } + } + bsetDst := bitset.New(nBits) + + tot := uint(0) + for iter := 0; iter < nIter; iter++ { + bsetSrc.Copy(bsetDst) + for i, e := bsetDst.NextSet(0); e; i, e = bsetDst.NextSet(i + 1) { + tot += i + } + bsetDst.ClearAll() + } + return int(tot) +} + +func willfNextSetManySubtask(dst, src []uintptr, nIter int) int { + nBits := uint(len(src) * gbitset.BitsPerWord) + bsetSrc := bitset.New(nBits) + for i := uint(0); i != nBits; i++ { + if gbitset.Test(src, int(i)) { + bsetSrc.Set(i) + } + } + bsetDst := bitset.New(nBits) + + tot := uint(0) + // tried other buffer sizes, 256 seems to be a sweet spot + var buffer [256]uint + for iter := 0; iter < nIter; iter++ { + bsetSrc.Copy(bsetDst) + for i, buf := bsetDst.NextSetMany(0, buffer[:]); len(buf) > 0; i, buf = bsetDst.NextSetMany(i+1, buf) { + for j := range buf { + tot += buf[j] + } + } + bsetDst.ClearAll() + } + return int(tot) +} + +func bitsetManualInlineSubtask(dst, src []uintptr, nIter int) int { + tot := 0 + nzwPop := 0 + for _, bitWord := range src { + if bitWord != 0 { + nzwPop++ + } + } + for iter := 0; iter < nIter; iter++ { + copy(dst, src) + nNonzeroWord := nzwPop + for i, bitWord := range dst { + if bitWord != 0 { + bitIdxOffset := i * gbitset.BitsPerWord + for { + tot += bits.TrailingZeros64(uint64(bitWord)) + bitIdxOffset + bitWord &= bitWord - 1 + if bitWord == 0 { + break + } + } + dst[i] = 0 + } + nNonzeroWord-- + if nNonzeroWord == 0 { + break + } + } + } + return tot +} + +func nonzeroWordSubtaskFuture(dst, src []uintptr, nIter int) chan int { + future := make(chan int) + // go func() { future <- nonzeroWordSubtask(dst, src, nIter) }() + // go func() { future <- willfNextSetSubtask(dst, src, nIter) }() + // go func() { future <- willfNextSetManySubtask(dst, src, nIter) }() + go func() { future <- bitsetManualInlineSubtask(dst, src, nIter) }() + return future +} + +func multiNonzeroWord(dsts, srcs [][]uintptr, cpus int, nJob int) { + sumFutures := make([]chan int, cpus) + shardSizeBase := nJob / cpus + shardRemainder := nJob - shardSizeBase*cpus + shardSizeP1 := shardSizeBase + 1 + var taskIdx int + for ; taskIdx < shardRemainder; taskIdx++ { + sumFutures[taskIdx] = nonzeroWordSubtaskFuture(dsts[taskIdx], srcs[taskIdx], shardSizeP1) + } + for ; taskIdx < cpus; taskIdx++ { + sumFutures[taskIdx] = nonzeroWordSubtaskFuture(dsts[taskIdx], srcs[taskIdx], shardSizeBase) + } + var sum int + for taskIdx = 0; taskIdx < cpus; taskIdx++ { + sum += <-sumFutures[taskIdx] + } +} + +func benchmarkNonzeroWord(cpus, nWord, spacing, nJob int, b *testing.B) { + if cpus > runtime.NumCPU() { + b.Skipf("only have %v cpus", runtime.NumCPU()) + } + + dstSlices := make([][]uintptr, cpus) + srcSlices := make([][]uintptr, cpus) + nBits := nWord * gbitset.BitsPerWord + for ii := range dstSlices { + // 7 extra capacity to prevent false sharing. + newDst := make([]uintptr, nWord, nWord+7) + newSrc := make([]uintptr, nWord, nWord+7) + for i := spacing - 1; i < nBits; i += spacing { + gbitset.Set(newSrc, i) + } + dstSlices[ii] = newDst + srcSlices[ii] = newSrc + } + for i := 0; i < b.N; i++ { + multiNonzeroWord(dstSlices, srcSlices, cpus, nJob) + } +} + +func Benchmark_NonzeroWordLowDensity1(b *testing.B) { + benchmarkNonzeroWord(1, 16, 369, 9999999, b) +} + +func Benchmark_NonzeroWordLowDensity4(b *testing.B) { + benchmarkNonzeroWord(4, 16, 369, 9999999, b) +} + +func Benchmark_NonzeroWordLowDensityMax(b *testing.B) { + benchmarkNonzeroWord(runtime.NumCPU(), 16, 369, 9999999, b) +} + +func Benchmark_NonzeroWordHighDensity1(b *testing.B) { + benchmarkNonzeroWord(1, 16, 1, 99999, b) +} + +func Benchmark_NonzeroWordHighDensity4(b *testing.B) { + benchmarkNonzeroWord(4, 16, 1, 99999, b) +} + +func Benchmark_NonzeroWordHighDensityMax(b *testing.B) { + benchmarkNonzeroWord(runtime.NumCPU(), 16, 1, 99999, b) +} + +func naiveBitScanAdder(dst []uintptr) int { + nBits := len(dst) * gbitset.BitsPerWord + tot := 0 + for i := 0; i != nBits; i++ { + if gbitset.Test(dst, i) { + tot += i + } + } + return tot +} + +func TestNonzeroWord(t *testing.T) { + maxSize := 500 + nIter := 200 + srcArr := make([]uintptr, maxSize) + dstArr := make([]uintptr, maxSize) + for iter := 0; iter < nIter; iter++ { + sliceStart := rand.Intn(maxSize) + sliceEnd := sliceStart + rand.Intn(maxSize-sliceStart) + srcSlice := srcArr[sliceStart:sliceEnd] + dstSlice := dstArr[sliceStart:sliceEnd] + + for i := range srcSlice { + srcSlice[i] = uintptr(rand.Uint64()) + } + copy(dstSlice, srcSlice) + nzwPop := 0 + for _, bitWord := range dstSlice { + if bitWord != 0 { + nzwPop++ + } + } + if nzwPop == 0 { + continue + } + + tot1 := 0 + for s, i := gbitset.NewNonzeroWordScanner(dstSlice, nzwPop); i != -1; i = s.Next() { + tot1 += i + } + tot2 := naiveBitScanAdder(srcSlice) + if tot1 != tot2 { + t.Fatal("Mismatched bit-index sums.") + } + for _, bitWord := range dstSlice { + if bitWord != 0 { + t.Fatal("NonzeroWordScanner failed to clear all words.") + } + } + } +} diff --git a/bitset/doc.go b/bitset/doc.go new file mode 100644 index 00000000..d784fe85 --- /dev/null +++ b/bitset/doc.go @@ -0,0 +1,7 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +// Package bitset provides support for treating a []uintptr as a bitset. It's +// essentially a less-abstracted variant of github.com/willf/bitset. +package bitset diff --git a/cloud/awssession/provider.go b/cloud/awssession/provider.go index 2f29a652..cf6aa868 100644 --- a/cloud/awssession/provider.go +++ b/cloud/awssession/provider.go @@ -9,9 +9,9 @@ import ( "time" "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/grailbio/base/errors" "github.com/grailbio/base/security/ticket" "v.io/v23/context" - "v.io/x/lib/vlog" ) // Provider implements the aws/credentials.Provider interface using a GRAIL @@ -38,6 +38,9 @@ type Provider struct { // ExpiryWindow allows triggering a refresh before the AWS credentials // actually expire. ExpiryWindow time.Duration + + // Rationale indicates the reason for accessing a ticket + Rationale string } var _ credentials.Provider = (*Provider)(nil) @@ -52,7 +55,13 @@ func (p *Provider) Retrieve() (credentials.Value, error) { defer cancel() } var err error - p.Ticket, err = ticket.TicketServiceClient(p.TicketPath).Get(ctx) + if p.Rationale != "" { + p.Ticket, err = ticket.TicketServiceClient(p.TicketPath).GetWithArgs(ctx, map[string]string{ + ticket.ControlRationale.String(): p.Rationale, + }) + } else { + p.Ticket, err = ticket.TicketServiceClient(p.TicketPath).Get(ctx) + } if err != nil { return credentials.Value{}, err } @@ -64,7 +73,6 @@ func (p *Provider) Retrieve() (credentials.Value, error) { func (p *Provider) retrieve() (credentials.Value, error) { awsTicket, ok := p.Ticket.(ticket.TicketAwsTicket) if !ok { - vlog.Info("%q: bad ticket type %T, want %T", p.TicketPath, p.Ticket, awsTicket) return credentials.Value{}, fmt.Errorf("bad ticket type %T for %q, want %T", p.Ticket, p.TicketPath, awsTicket) } @@ -72,9 +80,8 @@ func (p *Provider) retrieve() (credentials.Value, error) { var err error p.Expiration, err = time.Parse(time.RFC3339, awsTicket.Value.AwsCredentials.Expiration) if err != nil { - vlog.Infof("%q: error parsing %q: %s", p.TicketPath, awsTicket.Value.AwsCredentials.Expiration, err) p.Ticket = nil - return credentials.Value{}, err + return credentials.Value{}, errors.E(err, fmt.Sprintf("%q: error parsing %q", p.TicketPath, awsTicket.Value.AwsCredentials.Expiration)) } } return credentials.Value{ @@ -88,16 +95,11 @@ func (p *Provider) retrieve() (credentials.Value, error) { // IsExpired implements the github.com/aws/aws-sdk-go/aws/credentials.Provider // interface. func (p *Provider) IsExpired() bool { - r := false + var r bool if p.Ticket == nil { r = true - } else if p.Expiration.IsZero() { - r = false - } else { + } else if !p.Expiration.IsZero() { r = time.Now().Add(p.ExpiryWindow).After(p.Expiration) } - if r { - vlog.VI(1).Infof("%q is expired", p.TicketPath) - } return r } diff --git a/cloud/awssession/session.go b/cloud/awssession/session.go index 81074f7a..913d844b 100644 --- a/cloud/awssession/session.go +++ b/cloud/awssession/session.go @@ -7,7 +7,6 @@ package awssession import ( - "os" "time" "github.com/aws/aws-sdk-go/aws" @@ -17,45 +16,27 @@ import ( ) const ( - region = "us-west-2" - grailTicketPath = "GRAIL_AWS_TICKET_PATH" - defaultTimeout = 10 * time.Second + region = "us-west-2" + defaultTimeout = 10 * time.Second ) -// NewWithTicket creates a AWS session using a GRAIL ticket. This is helper -// that uses inside a Provider with a timeout of 10 seconds. The region will be -// set to 'us-west-2' and can be overriden by passing an appropriate -// *aws.Config. +// NewWithTicket creates an AWS session using a GRAIL ticket. The returned +// session uses a Provider with a timeout of 10 seconds. The region will be set +// to 'us-west-2' and can be overridden by passing an appropriate *aws.Config. func NewWithTicket(ctx *context.T, ticketPath string, cfgs ...*aws.Config) (*session.Session, error) { + cfg := NewConfigWithTicket(ctx, ticketPath) + cfgs = append([]*aws.Config{cfg}, cfgs...) + return session.NewSession(cfgs...) +} + +// NewConfigWithTicket creates an AWS configuration using a GRAIL ticket. The +// returned configuration uses a Provider with a timeout of 10 seconds. The +// region will be set to 'us-west-2'. +func NewConfigWithTicket(ctx *context.T, ticketPath string) *aws.Config { creds := credentials.NewCredentials(&Provider{ Ctx: ctx, Timeout: defaultTimeout, TicketPath: ticketPath, }) - cfg := aws.NewConfig().WithCredentials(creds).WithRegion(region) - cfgs = append([]*aws.Config{cfg}, cfgs...) - return session.NewSession(cfgs...) -} - -// NewWithProvider creates a AWS session using a provider. The region will be -// set to 'us-west-2' and can be overriden by passing an appropriate -// *aws.Config. -func NewWithProvider(provider credentials.Provider, cfgs ...*aws.Config) (*session.Session, error) { - creds := credentials.NewCredentials(provider) - cfg := aws.NewConfig().WithCredentials(creds).WithRegion(region) - cfgs = append([]*aws.Config{cfg}, cfgs...) - return session.NewSession(cfgs...) -} - -// New creates a new AWS session using a GRAIL ticket indicated by the -// GRAIL_AWS_TICKET_PATH env variable if the variable is defined or by the -// default AWS session otherwise. The region will be set to 'us-west-2' and can -// be overriden by passing an appropriate *aws.Config. -func New(ctx *context.T, cfgs ...*aws.Config) (*session.Session, error) { - ticketPath := os.Getenv(grailTicketPath) - if ticketPath != "" { - return NewWithTicket(ctx, ticketPath, cfgs...) - } - cfgs = append([]*aws.Config{aws.NewConfig().WithRegion(region)}, cfgs...) - return session.NewSession(cfgs...) + return aws.NewConfig().WithCredentials(creds).WithRegion(region) } diff --git a/cloud/ec2util/ec2util.go b/cloud/ec2util/ec2util.go index 2602cf97..a1b5b768 100644 --- a/cloud/ec2util/ec2util.go +++ b/cloud/ec2util/ec2util.go @@ -22,8 +22,8 @@ import ( "time" "github.com/aws/aws-sdk-go/service/ec2" - "github.com/fullsailor/pkcs7" "v.io/x/lib/vlog" + "go.mozilla.org/pkcs7" ) type IdentityDocument struct { @@ -48,7 +48,7 @@ func init() { awsPublicCertificates = []*x509.Certificate{cert} } -func getInstance(output *ec2.DescribeInstancesOutput) (*ec2.Instance, error) { +func GetInstance(output *ec2.DescribeInstancesOutput) (*ec2.Instance, error) { if len(output.Reservations) != 1 { return nil, fmt.Errorf("unexpected number of Reservations (want 1): %+v", output) } @@ -66,47 +66,85 @@ func getInstance(output *ec2.DescribeInstancesOutput) (*ec2.Instance, error) { return instance, nil } -// GetIamInstanceProfileARN extracts the ARN from the output of a call to +// GetIamInstanceProfileARN extracts the ARN from the `instance` output of a call to // DescribeInstances. The ARN is expected to be non-empty. -func GetIamInstanceProfileARN(output *ec2.DescribeInstancesOutput) (string, error) { - instance, err := getInstance(output) - if err != nil { - return "", err +func GetIamInstanceProfileARN(instance *ec2.Instance) (string, error) { + if instance == nil { + return "", fmt.Errorf("non-nil instance is required: %+v", instance) } if instance.IamInstanceProfile == nil { - return "", fmt.Errorf("non-nil IamInstanceProfile is required: %+v", output) + return "", fmt.Errorf("non-nil IamInstanceProfile is required: %+v", instance) } profile := instance.IamInstanceProfile if profile.Arn == nil { - return "", fmt.Errorf("non-nil Arn is required: %+v", output) + return "", fmt.Errorf("non-nil Arn is required: %+v", instance) } if len(*profile.Arn) == 0 { - return "", fmt.Errorf("non-empty Arn is required: %+v", output) + return "", fmt.Errorf("non-empty Arn is required: %+v", instance) } return *profile.Arn, nil } // GetPublicIPAddress extracts the public IP address from the output of a call -// to DescribeInstances. The response is expected to be non-empty. -func GetPublicIPAddress(output *ec2.DescribeInstancesOutput) (string, error) { - instance, err := getInstance(output) - if err != nil { - return "", err +// to DescribeInstances Instance. The response is expected to be non-empty if the +// instance has a public IP and empty ("") if the instance is private. +func GetPublicIPAddress(instance *ec2.Instance) (string, error) { + if instance == nil { + return "", fmt.Errorf("non-nil instance is required: %+v", instance) } - if instance.PublicIpAddress == nil { - return "", fmt.Errorf("non-nil PublicIpAddress is required: %+v", output) + if instance.PublicIpAddress == nil || len(*instance.PublicIpAddress) == 0 { + return "", nil } - if len(*instance.PublicIpAddress) == 0 { - return "", fmt.Errorf("non-empty PublicIpAddress is required: %+v", output) + return *instance.PublicIpAddress, nil +} + +// GetPrivateIPAddress extracts the private IP address from the output of a call +// to DescribeInstances Instance. The response is expected to be the first private IP +// attached to the instance. +// If the instances no attached interfaces, the value is empty ("") +func GetPrivateIPAddress(instance *ec2.Instance) (string, error) { + if instance == nil { + return "", fmt.Errorf("non-nil instance is required: %+v", instance) } - return *instance.PublicIpAddress, nil + if instance.PrivateIpAddress == nil || len(*instance.PrivateIpAddress) == 0 { + return "", nil + } + + return *instance.PrivateIpAddress, nil +} + +// GetTags returns a map of Key/Value pairs representing the tags +func GetTags(instance *ec2.Instance) ([]*ec2.Tag, error) { + if instance == nil { + return nil, fmt.Errorf("non-nil instance is required: %+v", instance) + } + + if instance.Tags == nil || len(instance.Tags) == 0 { + return nil, nil + } + + return instance.Tags, nil +} + +// GetInstanceId returns the instanceID from the output of a call +// to DescribeInstances Instance. +func GetInstanceId(instance *ec2.Instance) (string, error) { + if instance == nil { + return "", fmt.Errorf("non-nil instance is required: %+v", instance) + } + + if instance.InstanceId == nil || len(*instance.InstanceId) == 0 { + return "", nil + } + + return *instance.InstanceId, nil } // ValidateInstance checks if an EC2 instance exists and it has the expected @@ -116,17 +154,28 @@ func GetPublicIPAddress(output *ec2.DescribeInstancesOutput) (string, error) { func ValidateInstance(output *ec2.DescribeInstancesOutput, doc IdentityDocument, remoteAddr string) (role string, err error) { vlog.Infof("reservations:\n%+v", output.Reservations) - if remoteAddr != "" { - publicIP, err := GetPublicIPAddress(output) - if err != nil { - return "", err - } + instance, err := GetInstance(output) + if err != nil { + return "", err + } + + publicIP, err := GetPublicIPAddress(instance) + if err != nil { + return "", err + } + + // Instances that do not have a public IP should be able to authenticate + // with ticket server. Connections from such instances are routed through a + // NAT gateway with an Elastic IP. The following check which ensures the + // remoteAddr from which the connection originates is same as the public IP + // of the instance is skipped for private instances. + if remoteAddr != "" && publicIP != "" { if !strings.HasPrefix(remoteAddr, publicIP+":") { return "", fmt.Errorf("mismatch between the real peer address (%s) and public IP of the instance (%s)", remoteAddr, publicIP) } } - arn, err := GetIamInstanceProfileARN(output) + arn, err := GetIamInstanceProfileARN(instance) if err != nil { return "", err } diff --git a/cloud/ec2util/ec2util_test.go b/cloud/ec2util/ec2util_test.go index b62d75ff..d5ccffe0 100644 --- a/cloud/ec2util/ec2util_test.go +++ b/cloud/ec2util/ec2util_test.go @@ -15,7 +15,7 @@ import ( "github.com/grailbio/base/cloud/ec2util" ) -func TestGetARN(t *testing.T) { +func TestGetInstance(t *testing.T) { cases := []struct { output *ec2.DescribeInstancesOutput @@ -53,19 +53,29 @@ func TestGetARN(t *testing.T) { }, }, }, "", "non-nil IamInstanceProfile"}, - {&ec2.DescribeInstancesOutput{ - Reservations: []*ec2.Reservation{ - &ec2.Reservation{ - Instances: []*ec2.Instance{ - &ec2.Instance{ - IamInstanceProfile: &ec2.IamInstanceProfile{}, - }, - }, - }, - }, - }, "", "non-nil Arn"}, - {newDescribeInstancesOutput("", ""), "", "non-empty Arn"}, - {newDescribeInstancesOutput("dummy", ""), "dummy", ""}, + } + + for _, c := range cases { + _, err := ec2util.GetInstance(c.output) + if err != nil && (c.errPrefix == "" || !strings.HasPrefix(err.Error(), c.errPrefix)) { + t.Errorf("GetInstance: got %q, want %q", err, c.errPrefix) + } + } +} + +func TestGetARN(t *testing.T) { + cases := []struct { + output *ec2.Instance + + arn string + errPrefix string + }{ + { + &ec2.Instance{ + IamInstanceProfile: &ec2.IamInstanceProfile{}, + }, "", "non-nil Arn"}, + {newInstancesOutput("", "", ""), "", "non-empty Arn"}, + {newInstancesOutput("", "dummy", ""), "dummy", ""}, } for _, c := range cases { @@ -79,6 +89,77 @@ func TestGetARN(t *testing.T) { } } +func TestGetInstanceId(t *testing.T) { + cases := []struct { + output *ec2.Instance + + instanceId string + errPrefix string + }{ + {nil, "", "non-nil"}, + {newInstancesOutput("i-1234", "", ""), + "i-1234", ""}, + } + + for _, c := range cases { + instanceId, err := ec2util.GetInstanceId(c.output) + if err != nil && (c.errPrefix == "" || !strings.HasPrefix(err.Error(), c.errPrefix)) { + t.Errorf("GetInstanceId: got %q, want %q", err, c.errPrefix) + } + if instanceId != c.instanceId { + t.Errorf("GetInstanceId: got %q, want %q", instanceId, c.instanceId) + } + } +} + +func TestGetPublicIPAddress(t *testing.T) { + cases := []struct { + output *ec2.Instance + + publicIp string + errPrefix string + }{ + {nil, "", "non-nil"}, + {newInstancesOutput("", "", "192.168.1.1"), + "192.168.1.1", ""}, + } + + for _, c := range cases { + publicIp, err := ec2util.GetPublicIPAddress(c.output) + if err != nil && (c.errPrefix == "" || !strings.HasPrefix(err.Error(), c.errPrefix)) { + t.Errorf("GetPublicIPAddress: got %q, want %q", err, c.errPrefix) + } + if publicIp != c.publicIp { + t.Errorf("GetPublicIPAddress: got %q, want %q", publicIp, c.publicIp) + } + } +} + +// TODO(aeiser) Implement test checking for tags +func TestGetTags(t *testing.T) { + cases := []struct { + output *ec2.Instance + + tags string + errPrefix string + }{ + {nil, "", "non-nil"}, + {&ec2.Instance{ + IamInstanceProfile: &ec2.IamInstanceProfile{}, + }, "", "non-nil Arn"}, + } + + for _, c := range cases { + _, err := ec2util.GetTags(c.output) + if err != nil && (c.errPrefix == "" || !strings.HasPrefix(err.Error(), c.errPrefix)) { + t.Errorf("GetTags: got %q, want %q", err, c.errPrefix) + } + // if tags != c.tags { + // t.Errorf("GetTags: got %q, want %q", tags, c.tags) + // } + } +} + func TestValidateInstance(t *testing.T) { cases := []struct { describeInstances *ec2.DescribeInstancesOutput @@ -104,6 +185,14 @@ func TestValidateInstance(t *testing.T) { "dummyRole", "", }, + //Instance that does not have a public IP + { + newDescribeInstancesOutput("arn:aws:iam::987654321012:instance-profile/dummyRole", ""), + ec2util.IdentityDocument{AccountID: "987654321012"}, + "52.215.119.108:", + "dummyRole", + "", + }, } for _, c := range cases { @@ -197,14 +286,19 @@ func newDescribeInstancesOutput(arn string, publicIP string) *ec2.DescribeInstan Reservations: []*ec2.Reservation{ &ec2.Reservation{ Instances: []*ec2.Instance{ - &ec2.Instance{ - IamInstanceProfile: &ec2.IamInstanceProfile{ - Arn: aws.String(arn), - }, - PublicIpAddress: aws.String(publicIP), - }, + newInstancesOutput("", arn, publicIP), }, }, }, } } + +func newInstancesOutput(instanceId string, arn string, publicIP string) *ec2.Instance { + return &ec2.Instance{ + InstanceId: &instanceId, + IamInstanceProfile: &ec2.IamInstanceProfile{ + Arn: aws.String(arn), + }, + PublicIpAddress: aws.String(publicIP), + } +} diff --git a/cloud/ec2util/iid.go b/cloud/ec2util/iid.go new file mode 100644 index 00000000..40f15b61 --- /dev/null +++ b/cloud/ec2util/iid.go @@ -0,0 +1,50 @@ +package ec2util + +import ( + "fmt" + "net/http" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/ec2metadata" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/grailbio/base/sync/once" +) + +var ( + iidOnce once.Task + iid ec2metadata.EC2InstanceIdentityDocument +) + +// GetInstanceIdentityDocument returns the EC2 Instance ID document (if the current +// process is running within an EC2 instance) or an error. +// Unlike the SDK's implementation, this will use longer timeouts and multiple retries +// to improve the reliability of getting the Instance ID document. +// The first result, whether success or failure, is cached for the lifetime of the process. +func GetInstanceIdentityDocument(sess *session.Session) (doc ec2metadata.EC2InstanceIdentityDocument, err error) { + err = iidOnce.Do(func() (oerr error) { + // Use HTTP client with custom timeout and max retries to prevent the SDK from + // using an HTTP client with a small timeout and a small number of retries for + // the ec2metadata client + metaClient := ec2metadata.New(sess, &aws.Config{ + HTTPClient: &http.Client{Timeout: 5 * time.Second}, + MaxRetries: aws.Int(5), + }) + for retries := 0; retries < 5; retries++ { + iid, oerr = metaClient.GetInstanceIdentityDocument() + if oerr == nil { + break + } + } + return + }) + switch { + case err != nil: + err = fmt.Errorf("ec2util.GetInstanceIdentityDocument: %v", err) + case iid.InstanceID == "": + err = fmt.Errorf("ec2util.GetInstanceIdentityDocument: Unable to get EC2InstanceIdentityDocument") + default: + doc = iid + } + return +} diff --git a/cloud/spotadvisor/README.md b/cloud/spotadvisor/README.md new file mode 100644 index 00000000..d66684fe --- /dev/null +++ b/cloud/spotadvisor/README.md @@ -0,0 +1,23 @@ +# spotadvisor + +This package provides an interface for fetching and utilizing [AWS Spot Advisor](https://aws.amazon.com/ec2/spot/instance-advisor/) +data. + +The data is sourced from: https://spot-bid-advisor.s3.amazonaws.com/spot-advisor-data.json + +The available data are: + +1. Interrupt rate data, described as: + > Frequency of interruption represents the rate at which Spot has reclaimed + capacity during the trailing month. They are in ranges of < 5%, 5-10%, 10-15%, + 15-20% and >20%. + +2. Savings data, described as: + + > Savings compared to On-Demand are calculated over the last 30 days. Please + note that price history data is averaged across Availability Zones and may be + delayed. To view current Spot prices, visit the Spot Price History in the AWS + Management Console for up to date pricing information for each Availability + Zone. + +See [godoc](https://godoc.org/github.com/grailbio/base/cloud/spotadvisor) for usage. diff --git a/cloud/spotadvisor/export_test.go b/cloud/spotadvisor/export_test.go new file mode 100644 index 00000000..10b1acf7 --- /dev/null +++ b/cloud/spotadvisor/export_test.go @@ -0,0 +1,10 @@ +// Copyright 2021 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package spotadvisor + +// Only for use in unit tests. +func SetSpotAdvisorDataUrl(url string) { + spotAdvisorDataUrl = url +} diff --git a/cloud/spotadvisor/spotadvisor.go b/cloud/spotadvisor/spotadvisor.go new file mode 100644 index 00000000..70c8b842 --- /dev/null +++ b/cloud/spotadvisor/spotadvisor.go @@ -0,0 +1,316 @@ +// Copyright 2021 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +// Package spotadvisor provides an interface for utilizing spot instance +// interrupt rate data and savings data from AWS. +package spotadvisor + +import ( + "encoding/json" + "fmt" + "net/http" + "sync" + "time" +) + +var spotAdvisorDataUrl = "https://spot-bid-advisor.s3.amazonaws.com/spot-advisor-data.json" + +const ( + // Spot Advisor data is only updated a few times a day, so we just refresh once an hour. + // Might need to revisit this value if data is updated more frequently in the future. + defaultRefreshInterval = 1 * time.Hour + defaultRequestTimeout = 10 * time.Second + + Linux = OsType("Linux") + Windows = OsType("Windows") +) + +// These need to be in their own const block to ensure iota starts at 0. +const ( + ZeroToFivePct InterruptRange = iota + FiveToTenPct + TenToFifteenPct + FifteenToTwentyPct + GreaterThanTwentyPct +) + +// These need to be in their own const block to ensure iota starts at 0. +const ( + LessThanFivePct InterruptProbability = iota + LessThanTenPct + LessThanFifteenPct + LessThanTwentyPct + Any +) + +type interruptRange struct { + Label string `json:"label"` + Index int `json:"index"` + Dots int `json:"dots"` + Max int `json:"max"` +} + +type instanceType struct { + Cores int `json:"cores"` + Emr bool `json:"emr"` + RamGb float32 `json:"ram_gb"` +} + +type instanceData struct { + RangeIdx int `json:"r"` + Savings int `json:"s"` +} + +type osGroups struct { + Windows map[string]instanceData `json:"Windows"` + Linux map[string]instanceData `json:"Linux"` +} + +type advisorData struct { + Ranges []interruptRange `json:"ranges"` + // key is an EC2 instance type name like "r5a.large" + InstanceTypes map[string]instanceType `json:"instance_types"` + // key is an AWS region name like "us-west-2" + SpotAdvisor map[string]osGroups `json:"spot_advisor"` +} + +type aggKey struct { + ot OsType + ar AwsRegion + ip InterruptProbability +} + +func (k aggKey) String() string { + return fmt.Sprintf("{%s, %s, %s}", k.ot, k.ar, k.ip) +} + +// OsType should only be used via the pre-defined constants in this package. +type OsType string + +// AwsRegion is an AWS region name like "us-west-2". +type AwsRegion string + +// InstanceType is an EC2 instance type name like "r5a.large". +type InstanceType string + +// InterruptRange is the AWS defined interrupt range for an instance type; it +// should only be used via the pre-defined constants in this package. +type InterruptRange int + +func (ir InterruptRange) String() string { + switch ir { + case ZeroToFivePct: + return "O-5%" + case FiveToTenPct: + return "5-10%" + case TenToFifteenPct: + return "10-15%" + case FifteenToTwentyPct: + return "15-20%" + case GreaterThanTwentyPct: + return "> 20%" + default: + return "invalid interrupt range" + } +} + +// InterruptProbability is an upper bound used to indicate multiple interrupt +// ranges; it should only be used via the pre-defined constants in this package. +type InterruptProbability int + +func (ir InterruptProbability) String() string { + switch ir { + case LessThanFivePct: + return "< 5%" + case LessThanTenPct: + return "< 10%" + case LessThanFifteenPct: + return "< 15%" + case LessThanTwentyPct: + return "< 20%" + case Any: + return "Any" + default: + return "invalid interrupt probability" + } +} + +// SpotAdvisor provides an interface for utilizing spot instance interrupt rate +// data and savings data from AWS. +type SpotAdvisor struct { + mu sync.RWMutex + // rawData is the decoded spot advisor json response + rawData advisorData + // aggData maps each aggKey to a slice of instance types aggregated by interrupt + // probability. For example, if aggKey.ip=LessThanTenPct, then the mapped value + // would contain all instance types which have an interrupt range of + // LessThanFivePct or FiveToTenPct. + aggData map[aggKey][]string + + // TODO: incorporate spot advisor savings data +} + +// SimpleLogger is a bare-bones logger interface which allows many logger +// implementations to be used with SpotAdvisor. The default Go log.Logger and +// grailbio/base/log.Logger implement this interface. +type SimpleLogger interface { + Printf(string, ...interface{}) +} + +// NewSpotAdvisor initializes and returns a SpotAdvisor instance. If +// initialization fails, a nil SpotAdvisor is returned with an error. The +// underlying data is asynchronously updated, until the done channel is closed. +// Errors during updates are non-fatal and will not prevent future updates. +func NewSpotAdvisor(log SimpleLogger, done <-chan struct{}) (*SpotAdvisor, error) { + sa := SpotAdvisor{} + // initial load + if err := sa.refresh(); err != nil { + return nil, fmt.Errorf("error fetching spot advisor data: %s", err) + } + + go func() { + ticker := time.NewTicker(defaultRefreshInterval) + defer ticker.Stop() + for { + select { + case <-done: + return + case <-ticker.C: + if err := sa.refresh(); err != nil { + log.Printf("error refreshing spot advisor data (will try again later): %s", err) + } + } + } + }() + return &sa, nil +} + +func (sa *SpotAdvisor) refresh() (err error) { + // fetch + client := &http.Client{Timeout: defaultRequestTimeout} + resp, err := client.Get(spotAdvisorDataUrl) + if err != nil { + return err + } + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("GET %s response StatusCode: %s", spotAdvisorDataUrl, http.StatusText(resp.StatusCode)) + } + var rawData advisorData + err = json.NewDecoder(resp.Body).Decode(&rawData) + if err != nil { + return err + } + err = resp.Body.Close() + if err != nil { + return err + } + + // update internal data structures + aggData := make(map[aggKey][]string) + for r, o := range rawData.SpotAdvisor { + region := AwsRegion(r) + // transform the raw data so that the values of aggData will contain just the instances in a given range + for instance, data := range o.Linux { + k := aggKey{Linux, region, InterruptProbability(data.RangeIdx)} + aggData[k] = append(aggData[k], instance) + } + for instance, data := range o.Windows { + k := aggKey{Windows, region, InterruptProbability(data.RangeIdx)} + aggData[k] = append(aggData[k], instance) + } + + // aggregate instances by the upper bound interrupt probability of each key + for i := 1; i <= int(Any); i++ { + { + lk := aggKey{Linux, region, InterruptProbability(i)} + lprevk := aggKey{Linux, region, InterruptProbability(i - 1)} + aggData[lk] = append(aggData[lk], aggData[lprevk]...) + } + { + wk := aggKey{Windows, region, InterruptProbability(i)} + wprevk := aggKey{Windows, region, InterruptProbability(i - 1)} + aggData[wk] = append(aggData[wk], aggData[wprevk]...) + } + } + } + sa.mu.Lock() + sa.rawData = rawData + sa.aggData = aggData + sa.mu.Unlock() + return nil +} + +// FilterByMaxInterruptProbability returns a subset of the input candidates by +// removing instance types which have a probability of interruption greater than ip. +func (sa *SpotAdvisor) FilterByMaxInterruptProbability(ot OsType, ar AwsRegion, candidates []string, ip InterruptProbability) (filtered []string, err error) { + if ip == Any { + // There's a chance we may not have spot advisor data for some instances in + // the candidates, so just return as is without doing a set difference. + return candidates, nil + } + allowed, err := sa.GetInstancesWithMaxInterruptProbability(ot, ar, ip) + if err != nil { + return nil, err + } + for _, c := range candidates { + if allowed[c] { + filtered = append(filtered, c) + } + } + return filtered, nil +} + +// GetInstancesWithMaxInterruptProbability returns the set of spot instance types +// with an interrupt probability less than or equal to ip, with the given OS and region. +func (sa *SpotAdvisor) GetInstancesWithMaxInterruptProbability(ot OsType, region AwsRegion, ip InterruptProbability) (map[string]bool, error) { + if ip < LessThanFivePct || ip > Any { + return nil, fmt.Errorf("invalid InterruptProbability: %d", ip) + } + k := aggKey{ot, region, ip} + sa.mu.RLock() + defer sa.mu.RUnlock() + ts, ok := sa.aggData[k] + if !ok { + return nil, fmt.Errorf("no spot advisor data for: %s", k) + } + tsMap := make(map[string]bool, len(ts)) + for _, t := range ts { + tsMap[t] = true + } + return tsMap, nil +} + +// GetInterruptRange returns the interrupt range for the instance type with the +// given OS and region. +func (sa *SpotAdvisor) GetInterruptRange(ot OsType, ar AwsRegion, it InstanceType) (InterruptRange, error) { + sa.mu.RLock() + defer sa.mu.RUnlock() + osg, ok := sa.rawData.SpotAdvisor[string(ar)] + if !ok { + return -1, fmt.Errorf("no spot advisor data for: %s", ar) + } + var m map[string]instanceData + switch ot { + case Linux: + m = osg.Linux + case Windows: + m = osg.Windows + default: + return -1, fmt.Errorf("invalid OS: %s", ot) + } + + d, ok := m[string(it)] + if !ok { + return -1, fmt.Errorf("no spot advisor data for %s instance type '%s' in %s", ot, it, ar) + } + return InterruptRange(d.RangeIdx), nil +} + +// GetMaxInterruptProbability is a helper method to easily get the max interrupt +// probability of an instance type (i.e. the upper bound of the interrupt range +// for that instance type). +func (sa *SpotAdvisor) GetMaxInterruptProbability(ot OsType, ar AwsRegion, it InstanceType) (InterruptProbability, error) { + ir, err := sa.GetInterruptRange(ot, ar, it) + return InterruptProbability(ir), err +} diff --git a/cloud/spotadvisor/spotadvisor_test.go b/cloud/spotadvisor/spotadvisor_test.go new file mode 100644 index 00000000..708da184 --- /dev/null +++ b/cloud/spotadvisor/spotadvisor_test.go @@ -0,0 +1,352 @@ +// Copyright 2021 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package spotadvisor_test + +import ( + "context" + "fmt" + "io/ioutil" + "log" + "net/http" + "net/http/httptest" + "reflect" + "sort" + "testing" + + sa "github.com/grailbio/base/cloud/spotadvisor" +) + +// Contains an abridged version of a real response to make got/want comparisons easier. +const testDataPath = "./testdata/test-spot-advisor-data.json" + +// TestGetAndFilterByInterruptRate tests both GetInstancesWithMaxInterruptProbability and FilterByMaxInterruptProbability. +func TestGetAndFilterByInterruptRate(t *testing.T) { + defer setupMockTestServer(t).Close() + adv, err := sa.NewSpotAdvisor(testLogger, context.Background().Done()) + if err != nil { + t.Fatalf(err.Error()) + } + tests := []struct { + name string + osType sa.OsType + region sa.AwsRegion + maxInterruptProb sa.InterruptProbability + candidates []string + want []string + wantErr error + }{ + { + name: "simple", + osType: sa.Windows, + region: sa.AwsRegion("eu-west-2"), + candidates: testAvailableInstanceTypes, + maxInterruptProb: sa.LessThanFivePct, + want: []string{"r4.xlarge"}, + }, + { + name: "<5%", + osType: sa.Linux, + region: sa.AwsRegion("eu-west-2"), + candidates: testAvailableInstanceTypes, + maxInterruptProb: sa.LessThanFivePct, + want: []string{"m5a.4xlarge"}, + }, + { + name: "<10%", + osType: sa.Linux, + region: sa.AwsRegion("eu-west-2"), + candidates: testAvailableInstanceTypes, + maxInterruptProb: sa.LessThanTenPct, + want: []string{"m5a.4xlarge", "t3.nano"}, + }, + { + name: "<15%", + osType: sa.Linux, + region: sa.AwsRegion("eu-west-2"), + candidates: testAvailableInstanceTypes, + maxInterruptProb: sa.LessThanFifteenPct, + want: []string{"m5a.4xlarge", "t3.nano", "g4dn.12xlarge"}, + }, + { + name: "<20%", + osType: sa.Linux, + region: sa.AwsRegion("eu-west-2"), + candidates: testAvailableInstanceTypes, + maxInterruptProb: sa.LessThanTwentyPct, + want: []string{"m5a.4xlarge", "t3.nano", "g4dn.12xlarge", "r5d.8xlarge"}, + }, + { + name: "Any", + osType: sa.Linux, + region: sa.AwsRegion("eu-west-2"), + candidates: testAvailableInstanceTypes, + maxInterruptProb: sa.Any, + want: testAvailableInstanceTypes, + }, + { + name: "bad_interrupt_prob_neg", + osType: sa.Linux, + region: sa.AwsRegion("eu-west-2"), + candidates: testAvailableInstanceTypes, + maxInterruptProb: sa.InterruptProbability(-1), + want: nil, + wantErr: fmt.Errorf("invalid InterruptProbability: -1"), + }, + { + name: "bad_interrupt_prob_pos", + osType: sa.Linux, + region: sa.AwsRegion("eu-west-2"), + candidates: testAvailableInstanceTypes, + maxInterruptProb: sa.InterruptProbability(6), + want: nil, + wantErr: fmt.Errorf("invalid InterruptProbability: 6"), + }, + { + name: "bad_instance_region", + osType: sa.Linux, + region: sa.AwsRegion("us-foo-2"), + candidates: testAvailableInstanceTypes, + maxInterruptProb: sa.LessThanFifteenPct, + want: nil, + wantErr: fmt.Errorf("no spot advisor data for: {Linux, us-foo-2, < 15%%}"), + }, + } + for _, tt := range tests { + name := fmt.Sprintf("%s_%s_%s_%d", tt.name, tt.osType, tt.region, tt.maxInterruptProb) + t.Run(name, func(t *testing.T) { + got, gotErr := adv.FilterByMaxInterruptProbability(tt.osType, tt.region, tt.candidates, tt.maxInterruptProb) + checkErr(t, tt.wantErr, gotErr) + if tt.wantErr == nil { + checkEqual(t, tt.want, got) + } + }) + } +} + +func TestGetInterruptRange(t *testing.T) { + defer setupMockTestServer(t).Close() + adv, err := sa.NewSpotAdvisor(testLogger, context.Background().Done()) + if err != nil { + t.Fatalf(err.Error()) + } + tests := []struct { + name string + osType sa.OsType + region sa.AwsRegion + instanceType sa.InstanceType + want sa.InterruptRange + wantErr error + }{ + { + name: "simple", + osType: sa.Windows, + region: sa.AwsRegion("us-west-2"), + instanceType: "c5a.24xlarge", + want: sa.TenToFifteenPct, + }, + { + name: "bad_region", + osType: sa.Windows, + region: sa.AwsRegion("us-foo-2"), + instanceType: "c5a.24xlarge", + want: -1, + wantErr: fmt.Errorf("no spot advisor data for: us-foo-2"), + }, + { + name: "bad_os", + osType: sa.OsType("Unix"), + region: sa.AwsRegion("us-west-2"), + instanceType: "c5a.24xlarge", + want: -1, + wantErr: fmt.Errorf("invalid OS: Unix"), + }, + { + name: "bad_instance_type", + osType: sa.Linux, + region: sa.AwsRegion("us-west-2"), + instanceType: "foo.bar", + want: -1, + wantErr: fmt.Errorf("no spot advisor data for Linux instance type 'foo.bar' in us-west-2"), + }, + } + for _, tt := range tests { + name := fmt.Sprintf("%s_%s_%s", tt.name, tt.osType, tt.region) + t.Run(name, func(t *testing.T) { + got, gotErr := adv.GetInterruptRange(tt.osType, tt.region, tt.instanceType) + checkErr(t, tt.wantErr, gotErr) + if tt.wantErr == nil && tt.want != got { + t.Fatalf("want: %s, got: %s", tt.want, got) + } + }) + } +} + +func TestGetMaxInterruptProbability(t *testing.T) { + defer setupMockTestServer(t).Close() + adv, err := sa.NewSpotAdvisor(testLogger, context.Background().Done()) + if err != nil { + t.Fatalf(err.Error()) + } + tests := []struct { + name string + osType sa.OsType + region sa.AwsRegion + instanceType sa.InstanceType + want sa.InterruptProbability + wantErr error + }{ + { + name: "simple_<5%", + osType: sa.Linux, + region: sa.AwsRegion("eu-west-2"), + instanceType: "m5a.4xlarge", + want: sa.LessThanFivePct, + }, + { + name: "simple_<10%", + osType: sa.Linux, + region: sa.AwsRegion("eu-west-2"), + instanceType: "t3.nano", + want: sa.LessThanTenPct, + }, + { + name: "simple_<15%", + osType: sa.Linux, + region: sa.AwsRegion("eu-west-2"), + instanceType: "g4dn.12xlarge", + want: sa.LessThanFifteenPct, + }, + { + name: "simple_<20%", + osType: sa.Linux, + region: sa.AwsRegion("eu-west-2"), + instanceType: "r5d.8xlarge", + want: sa.LessThanTwentyPct, + }, + { + name: "simple_Any", + osType: sa.Linux, + region: sa.AwsRegion("eu-west-2"), + instanceType: "i3.2xlarge", + want: sa.Any, + }, + { + name: "bad_region", + osType: sa.Windows, + region: sa.AwsRegion("us-foo-2"), + instanceType: "c5a.24xlarge", + want: -1, + wantErr: fmt.Errorf("no spot advisor data for: us-foo-2"), + }, + { + name: "bad_os", + osType: sa.OsType("Unix"), + region: sa.AwsRegion("us-west-2"), + instanceType: "c5a.24xlarge", + want: -1, + wantErr: fmt.Errorf("invalid OS: Unix"), + }, + { + name: "bad_instance_type", + osType: sa.Linux, + region: sa.AwsRegion("us-west-2"), + instanceType: "foo.bar", + want: -1, + wantErr: fmt.Errorf("no spot advisor data for Linux instance type 'foo.bar' in us-west-2"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, gotErr := adv.GetMaxInterruptProbability(tt.osType, tt.region, tt.instanceType) + checkErr(t, tt.wantErr, gotErr) + if tt.wantErr == nil && tt.want != got { + t.Fatalf("want: %s, got: %s", tt.want, got) + } + }) + } + +} + +// setupMockTestServer starts a test server and replaces the actual spot advisor +// data URL with the test server's URL. A request to the server will return the +// contents of the file at testDataPath. The caller is expected to call Close() +// on the returned test server. +func setupMockTestServer(t *testing.T) *httptest.Server { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b, err := ioutil.ReadFile(testDataPath) + if err != nil { + t.Fatal(err) + } + if _, err := w.Write(b); err != nil { + t.Fatal(err) + } + })) + sa.SetSpotAdvisorDataUrl(ts.URL) + return ts +} + +func checkEqual(t *testing.T, want []string, got []string) { + if len(want) != len(got) { + t.Fatalf("\nwant:\t%s\ngot:\t%s", want, got) + } + sort.Strings(want) + sort.Strings(got) + if !reflect.DeepEqual(got, got) { + t.Fatalf("\nwant:\t%s\ngot:\t%s", want, got) + } +} + +func checkErr(t *testing.T, want error, got error) { + if want != nil && got != nil { + if want.Error() != got.Error() { + t.Fatalf("want: %s, got: %s", want, got) + } else { + return + } + } + if want != got { + t.Fatalf("want: %s, got: %s", want, got) + } +} + +var testLogger = log.New(ioutil.Discard, "", 0) + +var testAvailableInstanceTypes = []string{ + "a1.2xlarge", "a1.4xlarge", "a1.large", "a1.metal", "a1.xlarge", "c1.xlarge", "c3.2xlarge", "c3.4xlarge", "c3.8xlarge", "c3.large", "c3.xlarge", + "c4.2xlarge", "c4.4xlarge", "c4.8xlarge", "c4.large", "c4.xlarge", "c5.12xlarge", "c5.18xlarge", "c5.24xlarge", "c5.2xlarge", "c5.4xlarge", "c5.9xlarge", + "c5.large", "c5.metal", "c5.xlarge", "c5a.12xlarge", "c5a.16xlarge", "c5a.24xlarge", "c5a.2xlarge", "c5a.4xlarge", "c5a.8xlarge", "c5a.large", "c5a.xlarge", + "c5ad.12xlarge", "c5ad.16xlarge", "c5ad.24xlarge", "c5ad.2xlarge", "c5ad.4xlarge", "c5ad.8xlarge", "c5ad.large", "c5ad.xlarge", "c5d.12xlarge", "c5d.18xlarge", + "c5d.24xlarge", "c5d.2xlarge", "c5d.4xlarge", "c5d.9xlarge", "c5d.large", "c5d.metal", "c5d.xlarge", "c5n.18xlarge", "c5n.2xlarge", "c5n.4xlarge", "c5n.9xlarge", + "c5n.large", "c5n.metal", "c5n.xlarge", "c6g.12xlarge", "c6g.16xlarge", "c6g.2xlarge", "c6g.4xlarge", "c6g.8xlarge", "c6g.large", "c6g.metal", "c6g.xlarge", + "c6gd.12xlarge", "c6gd.16xlarge", "c6gd.2xlarge", "c6gd.4xlarge", "c6gd.8xlarge", "c6gd.large", "c6gd.metal", "c6gd.xlarge", "c6gn.12xlarge", "c6gn.16xlarge", "c6gn.2xlarge", + "c6gn.4xlarge", "c6gn.8xlarge", "c6gn.large", "c6gn.xlarge", "cr1.8xlarge", "d2.2xlarge", "d2.4xlarge", "d2.8xlarge", "d2.xlarge", "d3.2xlarge", "d3.4xlarge", + "d3.8xlarge", "d3.xlarge", "d3en.12xlarge", "d3en.2xlarge", "d3en.4xlarge", "d3en.6xlarge", "d3en.8xlarge", "d3en.xlarge", "f1.16xlarge", "f1.2xlarge", "f1.4xlarge", + "g2.2xlarge", "g2.8xlarge", "g3.16xlarge", "g3.4xlarge", "g3.8xlarge", "g3s.xlarge", "g4ad.16xlarge", "g4ad.4xlarge", "g4ad.8xlarge", "g4dn.12xlarge", "g4dn.16xlarge", + "g4dn.2xlarge", "g4dn.4xlarge", "g4dn.8xlarge", "g4dn.metal", "g4dn.xlarge", "h1.16xlarge", "h1.2xlarge", "h1.4xlarge", "h1.8xlarge", "hs1.8xlarge", "i2.2xlarge", + "i2.4xlarge", "i2.8xlarge", "i2.xlarge", "i3.16xlarge", "i3.2xlarge", "i3.4xlarge", "i3.8xlarge", "i3.large", "i3.metal", "i3.xlarge", "i3en.12xlarge", + "i3en.24xlarge", "i3en.2xlarge", "i3en.3xlarge", "i3en.6xlarge", "i3en.large", "i3en.metal", "i3en.xlarge", "inf1.24xlarge", "inf1.2xlarge", "inf1.6xlarge", "inf1.xlarge", + "m1.large", "m1.xlarge", "m2.2xlarge", "m2.4xlarge", "m2.xlarge", "m3.2xlarge", "m3.large", "m3.xlarge", "m4.10xlarge", "m4.16xlarge", "m4.2xlarge", + "m4.4xlarge", "m4.large", "m4.xlarge", "m5.12xlarge", "m5.16xlarge", "m5.24xlarge", "m5.2xlarge", "m5.4xlarge", "m5.8xlarge", "m5.large", "m5.metal", + "m5.xlarge", "m5a.12xlarge", "m5a.16xlarge", "m5a.24xlarge", "m5a.2xlarge", "m5a.4xlarge", "m5a.8xlarge", "m5a.large", "m5a.xlarge", "m5ad.12xlarge", "m5ad.16xlarge", + "m5ad.24xlarge", "m5ad.2xlarge", "m5ad.4xlarge", "m5ad.8xlarge", "m5ad.large", "m5ad.xlarge", "m5d.12xlarge", "m5d.16xlarge", "m5d.24xlarge", "m5d.2xlarge", "m5d.4xlarge", + "m5d.8xlarge", "m5d.large", "m5d.metal", "m5d.xlarge", "m5dn.12xlarge", "m5dn.16xlarge", "m5dn.24xlarge", "m5dn.2xlarge", "m5dn.4xlarge", "m5dn.8xlarge", "m5dn.large", + "m5dn.metal", "m5dn.xlarge", "m5n.12xlarge", "m5n.16xlarge", "m5n.24xlarge", "m5n.2xlarge", "m5n.4xlarge", "m5n.8xlarge", "m5n.large", "m5n.metal", "m5n.xlarge", + "m5zn.12xlarge", "m5zn.2xlarge", "m5zn.3xlarge", "m5zn.6xlarge", "m5zn.large", "m5zn.metal", "m5zn.xlarge", "m6g.12xlarge", "m6g.16xlarge", "m6g.2xlarge", "m6g.4xlarge", + "m6g.8xlarge", "m6g.large", "m6g.metal", "m6g.xlarge", "m6gd.12xlarge", "m6gd.16xlarge", "m6gd.2xlarge", "m6gd.4xlarge", "m6gd.8xlarge", "m6gd.large", "m6gd.metal", + "m6gd.xlarge", "p2.16xlarge", "p2.8xlarge", "p2.xlarge", "p3.16xlarge", "p3.2xlarge", "p3.8xlarge", "p3dn.24xlarge", "p4d.24xlarge", "r3.2xlarge", "r3.4xlarge", + "r3.8xlarge", "r3.large", "r3.xlarge", "r4.16xlarge", "r4.2xlarge", "r4.4xlarge", "r4.8xlarge", "r4.large", "r4.xlarge", "r5.12xlarge", "r5.16xlarge", + "r5.24xlarge", "r5.2xlarge", "r5.4xlarge", "r5.8xlarge", "r5.large", "r5.metal", "r5.xlarge", "r5a.12xlarge", "r5a.16xlarge", "r5a.24xlarge", "r5a.2xlarge", + "r5a.4xlarge", "r5a.8xlarge", "r5a.large", "r5a.xlarge", "r5ad.12xlarge", "r5ad.16xlarge", "r5ad.24xlarge", "r5ad.2xlarge", "r5ad.4xlarge", "r5ad.8xlarge", "r5ad.large", + "r5ad.xlarge", "r5b.12xlarge", "r5b.16xlarge", "r5b.24xlarge", "r5b.2xlarge", "r5b.4xlarge", "r5b.8xlarge", "r5b.large", "r5b.metal", "r5b.xlarge", "r5d.12xlarge", + "r5d.16xlarge", "r5d.24xlarge", "r5d.2xlarge", "r5d.4xlarge", "r5d.8xlarge", "r5d.large", "r5d.metal", "r5d.xlarge", "r5dn.12xlarge", "r5dn.16xlarge", "r5dn.24xlarge", + "r5dn.2xlarge", "r5dn.4xlarge", "r5dn.8xlarge", "r5dn.large", "r5dn.metal", "r5dn.xlarge", "r5n.12xlarge", "r5n.16xlarge", "r5n.24xlarge", "r5n.2xlarge", "r5n.4xlarge", + "r5n.8xlarge", "r5n.large", "r5n.metal", "r5n.xlarge", "r6g.12xlarge", "r6g.16xlarge", "r6g.2xlarge", "r6g.4xlarge", "r6g.8xlarge", "r6g.large", "r6g.metal", + "r6g.xlarge", "r6gd.12xlarge", "r6gd.16xlarge", "r6gd.2xlarge", "r6gd.4xlarge", "r6gd.8xlarge", "r6gd.large", "r6gd.metal", "r6gd.xlarge", "t1.micro", "t2.2xlarge", + "t2.large", "t2.micro", "t2.nano", "t2.xlarge", "t3.2xlarge", "t3.large", "t3.micro", "t3.nano", "t3.xlarge", "t3a.2xlarge", "t3a.large", + "t3a.micro", "t3a.nano", "t3a.xlarge", "t4g.2xlarge", "t4g.large", "t4g.micro", "t4g.nano", "t4g.xlarge", "u-12tb1.112xlarge", "u-6tb1.112xlarge", "u-6tb1.56xlarge", + "u-9tb1.112xlarge", "x1.16xlarge", "x1.32xlarge", "x1e.16xlarge", "x1e.2xlarge", "x1e.32xlarge", "x1e.4xlarge", "x1e.8xlarge", "x1e.xlarge", "x2gd.12xlarge", "x2gd.16xlarge", + "x2gd.2xlarge", "x2gd.4xlarge", "x2gd.8xlarge", "x2gd.large", "x2gd.metal", "x2gd.xlarge", "z1d.12xlarge", "z1d.2xlarge", "z1d.3xlarge", "z1d.6xlarge", "z1d.large", + "z1d.metal", "z1d.xlarge", +} diff --git a/cloud/spotadvisor/testdata/test-spot-advisor-data.json b/cloud/spotadvisor/testdata/test-spot-advisor-data.json new file mode 100644 index 00000000..95113564 --- /dev/null +++ b/cloud/spotadvisor/testdata/test-spot-advisor-data.json @@ -0,0 +1,2955 @@ +{ + "spot_advisor": { + "eu-west-2": { + "Windows": { + "r4.xlarge": { + "r": 0, + "s": 50 + }, + "t3.nano": { + "r": 1, + "s": 39 + }, + "t2.micro": { + "r": 2, + "s": 52 + }, + "r5d.metal": { + "r": 3, + "s": 52 + }, + "i3.2xlarge": { + "r": 4, + "s": 46 + } + }, + "Linux": { + "m5a.4xlarge": { + "r": 0, + "s": 67 + }, + "t3.nano": { + "r": 1, + "s": 69 + }, + "g4dn.12xlarge": { + "r": 2, + "s": 70 + }, + "r5d.8xlarge": { + "r": 3, + "s": 80 + }, + "i3.2xlarge": { + "r": 4, + "s": 70 + } + } + }, + "us-west-2": { + "Windows": { + "t3.nano": { + "r": 0, + "s": 37 + }, + "r5a.large": { + "r": 1, + "s": 38 + }, + "m2.xlarge": { + "r": 4, + "s": 64 + }, + "d3.2xlarge": { + "r": 2, + "s": 51 + }, + "r5b.16xlarge": { + "r": 3, + "s": 47 + }, + "c5a.24xlarge": { + "r": 2, + "s": 27 + }, + "z1d.3xlarge": { + "r": 4, + "s": 47 + }, + "c5d.24xlarge": { + "r": 0, + "s": 34 + }, + "r5d.xlarge": { + "r": 0, + "s": 46 + }, + "c5n.9xlarge": { + "r": 1, + "s": 38 + }, + "c5d.18xlarge": { + "r": 0, + "s": 34 + }, + "m5ad.12xlarge": { + "r": 2, + "s": 35 + }, + "c4.2xlarge": { + "r": 0, + "s": 36 + }, + "r5dn.8xlarge": { + "r": 0, + "s": 51 + }, + "r4.8xlarge": { + "r": 0, + "s": 44 + }, + "m3.xlarge": { + "r": 0, + "s": 40 + }, + "m5dn.metal": { + "r": 4, + "s": 45 + }, + "m5d.2xlarge": { + "r": 1, + "s": 39 + }, + "c3.2xlarge": { + "r": 2, + "s": 40 + }, + "r5dn.large": { + "r": 1, + "s": 51 + } + }, + "Linux": { + "x2gd.16xlarge": { + "r": 1, + "s": 70 + }, + "r5a.large": { + "r": 0, + "s": 69 + }, + "a1.2xlarge": { + "r": 4, + "s": 67 + }, + "m2.xlarge": { + "r": 4, + "s": 90 + }, + "d3.2xlarge": { + "r": 2, + "s": 70 + }, + "m3.large": { + "r": 0, + "s": 77 + }, + "c5.metal": { + "r": 0, + "s": 62 + }, + "t3a.nano": { + "r": 0, + "s": 66 + }, + "c5ad.xlarge": { + "r": 0, + "s": 63 + }, + "m6g.2xlarge": { + "r": 0, + "s": 54 + }, + "m6g.metal": { + "r": 0, + "s": 54 + }, + "r3.8xlarge": { + "r": 0, + "s": 81 + }, + "r5b.xlarge": { + "r": 0, + "s": 76 + }, + "t3a.medium": { + "r": 0, + "s": 70 + }, + "c6gd.2xlarge": { + "r": 0, + "s": 56 + }, + "c5ad.8xlarge": { + "r": 1, + "s": 63 + }, + "r5n.large": { + "r": 1, + "s": 76 + }, + "m6gd.2xlarge": { + "r": 2, + "s": 61 + }, + "t2.micro": { + "r": 0, + "s": 70 + }, + "m5a.8xlarge": { + "r": 2, + "s": 60 + }, + "x2gd.8xlarge": { + "r": 1, + "s": 70 + }, + "m5n.2xlarge": { + "r": 0, + "s": 61 + }, + "c5n.4xlarge": { + "r": 4, + "s": 70 + }, + "c5ad.2xlarge": { + "r": 4, + "s": 63 + }, + "m5ad.8xlarge": { + "r": 1, + "s": 67 + }, + "t2.small": { + "r": 0, + "s": 70 + }, + "m2.4xlarge": { + "r": 4, + "s": 90 + }, + "r5d.4xlarge": { + "r": 1, + "s": 70 + }, + "r5d.xlarge": { + "r": 0, + "s": 75 + }, + "m5d.12xlarge": { + "r": 1, + "s": 70 + }, + "z1d.large": { + "r": 4, + "s": 70 + }, + "r5b.12xlarge": { + "r": 3, + "s": 76 + }, + "r4.xlarge": { + "r": 0, + "s": 71 + }, + "m6g.xlarge": { + "r": 0, + "s": 54 + }, + "r5b.large": { + "r": 0, + "s": 76 + }, + "m5ad.4xlarge": { + "r": 0, + "s": 65 + }, + "c5d.large": { + "r": 0, + "s": 66 + }, + "a1.metal": { + "r": 2, + "s": 67 + }, + "t4g.medium": { + "r": 0, + "s": 70 + }, + "c3.2xlarge": { + "r": 1, + "s": 72 + }, + "g4dn.16xlarge": { + "r": 0, + "s": 70 + }, + "c3.8xlarge": { + "r": 0, + "s": 72 + }, + "r4.2xlarge": { + "r": 0, + "s": 67 + }, + "m5a.2xlarge": { + "r": 0, + "s": 48 + }, + "c5d.metal": { + "r": 3, + "s": 66 + }, + "t4g.xlarge": { + "r": 0, + "s": 70 + }, + "r6gd.large": { + "r": 0, + "s": 68 + }, + "r6gd.8xlarge": { + "r": 4, + "s": 68 + }, + "m5zn.12xlarge": { + "r": 3, + "s": 80 + }, + "r5ad.large": { + "r": 1, + "s": 73 + }, + "m5n.8xlarge": { + "r": 0, + "s": 62 + }, + "m6gd.medium": { + "r": 0, + "s": 61 + }, + "m5dn.xlarge": { + "r": 0, + "s": 63 + }, + "inf1.6xlarge": { + "r": 4, + "s": 70 + }, + "d3en.xlarge": { + "r": 2, + "s": 70 + }, + "p2.16xlarge": { + "r": 4, + "s": 70 + }, + "x2gd.4xlarge": { + "r": 1, + "s": 70 + }, + "z1d.3xlarge": { + "r": 4, + "s": 70 + }, + "c5d.24xlarge": { + "r": 3, + "s": 65 + }, + "r5d.12xlarge": { + "r": 1, + "s": 75 + }, + "g4ad.16xlarge": { + "r": 2, + "s": 70 + }, + "m5d.metal": { + "r": 3, + "s": 70 + }, + "c5ad.large": { + "r": 2, + "s": 63 + }, + "g4dn.xlarge": { + "r": 4, + "s": 70 + }, + "c5.9xlarge": { + "r": 2, + "s": 62 + }, + "r5dn.metal": { + "r": 3, + "s": 79 + }, + "c5ad.12xlarge": { + "r": 0, + "s": 63 + }, + "i2.xlarge": { + "r": 4, + "s": 70 + }, + "r6gd.16xlarge": { + "r": 4, + "s": 68 + }, + "m6gd.xlarge": { + "r": 0, + "s": 61 + }, + "x2gd.large": { + "r": 1, + "s": 70 + }, + "r3.4xlarge": { + "r": 0, + "s": 80 + }, + "c5a.24xlarge": { + "r": 0, + "s": 58 + }, + "t3a.2xlarge": { + "r": 0, + "s": 70 + }, + "c6g.12xlarge": { + "r": 0, + "s": 50 + }, + "c5n.2xlarge": { + "r": 0, + "s": 69 + }, + "c5a.16xlarge": { + "r": 4, + "s": 58 + }, + "m5dn.12xlarge": { + "r": 3, + "s": 73 + }, + "m1.medium": { + "r": 4, + "s": 90 + }, + "m6g.12xlarge": { + "r": 0, + "s": 54 + }, + "d3en.12xlarge": { + "r": 3, + "s": 70 + }, + "g3.8xlarge": { + "r": 4, + "s": 70 + }, + "m6g.8xlarge": { + "r": 0, + "s": 54 + }, + "m5a.24xlarge": { + "r": 4, + "s": 61 + }, + "z1d.6xlarge": { + "r": 4, + "s": 70 + }, + "r5n.8xlarge": { + "r": 0, + "s": 67 + }, + "m5zn.metal": { + "r": 4, + "s": 80 + }, + "m5dn.8xlarge": { + "r": 1, + "s": 65 + }, + "r5a.8xlarge": { + "r": 1, + "s": 69 + }, + "t3.2xlarge": { + "r": 0, + "s": 70 + }, + "r6gd.metal": { + "r": 0, + "s": 68 + }, + "m5d.xlarge": { + "r": 0, + "s": 70 + }, + "c6g.metal": { + "r": 0, + "s": 50 + }, + "r5dn.8xlarge": { + "r": 0, + "s": 75 + }, + "m5dn.metal": { + "r": 2, + "s": 75 + }, + "m5d.2xlarge": { + "r": 0, + "s": 64 + }, + "c6gd.4xlarge": { + "r": 0, + "s": 56 + }, + "r6g.large": { + "r": 0, + "s": 63 + }, + "t3.nano": { + "r": 0, + "s": 69 + }, + "m5a.4xlarge": { + "r": 1, + "s": 52 + }, + "m4.10xlarge": { + "r": 2, + "s": 61 + }, + "c6gd.8xlarge": { + "r": 3, + "s": 56 + }, + "i3.2xlarge": { + "r": 4, + "s": 70 + }, + "r5d.metal": { + "r": 4, + "s": 75 + }, + "m6gd.4xlarge": { + "r": 0, + "s": 61 + }, + "i2.2xlarge": { + "r": 4, + "s": 70 + }, + "c6gd.medium": { + "r": 0, + "s": 56 + }, + "c6gn.12xlarge": { + "r": 0, + "s": 61 + }, + "i3en.24xlarge": { + "r": 0, + "s": 70 + }, + "x1e.2xlarge": { + "r": 4, + "s": 70 + }, + "i3en.xlarge": { + "r": 0, + "s": 70 + }, + "a1.large": { + "r": 4, + "s": 67 + }, + "c5n.18xlarge": { + "r": 0, + "s": 70 + }, + "g4ad.4xlarge": { + "r": 2, + "s": 70 + }, + "r5n.24xlarge": { + "r": 1, + "s": 76 + }, + "x1e.32xlarge": { + "r": 4, + "s": 70 + }, + "r5ad.xlarge": { + "r": 1, + "s": 73 + }, + "c6g.8xlarge": { + "r": 0, + "s": 50 + }, + "c5.4xlarge": { + "r": 1, + "s": 62 + }, + "h1.16xlarge": { + "r": 4, + "s": 70 + }, + "c6gn.large": { + "r": 0, + "s": 61 + }, + "c3.xlarge": { + "r": 1, + "s": 72 + }, + "m5dn.24xlarge": { + "r": 3, + "s": 75 + }, + "m5ad.xlarge": { + "r": 0, + "s": 67 + }, + "r6g.4xlarge": { + "r": 1, + "s": 63 + }, + "c4.large": { + "r": 0, + "s": 69 + }, + "d2.2xlarge": { + "r": 4, + "s": 70 + }, + "t3.large": { + "r": 0, + "s": 70 + }, + "m6gd.16xlarge": { + "r": 0, + "s": 61 + }, + "c6gd.metal": { + "r": 4, + "s": 56 + }, + "r5a.12xlarge": { + "r": 0, + "s": 69 + }, + "m4.16xlarge": { + "r": 2, + "s": 65 + }, + "c6gn.16xlarge": { + "r": 0, + "s": 61 + }, + "m5d.24xlarge": { + "r": 0, + "s": 70 + }, + "x2gd.2xlarge": { + "r": 1, + "s": 70 + }, + "m6g.4xlarge": { + "r": 0, + "s": 54 + }, + "t4g.small": { + "r": 0, + "s": 70 + }, + "t3a.micro": { + "r": 0, + "s": 70 + }, + "x1.32xlarge": { + "r": 4, + "s": 70 + }, + "inf1.2xlarge": { + "r": 4, + "s": 70 + }, + "m4.xlarge": { + "r": 1, + "s": 67 + }, + "c4.8xlarge": { + "r": 3, + "s": 65 + }, + "x1e.16xlarge": { + "r": 4, + "s": 70 + }, + "d2.4xlarge": { + "r": 4, + "s": 70 + }, + "d3en.4xlarge": { + "r": 2, + "s": 70 + }, + "r5n.2xlarge": { + "r": 0, + "s": 68 + }, + "m5n.12xlarge": { + "r": 4, + "s": 65 + }, + "r3.2xlarge": { + "r": 1, + "s": 79 + }, + "r5a.24xlarge": { + "r": 1, + "s": 69 + }, + "a1.medium": { + "r": 4, + "s": 67 + }, + "t4g.nano": { + "r": 0, + "s": 69 + }, + "r5d.2xlarge": { + "r": 1, + "s": 75 + }, + "r5dn.24xlarge": { + "r": 3, + "s": 79 + }, + "r6gd.medium": { + "r": 0, + "s": 68 + }, + "r5ad.4xlarge": { + "r": 0, + "s": 73 + }, + "m5.4xlarge": { + "r": 4, + "s": 60 + }, + "r6g.xlarge": { + "r": 3, + "s": 63 + }, + "h1.8xlarge": { + "r": 4, + "s": 68 + }, + "g4dn.2xlarge": { + "r": 4, + "s": 70 + }, + "m5zn.6xlarge": { + "r": 4, + "s": 80 + }, + "r5b.8xlarge": { + "r": 4, + "s": 76 + }, + "c5.large": { + "r": 0, + "s": 62 + }, + "m5ad.24xlarge": { + "r": 1, + "s": 67 + }, + "i3en.6xlarge": { + "r": 0, + "s": 70 + }, + "r5n.12xlarge": { + "r": 4, + "s": 72 + }, + "r5dn.16xlarge": { + "r": 3, + "s": 79 + }, + "m5zn.large": { + "r": 1, + "s": 80 + }, + "c5.xlarge": { + "r": 0, + "s": 62 + }, + "m5ad.2xlarge": { + "r": 0, + "s": 57 + }, + "i3.xlarge": { + "r": 4, + "s": 62 + }, + "g3.4xlarge": { + "r": 4, + "s": 70 + }, + "p2.xlarge": { + "r": 4, + "s": 70 + }, + "m1.xlarge": { + "r": 4, + "s": 90 + }, + "t3a.large": { + "r": 1, + "s": 70 + }, + "r5b.16xlarge": { + "r": 4, + "s": 76 + }, + "r5.16xlarge": { + "r": 3, + "s": 70 + }, + "r5.12xlarge": { + "r": 2, + "s": 72 + }, + "h1.4xlarge": { + "r": 4, + "s": 70 + }, + "c5.18xlarge": { + "r": 0, + "s": 62 + }, + "i2.8xlarge": { + "r": 4, + "s": 70 + }, + "x2gd.metal": { + "r": 1, + "s": 70 + }, + "r6g.16xlarge": { + "r": 0, + "s": 63 + }, + "r6gd.4xlarge": { + "r": 3, + "s": 68 + }, + "r5.8xlarge": { + "r": 1, + "s": 71 + }, + "c5n.metal": { + "r": 1, + "s": 70 + }, + "c6g.xlarge": { + "r": 0, + "s": 50 + }, + "m5zn.2xlarge": { + "r": 0, + "s": 78 + }, + "r5ad.16xlarge": { + "r": 1, + "s": 73 + }, + "p2.8xlarge": { + "r": 4, + "s": 70 + }, + "r4.4xlarge": { + "r": 0, + "s": 68 + }, + "t2.large": { + "r": 0, + "s": 70 + }, + "c5d.xlarge": { + "r": 0, + "s": 66 + }, + "g4ad.8xlarge": { + "r": 2, + "s": 70 + }, + "m5.24xlarge": { + "r": 3, + "s": 65 + }, + "z1d.metal": { + "r": 4, + "s": 70 + }, + "r5b.2xlarge": { + "r": 2, + "s": 71 + }, + "a1.xlarge": { + "r": 4, + "s": 67 + }, + "r5dn.2xlarge": { + "r": 0, + "s": 70 + }, + "m5zn.xlarge": { + "r": 1, + "s": 73 + }, + "r5dn.large": { + "r": 1, + "s": 79 + }, + "i3.4xlarge": { + "r": 4, + "s": 70 + } + } + } + }, + "global_rate": "<5%", + "instance_types": { + "x2gd.16xlarge": { + "ram_gb": 1024, + "cores": 64, + "emr": false + }, + "r5a.large": { + "ram_gb": 16, + "cores": 2, + "emr": false + }, + "a1.2xlarge": { + "ram_gb": 16, + "cores": 8, + "emr": false + }, + "m2.xlarge": { + "ram_gb": 17.1, + "cores": 2, + "emr": true + }, + "d3.2xlarge": { + "ram_gb": 64, + "cores": 8, + "emr": false + }, + "m3.large": { + "ram_gb": 7.5, + "cores": 2, + "emr": false + }, + "c5.metal": { + "ram_gb": 192, + "cores": 96, + "emr": false + }, + "t3a.nano": { + "ram_gb": 0.5, + "cores": 2, + "emr": false + }, + "r4.xlarge": { + "ram_gb": 30.5, + "cores": 4, + "emr": true + }, + "m6g.2xlarge": { + "ram_gb": 32, + "cores": 8, + "emr": true + }, + "m6g.metal": { + "ram_gb": 256, + "cores": 64, + "emr": false + }, + "inf1.6xlarge": { + "ram_gb": 48, + "cores": 24, + "emr": false + }, + "r5b.xlarge": { + "ram_gb": 32, + "cores": 4, + "emr": false + }, + "t3a.medium": { + "ram_gb": 4, + "cores": 2, + "emr": false + }, + "c6gd.2xlarge": { + "ram_gb": 16, + "cores": 8, + "emr": false + }, + "c5ad.8xlarge": { + "ram_gb": 64, + "cores": 32, + "emr": false + }, + "r5n.large": { + "ram_gb": 16, + "cores": 2, + "emr": false + }, + "m6gd.2xlarge": { + "ram_gb": 32, + "cores": 8, + "emr": false + }, + "t2.micro": { + "ram_gb": 1, + "cores": 1, + "emr": false + }, + "c5ad.4xlarge": { + "ram_gb": 32, + "cores": 16, + "emr": false + }, + "u-6tb1.56xlarge": { + "ram_gb": 6144, + "cores": 224, + "emr": false + }, + "r5d.8xlarge": { + "ram_gb": 256, + "cores": 32, + "emr": true + }, + "m5n.2xlarge": { + "ram_gb": 32, + "cores": 8, + "emr": false + }, + "c5n.4xlarge": { + "ram_gb": 42, + "cores": 16, + "emr": true + }, + "m5dn.16xlarge": { + "ram_gb": 256, + "cores": 64, + "emr": false + }, + "m5zn.3xlarge": { + "ram_gb": 48, + "cores": 12, + "emr": false + }, + "m5.metal": { + "ram_gb": 384, + "cores": 96, + "emr": false + }, + "c5n.xlarge": { + "ram_gb": 10.5, + "cores": 4, + "emr": true + }, + "m5a.xlarge": { + "ram_gb": 16, + "cores": 4, + "emr": true + }, + "u-9tb1.112xlarge": { + "ram_gb": 9216, + "cores": 448, + "emr": false + }, + "c5ad.16xlarge": { + "ram_gb": 128, + "cores": 64, + "emr": false + }, + "c6gn.2xlarge": { + "ram_gb": 16, + "cores": 8, + "emr": false + }, + "r4.large": { + "ram_gb": 15.25, + "cores": 2, + "emr": false + }, + "i3en.xlarge": { + "ram_gb": 32, + "cores": 4, + "emr": true + }, + "m5d.16xlarge": { + "ram_gb": 256, + "cores": 64, + "emr": true + }, + "d3.4xlarge": { + "ram_gb": 128, + "cores": 16, + "emr": false + }, + "c5ad.2xlarge": { + "ram_gb": 16, + "cores": 8, + "emr": false + }, + "inf1.xlarge": { + "ram_gb": 8, + "cores": 4, + "emr": false + }, + "m1.large": { + "ram_gb": 7.5, + "cores": 2, + "emr": true + }, + "i3en.large": { + "ram_gb": 16, + "cores": 2, + "emr": false + }, + "r5ad.2xlarge": { + "ram_gb": 64, + "cores": 8, + "emr": false + }, + "r6g.medium": { + "ram_gb": 8, + "cores": 1, + "emr": false + }, + "g3.16xlarge": { + "ram_gb": 488, + "cores": 64, + "emr": true + }, + "c4.xlarge": { + "ram_gb": 7.5, + "cores": 4, + "emr": true + }, + "t2.small": { + "ram_gb": 2, + "cores": 1, + "emr": false + }, + "m2.4xlarge": { + "ram_gb": 68.4, + "cores": 8, + "emr": true + }, + "r5d.4xlarge": { + "ram_gb": 128, + "cores": 16, + "emr": true + }, + "r5n.metal": { + "ram_gb": 768, + "cores": 96, + "emr": false + }, + "m3.medium": { + "ram_gb": 3.75, + "cores": 1, + "emr": false + }, + "z1d.large": { + "ram_gb": 16, + "cores": 2, + "emr": false + }, + "r5b.12xlarge": { + "ram_gb": 384, + "cores": 48, + "emr": false + }, + "c5ad.xlarge": { + "ram_gb": 8, + "cores": 4, + "emr": false + }, + "m6g.xlarge": { + "ram_gb": 16, + "cores": 4, + "emr": true + }, + "r5b.large": { + "ram_gb": 16, + "cores": 2, + "emr": false + }, + "c6gd.medium": { + "ram_gb": 2, + "cores": 1, + "emr": false + }, + "c5d.large": { + "ram_gb": 4, + "cores": 2, + "emr": false + }, + "a1.metal": { + "ram_gb": 32, + "cores": 16, + "emr": false + }, + "c3.2xlarge": { + "ram_gb": 15, + "cores": 8, + "emr": true + }, + "g4dn.16xlarge": { + "ram_gb": 256, + "cores": 64, + "emr": false + }, + "c3.8xlarge": { + "ram_gb": 60, + "cores": 32, + "emr": true + }, + "r4.2xlarge": { + "ram_gb": 61, + "cores": 8, + "emr": true + }, + "m5a.2xlarge": { + "ram_gb": 32, + "cores": 8, + "emr": true + }, + "c5d.metal": { + "ram_gb": 192, + "cores": 96, + "emr": false + }, + "t4g.xlarge": { + "ram_gb": 16, + "cores": 4, + "emr": false + }, + "r6gd.large": { + "ram_gb": 16, + "cores": 2, + "emr": false + }, + "r6gd.8xlarge": { + "ram_gb": 256, + "cores": 32, + "emr": false + }, + "m3.2xlarge": { + "ram_gb": 30, + "cores": 8, + "emr": true + }, + "m5d.large": { + "ram_gb": 8, + "cores": 2, + "emr": false + }, + "c6g.16xlarge": { + "ram_gb": 128, + "cores": 64, + "emr": false + }, + "r3.8xlarge": { + "ram_gb": 244, + "cores": 32, + "emr": true + }, + "c6gd.12xlarge": { + "ram_gb": 96, + "cores": 48, + "emr": false + }, + "r5b.24xlarge": { + "ram_gb": 768, + "cores": 96, + "emr": false + }, + "t4g.micro": { + "ram_gb": 1, + "cores": 2, + "emr": false + }, + "m5zn.12xlarge": { + "ram_gb": 192, + "cores": 48, + "emr": false + }, + "d3en.8xlarge": { + "ram_gb": 128, + "cores": 32, + "emr": false + }, + "g4dn.8xlarge": { + "ram_gb": 128, + "cores": 32, + "emr": false + }, + "c6gd.16xlarge": { + "ram_gb": 128, + "cores": 64, + "emr": false + }, + "x1e.8xlarge": { + "ram_gb": 976, + "cores": 32, + "emr": false + }, + "r5ad.large": { + "ram_gb": 16, + "cores": 2, + "emr": false + }, + "m5n.8xlarge": { + "ram_gb": 128, + "cores": 32, + "emr": false + }, + "r5n.xlarge": { + "ram_gb": 32, + "cores": 4, + "emr": false + }, + "m5dn.xlarge": { + "ram_gb": 16, + "cores": 4, + "emr": false + }, + "d3en.xlarge": { + "ram_gb": 16, + "cores": 4, + "emr": false + }, + "p2.16xlarge": { + "ram_gb": 732, + "cores": 64, + "emr": true + }, + "c6gd.xlarge": { + "ram_gb": 8, + "cores": 4, + "emr": false + }, + "r6g.8xlarge": { + "ram_gb": 256, + "cores": 32, + "emr": true + }, + "r5d.12xlarge": { + "ram_gb": 384, + "cores": 48, + "emr": true + }, + "g4ad.16xlarge": { + "ram_gb": 256, + "cores": 64, + "emr": false + }, + "m5d.metal": { + "ram_gb": 384, + "cores": 96, + "emr": false + }, + "c5ad.large": { + "ram_gb": 4, + "cores": 2, + "emr": false + }, + "g4dn.xlarge": { + "ram_gb": 16, + "cores": 4, + "emr": false + }, + "c5.9xlarge": { + "ram_gb": 72, + "cores": 36, + "emr": true + }, + "r5dn.metal": { + "ram_gb": 768, + "cores": 96, + "emr": false + }, + "c5ad.12xlarge": { + "ram_gb": 96, + "cores": 48, + "emr": false + }, + "i2.xlarge": { + "ram_gb": 30.5, + "cores": 4, + "emr": true + }, + "r5dn.large": { + "ram_gb": 16, + "cores": 2, + "emr": false + }, + "r6gd.16xlarge": { + "ram_gb": 512, + "cores": 64, + "emr": false + }, + "m6gd.xlarge": { + "ram_gb": 16, + "cores": 4, + "emr": false + }, + "x2gd.large": { + "ram_gb": 32, + "cores": 2, + "emr": false + }, + "c5a.24xlarge": { + "ram_gb": 192, + "cores": 96, + "emr": false + }, + "i3.16xlarge": { + "ram_gb": 488, + "cores": 64, + "emr": true + }, + "x1e.4xlarge": { + "ram_gb": 488, + "cores": 16, + "emr": false + }, + "t3a.xlarge": { + "ram_gb": 16, + "cores": 4, + "emr": false + }, + "r6g.2xlarge": { + "ram_gb": 64, + "cores": 8, + "emr": false + }, + "r5a.xlarge": { + "ram_gb": 32, + "cores": 4, + "emr": true + }, + "i3en.metal": { + "ram_gb": 768, + "cores": 96, + "emr": false + }, + "g2.2xlarge": { + "ram_gb": 15, + "cores": 8, + "emr": true + }, + "m5d.4xlarge": { + "ram_gb": 64, + "cores": 16, + "emr": true + }, + "d3.8xlarge": { + "ram_gb": 256, + "cores": 32, + "emr": false + }, + "m2.2xlarge": { + "ram_gb": 34.2, + "cores": 4, + "emr": true + }, + "i3en.6xlarge": { + "ram_gb": 192, + "cores": 24, + "emr": true + }, + "t2.2xlarge": { + "ram_gb": 32, + "cores": 8, + "emr": false + }, + "c6gd.large": { + "ram_gb": 4, + "cores": 2, + "emr": false + }, + "c5.4xlarge": { + "ram_gb": 32, + "cores": 16, + "emr": true + }, + "m6gd.8xlarge": { + "ram_gb": 128, + "cores": 32, + "emr": false + }, + "m5n.4xlarge": { + "ram_gb": 64, + "cores": 16, + "emr": false + }, + "d2.8xlarge": { + "ram_gb": 244, + "cores": 36, + "emr": true + }, + "z1d.xlarge": { + "ram_gb": 32, + "cores": 4, + "emr": false + }, + "m5dn.2xlarge": { + "ram_gb": 32, + "cores": 8, + "emr": false + }, + "r5n.24xlarge": { + "ram_gb": 768, + "cores": 96, + "emr": false + }, + "r5ad.12xlarge": { + "ram_gb": 384, + "cores": 48, + "emr": false + }, + "m3.xlarge": { + "ram_gb": 15, + "cores": 4, + "emr": true + }, + "m5n.24xlarge": { + "ram_gb": 384, + "cores": 96, + "emr": false + }, + "p4d.24xlarge": { + "ram_gb": 1152, + "cores": 96, + "emr": false + }, + "t2.nano": { + "ram_gb": 0.5, + "cores": 1, + "emr": false + }, + "c5ad.24xlarge": { + "ram_gb": 192, + "cores": 96, + "emr": false + }, + "c3.large": { + "ram_gb": 3.75, + "cores": 2, + "emr": false + }, + "m6gd.medium": { + "ram_gb": 4, + "cores": 1, + "emr": false + }, + "t3a.2xlarge": { + "ram_gb": 32, + "cores": 8, + "emr": false + }, + "c6g.12xlarge": { + "ram_gb": 96, + "cores": 48, + "emr": false + }, + "c5n.2xlarge": { + "ram_gb": 21, + "cores": 8, + "emr": true + }, + "c5a.16xlarge": { + "ram_gb": 128, + "cores": 64, + "emr": false + }, + "f1.2xlarge": { + "ram_gb": 122, + "cores": 8, + "emr": false + }, + "d3.xlarge": { + "ram_gb": 32, + "cores": 4, + "emr": false + }, + "r5b.metal": { + "ram_gb": 768, + "cores": 96, + "emr": false + }, + "i3.large": { + "ram_gb": 15.25, + "cores": 2, + "emr": true + }, + "m5n.16xlarge": { + "ram_gb": 256, + "cores": 64, + "emr": false + }, + "m5dn.12xlarge": { + "ram_gb": 192, + "cores": 48, + "emr": false + }, + "m1.medium": { + "ram_gb": 3.75, + "cores": 1, + "emr": true + }, + "m6g.12xlarge": { + "ram_gb": 192, + "cores": 48, + "emr": true + }, + "d3en.12xlarge": { + "ram_gb": 192, + "cores": 48, + "emr": false + }, + "g3.8xlarge": { + "ram_gb": 244, + "cores": 32, + "emr": true + }, + "m6g.8xlarge": { + "ram_gb": 128, + "cores": 32, + "emr": true + }, + "m5a.24xlarge": { + "ram_gb": 384, + "cores": 96, + "emr": true + }, + "z1d.6xlarge": { + "ram_gb": 192, + "cores": 24, + "emr": false + }, + "r5n.8xlarge": { + "ram_gb": 256, + "cores": 32, + "emr": false + }, + "m5zn.metal": { + "ram_gb": 192, + "cores": 48, + "emr": false + }, + "t4g.medium": { + "ram_gb": 4, + "cores": 2, + "emr": false + }, + "u-12tb1.112xlarge": { + "ram_gb": 12288, + "cores": 448, + "emr": false + }, + "r5a.8xlarge": { + "ram_gb": 256, + "cores": 32, + "emr": true + }, + "t3.2xlarge": { + "ram_gb": 32, + "cores": 8, + "emr": false + }, + "r6gd.metal": { + "ram_gb": 512, + "cores": 64, + "emr": false + }, + "m5d.xlarge": { + "ram_gb": 16, + "cores": 4, + "emr": true + }, + "m5.12xlarge": { + "ram_gb": 192, + "cores": 48, + "emr": true + }, + "r5.xlarge": { + "ram_gb": 32, + "cores": 4, + "emr": true + }, + "z1d.3xlarge": { + "ram_gb": 96, + "cores": 12, + "emr": false + }, + "c5n.large": { + "ram_gb": 5.25, + "cores": 2, + "emr": false + }, + "m5dn.8xlarge": { + "ram_gb": 128, + "cores": 32, + "emr": false + }, + "m5dn.large": { + "ram_gb": 8, + "cores": 2, + "emr": false + }, + "c6g.large": { + "ram_gb": 4, + "cores": 2, + "emr": false + }, + "m1.xlarge": { + "ram_gb": 15, + "cores": 4, + "emr": true + }, + "m5.2xlarge": { + "ram_gb": 32, + "cores": 8, + "emr": true + }, + "r4.16xlarge": { + "ram_gb": 488, + "cores": 64, + "emr": true + }, + "z1d.12xlarge": { + "ram_gb": 384, + "cores": 48, + "emr": false + }, + "c6g.4xlarge": { + "ram_gb": 32, + "cores": 16, + "emr": false + }, + "t1.micro": { + "ram_gb": 0.61, + "cores": 1, + "emr": false + }, + "u-6tb1.112xlarge": { + "ram_gb": 6144, + "cores": 448, + "emr": false + }, + "c1.medium": { + "ram_gb": 1.7, + "cores": 2, + "emr": true + }, + "f1.4xlarge": { + "ram_gb": 244, + "cores": 16, + "emr": false + }, + "t3.micro": { + "ram_gb": 1, + "cores": 2, + "emr": false + }, + "x2gd.12xlarge": { + "ram_gb": 768, + "cores": 48, + "emr": false + }, + "g4dn.metal": { + "ram_gb": 384, + "cores": 96, + "emr": false + }, + "c6g.2xlarge": { + "ram_gb": 16, + "cores": 8, + "emr": false + }, + "x2gd.xlarge": { + "ram_gb": 64, + "cores": 4, + "emr": false + }, + "m5dn.4xlarge": { + "ram_gb": 64, + "cores": 16, + "emr": false + }, + "x1.16xlarge": { + "ram_gb": 976, + "cores": 64, + "emr": false + }, + "t3.medium": { + "ram_gb": 4, + "cores": 2, + "emr": false + }, + "m6g.16xlarge": { + "ram_gb": 256, + "cores": 64, + "emr": true + }, + "r3.large": { + "ram_gb": 15.25, + "cores": 2, + "emr": false + }, + "r5d.xlarge": { + "ram_gb": 32, + "cores": 4, + "emr": true + }, + "c5.12xlarge": { + "ram_gb": 96, + "cores": 48, + "emr": true + }, + "r6gd.12xlarge": { + "ram_gb": 384, + "cores": 48, + "emr": false + }, + "c5d.12xlarge": { + "ram_gb": 96, + "cores": 48, + "emr": true + }, + "i3en.3xlarge": { + "ram_gb": 96, + "cores": 12, + "emr": true + }, + "r4.8xlarge": { + "ram_gb": 244, + "cores": 32, + "emr": true + }, + "t4g.2xlarge": { + "ram_gb": 32, + "cores": 8, + "emr": false + }, + "g4dn.12xlarge": { + "ram_gb": 192, + "cores": 48, + "emr": false + }, + "m5.8xlarge": { + "ram_gb": 128, + "cores": 32, + "emr": false + }, + "c6gn.large": { + "ram_gb": 4, + "cores": 2, + "emr": false + }, + "c5.18xlarge": { + "ram_gb": 144, + "cores": 72, + "emr": true + }, + "m5a.16xlarge": { + "ram_gb": 256, + "cores": 64, + "emr": true + }, + "r5ad.24xlarge": { + "ram_gb": 768, + "cores": 96, + "emr": false + }, + "i3.metal": { + "ram_gb": 512, + "cores": 64, + "emr": false + }, + "c5a.12xlarge": { + "ram_gb": 96, + "cores": 48, + "emr": false + }, + "p3.16xlarge": { + "ram_gb": 488, + "cores": 64, + "emr": true + }, + "r5.12xlarge": { + "ram_gb": 384, + "cores": 48, + "emr": true + }, + "t4g.large": { + "ram_gb": 8, + "cores": 2, + "emr": false + }, + "m4.2xlarge": { + "ram_gb": 32, + "cores": 8, + "emr": true + }, + "c6g.metal": { + "ram_gb": 128, + "cores": 64, + "emr": false + }, + "r5dn.8xlarge": { + "ram_gb": 256, + "cores": 32, + "emr": false + }, + "m5dn.metal": { + "ram_gb": 384, + "cores": 96, + "emr": false + }, + "c6gd.4xlarge": { + "ram_gb": 32, + "cores": 16, + "emr": false + }, + "m5ad.8xlarge": { + "ram_gb": 128, + "cores": 32, + "emr": false + }, + "t3.nano": { + "ram_gb": 0.5, + "cores": 2, + "emr": false + }, + "m5a.4xlarge": { + "ram_gb": 64, + "cores": 16, + "emr": true + }, + "m4.10xlarge": { + "ram_gb": 160, + "cores": 40, + "emr": true + }, + "c6gd.8xlarge": { + "ram_gb": 64, + "cores": 32, + "emr": false + }, + "i3.2xlarge": { + "ram_gb": 61, + "cores": 8, + "emr": true + }, + "r5d.metal": { + "ram_gb": 768, + "cores": 96, + "emr": false + }, + "m6gd.4xlarge": { + "ram_gb": 64, + "cores": 16, + "emr": false + }, + "i2.2xlarge": { + "ram_gb": 61, + "cores": 8, + "emr": true + }, + "m5ad.4xlarge": { + "ram_gb": 64, + "cores": 16, + "emr": false + }, + "c6gn.12xlarge": { + "ram_gb": 96, + "cores": 48, + "emr": false + }, + "i3en.24xlarge": { + "ram_gb": 768, + "cores": 96, + "emr": true + }, + "x1e.2xlarge": { + "ram_gb": 244, + "cores": 8, + "emr": false + }, + "r6g.metal": { + "ram_gb": 512, + "cores": 64, + "emr": false + }, + "a1.large": { + "ram_gb": 4, + "cores": 2, + "emr": false + }, + "c5n.18xlarge": { + "ram_gb": 192, + "cores": 72, + "emr": true + }, + "g4ad.4xlarge": { + "ram_gb": 64, + "cores": 16, + "emr": false + }, + "d2.xlarge": { + "ram_gb": 30.5, + "cores": 4, + "emr": true + }, + "x1e.32xlarge": { + "ram_gb": 3904, + "cores": 128, + "emr": false + }, + "r5ad.xlarge": { + "ram_gb": 32, + "cores": 4, + "emr": false + }, + "c6g.8xlarge": { + "ram_gb": 64, + "cores": 32, + "emr": false + }, + "m4.large": { + "ram_gb": 8, + "cores": 2, + "emr": true + }, + "m5a.8xlarge": { + "ram_gb": 128, + "cores": 32, + "emr": false + }, + "h1.16xlarge": { + "ram_gb": 256, + "cores": 64, + "emr": true + }, + "c3.xlarge": { + "ram_gb": 7.5, + "cores": 4, + "emr": true + }, + "m5dn.24xlarge": { + "ram_gb": 384, + "cores": 96, + "emr": false + }, + "m5ad.xlarge": { + "ram_gb": 16, + "cores": 4, + "emr": false + }, + "r6g.4xlarge": { + "ram_gb": 128, + "cores": 16, + "emr": false + }, + "c4.large": { + "ram_gb": 3.75, + "cores": 2, + "emr": true + }, + "d2.2xlarge": { + "ram_gb": 61, + "cores": 8, + "emr": true + }, + "r5ad.16xlarge": { + "ram_gb": 512, + "cores": 64, + "emr": false + }, + "m6gd.16xlarge": { + "ram_gb": 256, + "cores": 64, + "emr": false + }, + "c6gd.metal": { + "ram_gb": 128, + "cores": 64, + "emr": false + }, + "x2gd.4xlarge": { + "ram_gb": 256, + "cores": 16, + "emr": false + }, + "r5a.12xlarge": { + "ram_gb": 384, + "cores": 48, + "emr": true + }, + "m4.16xlarge": { + "ram_gb": 256, + "cores": 64, + "emr": true + }, + "c6gn.16xlarge": { + "ram_gb": 128, + "cores": 64, + "emr": false + }, + "m5d.24xlarge": { + "ram_gb": 384, + "cores": 96, + "emr": true + }, + "x2gd.2xlarge": { + "ram_gb": 128, + "cores": 8, + "emr": false + }, + "m6g.4xlarge": { + "ram_gb": 64, + "cores": 16, + "emr": true + }, + "t4g.small": { + "ram_gb": 2, + "cores": 2, + "emr": false + }, + "t3a.micro": { + "ram_gb": 1, + "cores": 2, + "emr": false + }, + "c5.24xlarge": { + "ram_gb": 192, + "cores": 96, + "emr": false + }, + "t3a.small": { + "ram_gb": 2, + "cores": 2, + "emr": false + }, + "c5d.2xlarge": { + "ram_gb": 16, + "cores": 8, + "emr": true + }, + "d3en.2xlarge": { + "ram_gb": 32, + "cores": 8, + "emr": false + }, + "g3s.xlarge": { + "ram_gb": 30.5, + "cores": 4, + "emr": true + }, + "c5a.large": { + "ram_gb": 4, + "cores": 2, + "emr": false + }, + "c6gn.xlarge": { + "ram_gb": 8, + "cores": 4, + "emr": false + }, + "c5d.4xlarge": { + "ram_gb": 32, + "cores": 16, + "emr": true + }, + "p3dn.24xlarge": { + "ram_gb": 768, + "cores": 96, + "emr": false + }, + "c5a.2xlarge": { + "ram_gb": 16, + "cores": 8, + "emr": false + }, + "r5dn.4xlarge": { + "ram_gb": 128, + "cores": 16, + "emr": false + }, + "m1.small": { + "ram_gb": 1.7, + "cores": 1, + "emr": true + }, + "x1.32xlarge": { + "ram_gb": 1952, + "cores": 128, + "emr": false + }, + "inf1.2xlarge": { + "ram_gb": 16, + "cores": 8, + "emr": false + }, + "m4.xlarge": { + "ram_gb": 16, + "cores": 4, + "emr": true + }, + "c4.8xlarge": { + "ram_gb": 60, + "cores": 36, + "emr": true + }, + "x1e.16xlarge": { + "ram_gb": 1952, + "cores": 64, + "emr": false + }, + "d2.4xlarge": { + "ram_gb": 122, + "cores": 16, + "emr": true + }, + "d3en.4xlarge": { + "ram_gb": 64, + "cores": 16, + "emr": false + }, + "z1d.2xlarge": { + "ram_gb": 64, + "cores": 8, + "emr": false + }, + "m5n.12xlarge": { + "ram_gb": 192, + "cores": 48, + "emr": false + }, + "r3.2xlarge": { + "ram_gb": 61, + "cores": 8, + "emr": true + }, + "r5a.24xlarge": { + "ram_gb": 768, + "cores": 96, + "emr": true + }, + "x2gd.8xlarge": { + "ram_gb": 512, + "cores": 32, + "emr": false + }, + "a1.medium": { + "ram_gb": 2, + "cores": 1, + "emr": false + }, + "t4g.nano": { + "ram_gb": 0.5, + "cores": 2, + "emr": false + }, + "r5d.2xlarge": { + "ram_gb": 64, + "cores": 8, + "emr": true + }, + "r5dn.24xlarge": { + "ram_gb": 768, + "cores": 96, + "emr": false + }, + "r3.4xlarge": { + "ram_gb": 122, + "cores": 16, + "emr": true + }, + "r5ad.4xlarge": { + "ram_gb": 128, + "cores": 16, + "emr": false + }, + "m5.4xlarge": { + "ram_gb": 64, + "cores": 16, + "emr": true + }, + "r6g.xlarge": { + "ram_gb": 32, + "cores": 4, + "emr": true + }, + "h1.8xlarge": { + "ram_gb": 128, + "cores": 32, + "emr": true + }, + "hs1.8xlarge": { + "ram_gb": 117, + "cores": 16, + "emr": false + }, + "m5zn.6xlarge": { + "ram_gb": 96, + "cores": 24, + "emr": false + }, + "cr1.8xlarge": { + "ram_gb": 244, + "cores": 32, + "emr": false + }, + "c5a.4xlarge": { + "ram_gb": 32, + "cores": 16, + "emr": false + }, + "c5.large": { + "ram_gb": 4, + "cores": 2, + "emr": false + }, + "m5ad.24xlarge": { + "ram_gb": 384, + "cores": 96, + "emr": false + }, + "r3.xlarge": { + "ram_gb": 30.5, + "cores": 4, + "emr": true + }, + "r5n.12xlarge": { + "ram_gb": 384, + "cores": 48, + "emr": false + }, + "m5zn.large": { + "ram_gb": 8, + "cores": 2, + "emr": false + }, + "c5.xlarge": { + "ram_gb": 8, + "cores": 4, + "emr": true + }, + "m5ad.2xlarge": { + "ram_gb": 32, + "cores": 8, + "emr": false + }, + "i3.xlarge": { + "ram_gb": 30.5, + "cores": 4, + "emr": true + }, + "g3.4xlarge": { + "ram_gb": 122, + "cores": 16, + "emr": true + }, + "c5.2xlarge": { + "ram_gb": 16, + "cores": 8, + "emr": true + }, + "i3en.2xlarge": { + "ram_gb": 64, + "cores": 8, + "emr": true + }, + "c6gn.4xlarge": { + "ram_gb": 32, + "cores": 16, + "emr": false + }, + "c1.xlarge": { + "ram_gb": 7, + "cores": 8, + "emr": true + }, + "m5d.12xlarge": { + "ram_gb": 192, + "cores": 48, + "emr": true + }, + "r5ad.8xlarge": { + "ram_gb": 256, + "cores": 32, + "emr": false + }, + "m5n.metal": { + "ram_gb": 384, + "cores": 96, + "emr": false + }, + "h1.2xlarge": { + "ram_gb": 32, + "cores": 8, + "emr": true + }, + "h1.4xlarge": { + "ram_gb": 64, + "cores": 16, + "emr": true + }, + "c4.2xlarge": { + "ram_gb": 15, + "cores": 8, + "emr": true + }, + "z1d.metal": { + "ram_gb": 384, + "cores": 48, + "emr": false + }, + "g2.8xlarge": { + "ram_gb": 60, + "cores": 32, + "emr": false + }, + "r5d.large": { + "ram_gb": 16, + "cores": 2, + "emr": false + }, + "m5a.large": { + "ram_gb": 8, + "cores": 2, + "emr": false + }, + "i3en.12xlarge": { + "ram_gb": 384, + "cores": 48, + "emr": true + }, + "m5.16xlarge": { + "ram_gb": 256, + "cores": 64, + "emr": true + }, + "r5d.16xlarge": { + "ram_gb": 512, + "cores": 64, + "emr": true + }, + "m5.xlarge": { + "ram_gb": 16, + "cores": 4, + "emr": true + }, + "m5a.12xlarge": { + "ram_gb": 192, + "cores": 48, + "emr": true + }, + "r5n.4xlarge": { + "ram_gb": 128, + "cores": 16, + "emr": false + }, + "inf1.24xlarge": { + "ram_gb": 192, + "cores": 96, + "emr": false + }, + "r5d.24xlarge": { + "ram_gb": 768, + "cores": 96, + "emr": true + }, + "r5.4xlarge": { + "ram_gb": 128, + "cores": 16, + "emr": true + }, + "r6gd.medium": { + "ram_gb": 8, + "cores": 1, + "emr": false + }, + "x2gd.medium": { + "ram_gb": 16, + "cores": 1, + "emr": false + }, + "r5b.8xlarge": { + "ram_gb": 256, + "cores": 32, + "emr": false + }, + "r5a.2xlarge": { + "ram_gb": 64, + "cores": 8, + "emr": true + }, + "m5d.8xlarge": { + "ram_gb": 128, + "cores": 32, + "emr": false + }, + "r5a.4xlarge": { + "ram_gb": 128, + "cores": 16, + "emr": true + }, + "r5dn.xlarge": { + "ram_gb": 32, + "cores": 4, + "emr": false + }, + "c6gn.medium": { + "ram_gb": 2, + "cores": 1, + "emr": false + }, + "m4.4xlarge": { + "ram_gb": 64, + "cores": 16, + "emr": true + }, + "m5ad.large": { + "ram_gb": 8, + "cores": 2, + "emr": false + }, + "c5n.9xlarge": { + "ram_gb": 96, + "cores": 36, + "emr": true + }, + "i3.8xlarge": { + "ram_gb": 244, + "cores": 32, + "emr": true + }, + "c6gn.8xlarge": { + "ram_gb": 64, + "cores": 32, + "emr": false + }, + "c6g.medium": { + "ram_gb": 2, + "cores": 1, + "emr": false + }, + "m5.large": { + "ram_gb": 8, + "cores": 2, + "emr": false + }, + "r5.2xlarge": { + "ram_gb": 64, + "cores": 8, + "emr": true + }, + "g4dn.4xlarge": { + "ram_gb": 64, + "cores": 16, + "emr": false + }, + "t2.xlarge": { + "ram_gb": 16, + "cores": 4, + "emr": false + }, + "c4.4xlarge": { + "ram_gb": 30, + "cores": 16, + "emr": true + }, + "t2.medium": { + "ram_gb": 4, + "cores": 2, + "emr": false + }, + "x1e.xlarge": { + "ram_gb": 122, + "cores": 4, + "emr": false + }, + "m6gd.metal": { + "ram_gb": 256, + "cores": 64, + "emr": false + }, + "r6gd.2xlarge": { + "ram_gb": 64, + "cores": 8, + "emr": false + }, + "t3.xlarge": { + "ram_gb": 16, + "cores": 4, + "emr": false + }, + "c5a.xlarge": { + "ram_gb": 8, + "cores": 4, + "emr": false + }, + "c5a.8xlarge": { + "ram_gb": 64, + "cores": 32, + "emr": false + }, + "d3en.6xlarge": { + "ram_gb": 96, + "cores": 24, + "emr": false + }, + "m5.24xlarge": { + "ram_gb": 384, + "cores": 96, + "emr": true + }, + "r5n.16xlarge": { + "ram_gb": 512, + "cores": 64, + "emr": false + }, + "r5.24xlarge": { + "ram_gb": 768, + "cores": 96, + "emr": true + }, + "m6g.medium": { + "ram_gb": 4, + "cores": 1, + "emr": false + }, + "c3.4xlarge": { + "ram_gb": 30, + "cores": 16, + "emr": true + }, + "m5n.large": { + "ram_gb": 8, + "cores": 2, + "emr": false + }, + "r5b.4xlarge": { + "ram_gb": 128, + "cores": 16, + "emr": false + }, + "c5d.9xlarge": { + "ram_gb": 72, + "cores": 36, + "emr": true + }, + "r5a.16xlarge": { + "ram_gb": 512, + "cores": 64, + "emr": true + }, + "m5ad.16xlarge": { + "ram_gb": 256, + "cores": 64, + "emr": false + }, + "r5.large": { + "ram_gb": 16, + "cores": 2, + "emr": false + }, + "m5n.xlarge": { + "ram_gb": 16, + "cores": 4, + "emr": false + }, + "m6g.large": { + "ram_gb": 8, + "cores": 2, + "emr": false + }, + "r5n.2xlarge": { + "ram_gb": 64, + "cores": 8, + "emr": false + }, + "a1.4xlarge": { + "ram_gb": 32, + "cores": 16, + "emr": false + }, + "m5d.2xlarge": { + "ram_gb": 32, + "cores": 8, + "emr": true + }, + "r6g.12xlarge": { + "ram_gb": 384, + "cores": 48, + "emr": true + }, + "p2.xlarge": { + "ram_gb": 61, + "cores": 4, + "emr": true + }, + "c5d.24xlarge": { + "ram_gb": 192, + "cores": 96, + "emr": false + }, + "t3a.large": { + "ram_gb": 8, + "cores": 2, + "emr": false + }, + "r5b.16xlarge": { + "ram_gb": 512, + "cores": 64, + "emr": false + }, + "r5.16xlarge": { + "ram_gb": 512, + "cores": 64, + "emr": true + }, + "g4ad.8xlarge": { + "ram_gb": 128, + "cores": 32, + "emr": false + }, + "g4dn.2xlarge": { + "ram_gb": 32, + "cores": 8, + "emr": false + }, + "c5d.18xlarge": { + "ram_gb": 144, + "cores": 72, + "emr": true + }, + "i2.8xlarge": { + "ram_gb": 244, + "cores": 32, + "emr": true + }, + "x2gd.metal": { + "ram_gb": 1024, + "cores": 64, + "emr": false + }, + "r6g.16xlarge": { + "ram_gb": 512, + "cores": 64, + "emr": true + }, + "r6gd.4xlarge": { + "ram_gb": 128, + "cores": 16, + "emr": false + }, + "r5.8xlarge": { + "ram_gb": 256, + "cores": 32, + "emr": true + }, + "c5n.metal": { + "ram_gb": 192, + "cores": 72, + "emr": false + }, + "p3.8xlarge": { + "ram_gb": 244, + "cores": 32, + "emr": true + }, + "c6g.xlarge": { + "ram_gb": 8, + "cores": 4, + "emr": false + }, + "m5zn.2xlarge": { + "ram_gb": 32, + "cores": 8, + "emr": false + }, + "t3.large": { + "ram_gb": 8, + "cores": 2, + "emr": false + }, + "p2.8xlarge": { + "ram_gb": 488, + "cores": 32, + "emr": true + }, + "r4.4xlarge": { + "ram_gb": 122, + "cores": 16, + "emr": true + }, + "t2.large": { + "ram_gb": 8, + "cores": 2, + "emr": false + }, + "c5d.xlarge": { + "ram_gb": 8, + "cores": 4, + "emr": true + }, + "r5dn.16xlarge": { + "ram_gb": 512, + "cores": 64, + "emr": false + }, + "m6gd.large": { + "ram_gb": 8, + "cores": 2, + "emr": false + }, + "r5b.2xlarge": { + "ram_gb": 64, + "cores": 8, + "emr": false + }, + "a1.xlarge": { + "ram_gb": 8, + "cores": 4, + "emr": false + }, + "m6gd.12xlarge": { + "ram_gb": 192, + "cores": 48, + "emr": false + }, + "p3.2xlarge": { + "ram_gb": 61, + "cores": 8, + "emr": true + }, + "r5.metal": { + "ram_gb": 768, + "cores": 96, + "emr": false + }, + "m5ad.12xlarge": { + "ram_gb": 192, + "cores": 48, + "emr": false + }, + "i2.4xlarge": { + "ram_gb": 122, + "cores": 16, + "emr": true + }, + "r6gd.xlarge": { + "ram_gb": 32, + "cores": 4, + "emr": false + }, + "r5dn.12xlarge": { + "ram_gb": 384, + "cores": 48, + "emr": false + }, + "t3.small": { + "ram_gb": 2, + "cores": 2, + "emr": false + }, + "f1.16xlarge": { + "ram_gb": 976, + "cores": 64, + "emr": false + }, + "r5dn.2xlarge": { + "ram_gb": 64, + "cores": 8, + "emr": false + }, + "m5zn.xlarge": { + "ram_gb": 16, + "cores": 4, + "emr": false + }, + "r6g.large": { + "ram_gb": 16, + "cores": 2, + "emr": false + }, + "i3.4xlarge": { + "ram_gb": 122, + "cores": 16, + "emr": true + } + }, + "ranges": [ + { + "label": "<5%", + "index": 0, + "max": 5, + "dots": 0 + }, + { + "label": "5-10%", + "index": 1, + "max": 11, + "dots": 1 + }, + { + "label": "10-15%", + "index": 2, + "max": 16, + "dots": 2 + }, + { + "label": "15-20%", + "index": 3, + "max": 22, + "dots": 3 + }, + { + "label": ">20%", + "index": 4, + "max": 100, + "dots": 4 + } + ] +} diff --git a/cloud/spotfeed/parser.go b/cloud/spotfeed/parser.go new file mode 100644 index 00000000..bf128df6 --- /dev/null +++ b/cloud/spotfeed/parser.go @@ -0,0 +1,248 @@ +package spotfeed + +import ( + "bufio" + "fmt" + "io" + "regexp" + "strconv" + "strings" + "time" + + "github.com/grailbio/base/errors" +) + +const ( + feedFileTimestampFormat = "2006-01-02-15" +) + +var ( + feedFileNamePattern = regexp.MustCompile(`^[0-9]{12}\.[0-9]{4}(\-[0-9]{2}){3}\.[0-9]{3}.[a-z0-9]{8}(\.gz)?$`) +) + +type fileMeta struct { + filterable + + Name string + AccountId string + Timestamp time.Time + Version int64 + IsGzip bool +} + +func (f *fileMeta) accountId() string { + return f.AccountId +} + +func (f *fileMeta) timestamp() time.Time { + return f.Timestamp +} + +func (f *fileMeta) version() int64 { + return f.Version +} + +func parseFeedFileName(name string) (*fileMeta, error) { + if !feedFileNamePattern.MatchString(name) { + return nil, fmt.Errorf("%s does not match feed fileMeta pattern, skipping", name) + } + + fields := strings.Split(name, ".") + var isGzip bool + switch len(fields) { + case 4: + isGzip = false + case 5: + if fields[4] == "gz" { + isGzip = true + } else { + return nil, fmt.Errorf("failed to parse fileMeta name in data feed directory: %s", name) + } + default: + return nil, fmt.Errorf("failed to parse fileMeta name in data feed directory: %s", name) + } + + timestamp, err := time.Parse(feedFileTimestampFormat, fields[1]) + if err != nil { + return nil, errors.E(err, fmt.Sprintf("failed to parse timestamp for name %s", name)) + } + + version, err := strconv.ParseInt(fields[2], 10, 64) + if err != nil { + return nil, errors.E(err, fmt.Sprintf("failed to parse version for name %s", name)) + } + + return &fileMeta{ + Name: name, + AccountId: fields[0], + Timestamp: timestamp, + Version: version, + IsGzip: isGzip, + }, nil +} + +// Entry corresponds to a single line in a Spot Instance data feed file. The +// Spot Instance data feed files are tab-delimited. Each line in the data file +// corresponds to one instance hour and contains the fields listed in the +// following table. The AccountId field is not specified for each individual entry +// but is given as a prefix in the name of the spot data feed file. +type Entry struct { + filterable + + // AccountId is a 12-digit account number (ID) that specifies the AWS account + // billed for this spot instance-hour. + AccountId string + + // Timestamp is used to determine the price charged for this instance usage. + // It is not at the hour boundary but within the hour specified by the title of + // the data feed file that contains this Entry. + Timestamp time.Time + + // UsageType is the type of usage and instance type being charged for. For + // m1.small Spot Instances, this field is set to SpotUsage. For all other + // instance types, this field is set to SpotUsage:{instance-type}. For + // example, SpotUsage:c1.medium. + UsageType string + + // Instance is the instance type being charged for and is a member of the + // set of information provided by UsageType. + Instance string + + // Operation is the product being charged for. For Linux Spot Instances, + // this field is set to RunInstances. For Windows Spot Instances, this + // field is set to RunInstances:0002. Spot usage is grouped according + // to Availability Zone. + Operation string + + // InstanceID is the ID of the Spot Instance that generated this instance + // usage. + InstanceID string + + // MyBidID is the ID for the Spot Instance request that generated this instance usage. + MyBidID string + + // MyMaxPriceUSD is the maximum price specified for this Spot Instance request. + MyMaxPriceUSD float64 + + // MarketPriceUSD is the Spot price at the time specified in the Timestamp field. + MarketPriceUSD float64 + + // ChargeUSD is the price charged for this instance usage. + ChargeUSD float64 + + // Version is the version included in the data feed file name for this record. + Version int64 +} + +func (e *Entry) accountId() string { + return e.AccountId +} + +func (e *Entry) timestamp() time.Time { + return e.Timestamp +} + +func (e *Entry) version() int64 { + return e.Version +} + +// parsePriceUSD parses a price in USD formatted like "6.669 USD". +func parsePriceUSD(priceField string) (float64, error) { + trimCurrency := strings.TrimSuffix(priceField, " USD") + if len(trimCurrency) != (len(priceField) - 4) { + return 0, fmt.Errorf("failed to trim currency from %s", priceField) + } + return strconv.ParseFloat(trimCurrency, 64) +} + +// parseUsageType parses the EC2 instance type from the spot data feed column UsageType, as per the AWS documentation. +// For m1.small Spot Instances, this field is set to SpotUsage. For all other instance types, this field is set to +// SpotUsage:{instance-type}. For example, SpotUsage:c1.medium. +func parseUsageType(usageType string) (string, error) { + fields := strings.Split(usageType, ":") + if len(fields) == 1 { + return "m1.small", nil + } + if len(fields) == 2 { + return fields[1], nil + } + return "", fmt.Errorf("failed to parse instance from UsageType %s", usageType) +} + +const ( + feedLineTimestampFormat = "2006-01-02 15:04:05 MST" +) + +// parseFeedLine parses an *Entry from a line in a spot data feed file. The content and ordering of the columns +// in this file are documented at https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/spot-data-feeds.html +func parseFeedLine(line string, accountId string) (*Entry, error) { + fields := strings.Split(line, "\t") + if len(fields) != 9 { + return nil, fmt.Errorf("failed to parse line in data feed: %s", line) + } + + timestamp, err := time.Parse(feedLineTimestampFormat, fields[0]) + if err != nil { + return nil, errors.E(err, fmt.Sprintf("failed to parse timestamp for line %s", line)) + } + + instance, err := parseUsageType(fields[1]) + if err != nil { + return nil, errors.E(err, fmt.Sprintf("failed to parse usage type for line %s", line)) + } + + myMaxPriceUSD, err := parsePriceUSD(fields[5]) + if err != nil { + return nil, errors.E(err, fmt.Sprintf("failed to parse my max price for line %s", line)) + } + + marketPriceUSD, err := parsePriceUSD(fields[6]) + if err != nil { + return nil, errors.E(err, fmt.Sprintf("failed to parse market price for line %s", line)) + } + + chargeUSD, err := parsePriceUSD(fields[7]) + if err != nil { + return nil, errors.E(err, fmt.Sprintf("failed to parse charge for line %s", line)) + } + + version, err := strconv.ParseInt(fields[8], 10, 64) + if err != nil { + return nil, errors.E(err, fmt.Sprintf("failed to parse version for line %s", line)) + } + + return &Entry{ + AccountId: accountId, + Timestamp: timestamp, + UsageType: fields[1], + Instance: instance, + Operation: fields[2], + InstanceID: fields[3], + MyBidID: fields[4], + MyMaxPriceUSD: myMaxPriceUSD, + MarketPriceUSD: marketPriceUSD, + ChargeUSD: chargeUSD, + Version: version, + }, nil +} + +func ParseFeedFile(feed io.Reader, accountId string) ([]*Entry, error) { + scn := bufio.NewScanner(feed) + + entries := make([]*Entry, 0) + for scn.Scan() { + line := scn.Text() + if strings.HasPrefix(line, "#") { + continue + } + + entry, err := parseFeedLine(scn.Text(), accountId) + if err != nil { + return nil, errors.E(err, "") + } + + entries = append(entries, entry) + } + + return entries, nil +} diff --git a/cloud/spotfeed/parser_test.go b/cloud/spotfeed/parser_test.go new file mode 100644 index 00000000..d8f34a4c --- /dev/null +++ b/cloud/spotfeed/parser_test.go @@ -0,0 +1,190 @@ +package spotfeed + +import ( + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestParseFeedFileSuccess(t *testing.T) { + for _, test := range []struct { + name string + feedBlob string + expectedEntries []*Entry + }{ + { + name: "empty", + feedBlob: `#Version: 1.0 +#Fields: Timestamp UsageType Operation InstanceID MyBidID MyMaxPrice MarketPrice Charge Version`, + expectedEntries: []*Entry{}, + }, + { + name: "one line", + feedBlob: `#Version: 1.0 +#Fields: Timestamp UsageType Operation InstanceID MyBidID MyMaxPrice MarketPrice Charge Version +2020-10-29 19:35:44 UTC USW2-SpotUsage:c4.large RunInstances:SV002 i-00028539b6b1de1c0 sir-yzyi9qrm 0.100 USD 0.034 USD 0.001 USD 1`, + expectedEntries: []*Entry{ + { + AccountId: testAccountId, + Timestamp: time.Date(2020, 10, 29, 19, 35, 44, 0, time.UTC), + UsageType: "USW2-SpotUsage:c4.large", + Instance: "c4.large", + Operation: "RunInstances:SV002", + InstanceID: "i-00028539b6b1de1c0", + MyBidID: "sir-yzyi9qrm", + MyMaxPriceUSD: 0.1, + MarketPriceUSD: 0.034, + ChargeUSD: 0.001, + Version: 1, + }, + }, + }, + { + name: "m1.small", + feedBlob: `#Version: 1.0 +#Fields: Timestamp UsageType Operation InstanceID MyBidID MyMaxPrice MarketPrice Charge Version +2020-10-29 19:38:33 UTC USW2-SpotUsage RunInstances:SV002 i-0c5a748cea172ba6b sir-7nigaazp 26.688 USD 8.006 USD 4.546 USD 1`, + expectedEntries: []*Entry{ + { + AccountId: testAccountId, + Timestamp: time.Date(2020, 10, 29, 19, 38, 33, 0, time.UTC), + UsageType: "USW2-SpotUsage", + Instance: "m1.small", + Operation: "RunInstances:SV002", + InstanceID: "i-0c5a748cea172ba6b", + MyBidID: "sir-7nigaazp", + MyMaxPriceUSD: 26.688, + MarketPriceUSD: 8.006, + ChargeUSD: 4.546, + Version: 1, + }, + }, + }, + { + name: "multiple", + feedBlob: `#Version: 1.0 +#Fields: Timestamp UsageType Operation InstanceID MyBidID MyMaxPrice MarketPrice Charge Version +2020-10-29 19:35:44 UTC USW2-SpotUsage:c4.large RunInstances:SV002 i-00028539b6b1de1c0 sir-yzyi9qrm 0.100 USD 0.034 USD 0.001 USD 1 +2020-10-29 19:35:05 UTC USW2-SpotUsage:c4.large RunInstances:SV002 i-0003301e05abdf3c1 sir-5d78aggq 0.100 USD 0.034 USD 0.002 USD 1 +2020-10-29 19:35:44 UTC USW2-SpotUsage:c4.large RunInstances:SV002 i-0028e565fd6e3d37b sir-9g7ibe6n 0.100 USD 0.034 USD 0.002 USD 1`, + expectedEntries: []*Entry{ + { + AccountId: testAccountId, + Timestamp: time.Date(2020, 10, 29, 19, 35, 44, 0, time.UTC), + UsageType: "USW2-SpotUsage:c4.large", + Instance: "c4.large", + Operation: "RunInstances:SV002", + InstanceID: "i-00028539b6b1de1c0", + MyBidID: "sir-yzyi9qrm", + MyMaxPriceUSD: 0.1, + MarketPriceUSD: 0.034, + ChargeUSD: 0.001, + Version: 1, + }, + { + AccountId: testAccountId, + Timestamp: time.Date(2020, 10, 29, 19, 35, 05, 0, time.UTC), + UsageType: "USW2-SpotUsage:c4.large", + Instance: "c4.large", + Operation: "RunInstances:SV002", + InstanceID: "i-0003301e05abdf3c1", + MyBidID: "sir-5d78aggq", + MyMaxPriceUSD: 0.1, + MarketPriceUSD: 0.034, + ChargeUSD: 0.002, + Version: 1, + }, + { + AccountId: testAccountId, + Timestamp: time.Date(2020, 10, 29, 19, 35, 44, 0, time.UTC), + UsageType: "USW2-SpotUsage:c4.large", + Instance: "c4.large", + Operation: "RunInstances:SV002", + InstanceID: "i-0028e565fd6e3d37b", + MyBidID: "sir-9g7ibe6n", + MyMaxPriceUSD: 0.1, + MarketPriceUSD: 0.034, + ChargeUSD: 0.002, + Version: 1, + }, + }, + }, + } { + t.Run(test.name, func(t *testing.T) { + entries, err := ParseFeedFile(strings.NewReader(test.feedBlob), testAccountId) + require.NoError(t, err, "failed to parse feed blob") + require.Equal(t, test.expectedEntries, entries) + }) + } +} + +func TestParseFeedFileNameSuccess(t *testing.T) { + for _, test := range []struct { + name string + fileName string + expectedMeta *fileMeta + }{ + { + "no_gzip", + testAccountId + ".2021-02-23-04.004.8a6d6bb8", + &fileMeta{ + nil, + testAccountId + ".2021-02-23-04.004.8a6d6bb8", + testAccountId, + time.Date(2021, 02, 23, 04, 0, 0, 0, time.UTC), + 4, + false, + }, + }, + { + "gzip", + testAccountId + ".2021-02-23-04.004.8a6d6bb8.gz", + &fileMeta{ + nil, + testAccountId + ".2021-02-23-04.004.8a6d6bb8.gz", + testAccountId, + time.Date(2021, 02, 23, 04, 0, 0, 0, time.UTC), + 4, + true, + }, + }, + } { + t.Run(test.name, func(t *testing.T) { + meta, err := parseFeedFileName(test.fileName) + require.NoError(t, err, "failed to parse feed file name") + require.Equal(t, test.expectedMeta, meta) + }) + } +} + +func TestParseFeedFileNameError(t *testing.T) { + for _, test := range []struct { + name string + fileName string + expectedErr string + }{ + { + "fail_regex", + testAccountId + ".2021-02-23-04.004.8bb8", + "does not match", + }, + { + "invalid_extension", + testAccountId + ".2021-02-23-04.004.8a6d6bb8.zip", + "does not match", + }, + { + "invalid_date", + testAccountId + ".2021-13-23-04.004.8a6d6bb8.gz", + "failed to parse timestamp", + }, + } { + t.Run(test.name, func(t *testing.T) { + _, err := parseFeedFileName(test.fileName) + require.Error(t, err, "unexpected error while parsing feed file name %s but call to parseFeedFileName succeeded", test.fileName) + require.Contains(t, err.Error(), test.expectedErr) + }) + } +} diff --git a/cloud/spotfeed/querier.go b/cloud/spotfeed/querier.go new file mode 100644 index 00000000..c3bb508a --- /dev/null +++ b/cloud/spotfeed/querier.go @@ -0,0 +1,160 @@ +package spotfeed + +import ( + "context" + "fmt" + "sort" + "time" +) + +// ErrMissingData is the error returned if there is no data for the query time period. +var ErrMissingData = fmt.Errorf("missing data") + +// Period is a time period with a start and end time. +type Period struct { + Start, End time.Time +} + +type Cost struct { + // Period defines the time period for which this cost is applicable. + Period + // ChargeUSD is the total charge over the time period specified by Period. + ChargeUSD float64 +} + +// Querier provides the ability to query for costs. +type Querier interface { + + // Query computes the cost charged for the given instanceId for given time period + // assuming that terminated was the time at which the instance was terminated. + // + // It is not required to specify terminated time. Specifying it only impacts cost + // calculations for a time period that overlaps the last partial hour of the instance's lifetime. + // + // For example, if the instance was running only for say 30m in the last partial hour, and if the + // desired time period overlaps say the first 15m of that hour, then one must specify + // terminated time to compute the cost correctly. In this example, not specifying terminated time + // would result in a cost higher than actual (ie for the entire last 30 mins instead of only 15 mins). + // + // If the given time period spans beyond the instance's actual lifetime, the returned cost will + // yet only reflect the lifetime cost. While the returned cost will have the correct start time, + // the correct end time will be set only if terminated time is provided. + // + // Query will return ErrMissingData if it has no data for the given instanceId or + // if it doesn't have data overlapping the given time period. + Query(instanceId string, p Period, terminated time.Time) (Cost, error) +} + +// NewQuerier fetches data from the given loader and returns a Querier based on the returned data. +func NewQuerier(ctx context.Context, l Loader) (Querier, error) { + entries, err := l.Fetch(ctx, false) + if err != nil { + return nil, err + } + return newQuerier(entries), nil +} + +// querier provides a Querier implementation using static list of entries provided upon initialization. +type querier struct { + byInstanceId map[string][]*Entry +} + +func newQuerier(all []*Entry) *querier { + byInstanceId := make(map[string][]*Entry) + for _, entry := range all { + iid := entry.InstanceID + if _, ok := byInstanceId[iid]; !ok { + byInstanceId[iid] = []*Entry{} + } + byInstanceId[iid] = append(byInstanceId[iid], entry) + } + for iid, entries := range byInstanceId { + // For each instance, first sort all the entries + sort.Slice(entries, func(i, j int) bool { + return entries[i].Timestamp.Before(entries[j].Timestamp) + }) + var ( + prev *Entry + iidEntries []*Entry + ) + // There can be multiple entries at the same Timestamp, one for each Version. + // We simply take the entry of the version that has the max cost for the same timestamp. + for _, entry := range entries { + switch { + case prev == nil: + case prev.Timestamp != entry.Timestamp: + iidEntries = append(iidEntries, prev) + case entry.ChargeUSD > prev.ChargeUSD: + default: + continue // Keep prev as-is. + } + prev = entry + } + if prev != nil { + iidEntries = append(iidEntries, prev) + } + byInstanceId[iid] = iidEntries + } + return &querier{byInstanceId: byInstanceId} +} + +// Query implements Querier interface. +func (q *querier) Query(instanceId string, p Period, terminated time.Time) (Cost, error) { + p.Start, p.End = p.Start.Truncate(time.Second), p.End.Truncate(time.Second) + entries, ok := q.byInstanceId[instanceId] + if !ok || len(entries) == 0 { + return Cost{}, ErrMissingData + } + i := sort.Search(len(entries), func(i int) bool { + // This will return an entry after p.Start, even if one exists exactly at p.Start + return p.Start.Before(entries[i].Timestamp) + }) + switch { + case i == len(entries): + // Start is past all entries, so we don't have any data for the given time period. + return Cost{}, ErrMissingData + case i == 0 && entries[i].Timestamp.After(p.End): + // End is before the first entry, so we don't have any data for the given time period. + return Cost{}, ErrMissingData + case i > 0: + // Since we always get the entry after p.Start, we have to move back (if possible) + // to cover the time period starting from p.Start. + i-- + } + var ( + ended bool + cost = Cost{Period: Period{End: p.End}} + prev = entries[i] + ) + if startTs := entries[i].Timestamp; p.Start.After(startTs) { + cost.Start = p.Start + } else { + cost.Start = startTs + } + for i++; !ended && i < len(entries); i++ { + startTs := prev.Timestamp + endTs := entries[i].Timestamp + if p.Start.After(startTs) { + startTs = p.Start + } + if p.End.Before(endTs) { + ended = true + endTs = p.End + } + ratio := endTs.Sub(startTs).Seconds() / entries[i].Timestamp.Sub(prev.Timestamp).Seconds() + cost.ChargeUSD += ratio * prev.ChargeUSD + prev = entries[i] + } + if !ended { + ratio := 1.0 + switch { + case terminated.IsZero(): + case p.End.Before(terminated): + ratio = p.End.Sub(prev.Timestamp).Seconds() / terminated.Sub(prev.Timestamp).Seconds() + default: + cost.End = terminated + } + cost.ChargeUSD += ratio * prev.ChargeUSD + } + return cost, nil +} diff --git a/cloud/spotfeed/querier_test.go b/cloud/spotfeed/querier_test.go new file mode 100644 index 00000000..6d7971c7 --- /dev/null +++ b/cloud/spotfeed/querier_test.go @@ -0,0 +1,153 @@ +package spotfeed + +import ( + "testing" + "time" +) + +func TestQuerier(t *testing.T) { + now := time.Now().Truncate(time.Second) + iid, typ := "some-instance-id", "some-instance-type" + entries := []*Entry{ + {ChargeUSD: 60, Timestamp: now.Add(-60 * time.Minute), InstanceID: iid, Instance: typ}, + {ChargeUSD: 110 /*ignored*/, Timestamp: now, InstanceID: iid, Instance: typ}, + {ChargeUSD: 120, Timestamp: now, InstanceID: iid, Instance: typ}, + {ChargeUSD: 80 /*ignored*/, Timestamp: now.Add(59 * time.Minute), InstanceID: iid, Instance: typ}, + {ChargeUSD: 90, Timestamp: now.Add(59 * time.Minute), InstanceID: iid, Instance: typ}, + {ChargeUSD: 120, Timestamp: now.Add(121 * time.Minute), InstanceID: iid, Instance: typ}, + {ChargeUSD: 88 /*ignored*/, Timestamp: now.Add(3 * time.Hour), InstanceID: iid, Instance: typ}, + {ChargeUSD: 89 /*ignored*/, Timestamp: now.Add(3 * time.Hour), InstanceID: iid, Instance: typ}, + {ChargeUSD: 90 /*duplicate*/, Timestamp: now.Add(3 * time.Hour), InstanceID: iid, Instance: typ}, + {ChargeUSD: 90, Timestamp: now.Add(3 * time.Hour), InstanceID: iid, Instance: typ}, + } + terminated := now.Add(3*time.Hour + 30*time.Minute) + q := newQuerier(entries) + _, err := q.Query("some-other-instance-id", Period{}, time.Time{}) + if got, want := err, ErrMissingData; got != want { + t.Errorf("got %v, want %v", got, want) + } + for i, tt := range []struct { + iet time.Time + p Period + c Cost + wantE error + }{ + { // Period starting and ending before data. + time.Time{}, + Period{now.Add(-90 * time.Minute), now.Add(-70 * time.Minute)}, + Cost{}, + ErrMissingData, + }, + { // Period starting and ending after data. + time.Time{}, + Period{now.Add(3*time.Hour + 1*time.Minute), now.Add(4 * time.Hour)}, + Cost{}, + ErrMissingData, + }, + { // Period starting before data but ending within. + time.Time{}, + Period{now.Add(-90 * time.Minute), now.Add(-30 * time.Minute)}, + Cost{ + Period{now.Add(-60 * time.Minute), now.Add(-30 * time.Minute)}, + 60 * 30.0 / 60.0, + }, + nil, + }, + { // Period starting within data and going beyond. + terminated, + Period{now.Add(2 * time.Hour), now.Add(4 * time.Hour)}, + Cost{ + Period{now.Add(2 * time.Hour), terminated}, + 90*1.0/62.0 + 120 + 90, + }, + nil, + }, + { // Period starting within data and going beyond with no terminated + time.Time{}, + Period{now.Add(2 * time.Hour), now.Add(4 * time.Hour)}, + Cost{ + Period{now.Add(2 * time.Hour), now.Add(4 * time.Hour)}, + 90*1.0/62.0 + 120 + 90, + }, + nil, + }, + { // Period starting exactly at some timestamp and ending within its period. + time.Time{}, + Period{now, now.Add(5 * time.Minute)}, + Cost{ + Period{now, now.Add(5 * time.Minute)}, + 120 * 5.0 / 59.0, + }, + nil, + }, + { // Period starting and within a single time period. + time.Time{}, + Period{now.Add(1 * time.Minute), now.Add(6 * time.Minute)}, + Cost{ + Period{now.Add(1 * time.Minute), now.Add(6 * time.Minute)}, + 120 * 5.0 / 59.0, + }, + nil, + }, + { // Period starting exactly at some timestamp and spanning more than one. + time.Time{}, + Period{now, now.Add(80 * time.Minute)}, + Cost{ + Period{now, now.Add(80 * time.Minute)}, + 120 + 90*21.0/62.0, + }, + nil, + }, + { // Period starting before data and ending after. + terminated, + Period{now.Add(-90 * time.Minute), now.Add(6 * time.Hour)}, + Cost{ + Period{now.Add(-60 * time.Minute), terminated}, + 60 + 120 + 90 + 120 + 90, + }, + nil, + }, + { // Period starting before data and ending after with no terminated. + time.Time{}, + Period{now.Add(-90 * time.Minute), now.Add(6 * time.Hour)}, + Cost{ + Period{now.Add(-60 * time.Minute), now.Add(6 * time.Hour)}, + 60 + 120 + 90 + 120 + 90, + }, + nil, + }, + { // Period starting within data but ending within last period before instance end time. + terminated, + Period{now.Add(-90 * time.Minute), now.Add(3*time.Hour + 15*time.Minute)}, + Cost{ + Period{now.Add(-60 * time.Minute), now.Add(3*time.Hour + 15*time.Minute)}, + 60 + 120 + 90 + 120 + 90*15/30.0, + }, + nil, + }, + { // Period starting within data but ending within last period with no terminated. + time.Time{}, + Period{now.Add(-90 * time.Minute), now.Add(3*time.Hour + 15*time.Minute)}, + Cost{ + Period{now.Add(-60 * time.Minute), now.Add(3*time.Hour + 15*time.Minute)}, + 60 + 120 + 90 + 120 + 90, + }, + nil, + }, + } { + c, err := q.Query(iid, tt.p, tt.iet) + if tt.wantE != nil { + if got, want := err, tt.wantE; got != want { + t.Errorf("[%d] got %v, want %v", i, got, want) + } + continue + } + if err != nil { + t.Error(err) + continue + } + if got, want := c, tt.c; got != want { + t.Errorf("[%d[ got %v, want %v", i, got, want) + } + } +} diff --git a/cloud/spotfeed/spotfeed.go b/cloud/spotfeed/spotfeed.go new file mode 100644 index 00000000..7b0ee4e0 --- /dev/null +++ b/cloud/spotfeed/spotfeed.go @@ -0,0 +1,533 @@ +// Package spotfeed is used for querying spot-data-feeds provided by AWS. +// See https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/spot-data-feeds.html for a description of the +// spot data feed format. +// +// This package provides two interfaces for interacting with the AWS spot data feed format for files hosted +// on S3. +// +// 1. Fetch - makes a single blocking call to fetch feed files for some historical period, then parses and +// returns the results as a single slice. +// 2. Stream - creates a goroutine that asynchronously checks (once per 30mins by default) the specified S3 +// location for new spot data feed files (and sends parsed entries into a channel provided to +// the user at invocation). +// +// This package also provides a LocalLoader which can perform a Fetch operation against feed files already +// downloaded to local disk. This is often useful for analyzing spot usage over long periods of time, since +// the download phase can take some time. +package spotfeed + +import ( + "compress/gzip" + "context" + "fmt" + "io/ioutil" + "log" + "os" + "path" + "strings" + "sync" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3iface" + "github.com/grailbio/base/errors" + "github.com/grailbio/base/retry" + "golang.org/x/sync/errgroup" + "golang.org/x/time/rate" +) + +var ( + // RetryPolicy is used to retry failed S3 API calls. + retryPolicy = retry.Backoff(time.Second, 10*time.Second, 2) + + // Used to rate limit S3 calls. + limiter = rate.NewLimiter(rate.Limit(16), 4) +) + +type filterable interface { + accountId() string + timestamp() time.Time + version() int64 +} + +type filters struct { + // AccountId configures the Loader to only return Entry objects that belong to the specified + // 12-digit AWS account number (ID). If zero, no AccountId filter is applied. + AccountId string + + // StartTime configures the Loader to only return Entry objects younger than StartTime. + // If nil, no StartTime filter is applied. + StartTime *time.Time + + // EndTime configures the Loader to only return Entry objects older than EndTime. + // If nil, no EndTime filter is applied. + EndTime *time.Time + + // Version configures the Loader to only return Entry objects with version equal to Version. + // If zero, no Version filter is applied, and if multiple feed versions declare the same + // instance-hour, de-duping based on the maximum value seen for that hour will be applied. + Version int64 +} + +// filter returns true if the entry does not match loader criteria and should be filtered out. +func (l *filters) filter(f filterable) bool { + if l.AccountId != "" && f.accountId() != l.AccountId { + return true + } + if l.StartTime != nil && f.timestamp().Before(*l.StartTime) { // inclusive + return true + } + if l.EndTime != nil && !f.timestamp().Before(*l.EndTime) { // exclusive + return true + } + if l.Version != 0 && f.version() != l.Version { + return true + } + return false +} + +// filterTruncatedStartTime performs the same checks as filter but truncates the start boundary down to the hour. +func (l *filters) filterTruncatedStartTime(f filterable) bool { + if l.AccountId != "" && f.accountId() != l.AccountId { + return true + } + if l.StartTime != nil { + truncatedStart := l.StartTime.Truncate(time.Hour) + if f.timestamp().Before(truncatedStart) { // inclusive + return true + } + } + if l.EndTime != nil && !f.timestamp().Before(*l.EndTime) { // exclusive + return true + } + if l.Version != 0 && f.version() != l.Version { + return true + } + return false + +} + +type localFile struct { + *fileMeta + path string +} + +func (f *localFile) read() ([]*Entry, error) { + fd, err := os.Open(f.path) + defer func() { _ = fd.Close() }() + if err != nil { + err = errors.E(err, fmt.Sprintf("failed to open local spot feed data file %s", f.path)) + return nil, err + } + + if f.IsGzip { + gz, err := gzip.NewReader(fd) + defer func() { _ = gz.Close() }() + if err != nil { + return nil, fmt.Errorf("failed to read gzipped file %s", f.Name) + } + return ParseFeedFile(gz, f.AccountId) + } + + return ParseFeedFile(fd, f.AccountId) +} + +type s3File struct { + *fileMeta + bucket, key string + client s3iface.S3API +} + +func (s *s3File) read(ctx context.Context) ([]*Entry, error) { + // Pull feed file from S3 with rate limiting and retries. + var output *s3.GetObjectOutput + for retries := 0; ; { + if err := limiter.Wait(ctx); err != nil { + return nil, err + } + var getObjErr error + if output, getObjErr = s.client.GetObjectWithContext(ctx, &s3.GetObjectInput{ + Bucket: aws.String(s.bucket), + Key: aws.String(s.key), + }); getObjErr != nil { + if !request.IsErrorThrottle(getObjErr) { + return nil, getObjErr + } + if err := retry.Wait(ctx, retryPolicy, retries); err != nil { + return nil, err + } + retries++ + continue + } + break + } + // If the file is gzipped, unpack before attempting to read. + if s.IsGzip { + gz, err := gzip.NewReader(output.Body) + if err != nil { + return nil, fmt.Errorf("failed to read gzipped file s3://%s/%s", s.bucket, s.key) + } + defer func() { _ = gz.Close() }() + return ParseFeedFile(gz, s.AccountId) + } + + return ParseFeedFile(output.Body, s.AccountId) +} + +// Loader provides an API for pulling Spot Data Feed Entry objects from some repository. +// The tolerateErr parameter configures how the Loader responds to errors parsing +// individual files or entries; if true, the Loader will continue to parse and yield Entry +// objects if an error is encountered during parsing. +type Loader interface { + // Fetch performs a single blocking call to fetch a discrete set of Entry objects. + Fetch(ctx context.Context, tolerateErr bool) ([]*Entry, error) + + // Stream asynchronously retrieves, parses and sends Entry objects on the returned channel. + // To graciously terminate the goroutine managing the Stream, the client terminates the given context. + Stream(ctx context.Context, tolerateErr bool) (<-chan *Entry, error) +} + +type s3Loader struct { + Loader + filters + + log *log.Logger + client s3iface.S3API + bucket string + rootURI string +} + +// commonFilePrefix returns the most specific prefix common to all spot feed data files that +// match the loader criteria. +func (s *s3Loader) commonFilePrefix() string { + if s.AccountId == "" { + return "" + } + + if s.StartTime == nil || s.EndTime == nil || s.StartTime.Year() != s.EndTime.Year() { + return s.AccountId + } + + if s.StartTime.Month() != s.EndTime.Month() { + return fmt.Sprintf("%s.%d", s.AccountId, s.StartTime.Year()) + } + + if s.StartTime.Day() != s.EndTime.Day() { + return fmt.Sprintf("%s.%d-%02d", s.AccountId, s.StartTime.Year(), s.StartTime.Month()) + } + + if s.StartTime.Hour() != s.EndTime.Hour() { + return fmt.Sprintf("%s.%d-%02d-%02d", s.AccountId, s.StartTime.Year(), s.StartTime.Month(), s.StartTime.Day()) + } + + return fmt.Sprintf("%s.%d-%02d-%02d-%02d", s.AccountId, s.StartTime.Year(), s.StartTime.Month(), s.StartTime.Day(), s.StartTime.Hour()) +} + +// timePrefix returns a prefix which matches the given time in UTC. +func (s *s3Loader) timePrefix(t time.Time) string { + if s.AccountId == "" { + panic("nowPrefix cannot be given without an account id") + } + + t = t.UTC() + return fmt.Sprintf("%s.%d-%02d-%02d-%02d", s.AccountId, t.Year(), t.Month(), t.Day(), t.Hour()) +} + +// path returns a prefix which joins the loader rootURI with the given uri. +func (s *s3Loader) path(uri string) string { + if s.rootURI == "" { + return uri + } else { + return fmt.Sprintf("%s/%s", s.rootURI, uri) + } +} + +// list queries the AWS S3 ListBucket API for feed files. +func (s *s3Loader) list(ctx context.Context, startAfter string, tolerateErr bool) ([]*s3File, error) { + prefix := s.path(s.commonFilePrefix()) + + s3Files := make([]*s3File, 0) + var parseMetaErr error + if err := s.client.ListObjectsV2PagesWithContext(ctx, &s3.ListObjectsV2Input{ + Bucket: aws.String(s.bucket), + Prefix: aws.String(prefix), + StartAfter: aws.String(startAfter), + }, func(output *s3.ListObjectsV2Output, lastPage bool) bool { + for _, object := range output.Contents { + filename := aws.StringValue(object.Key) + fileMeta, err := parseFeedFileName(filename) + if err != nil { + parseMetaErr = errors.E(err, fmt.Sprintf("failed to parse spot feed data file name %s", filename)) + if tolerateErr { + s.log.Print(parseMetaErr) + continue + } else { + return false + } + } + + // skips s3Files that do not match the loader criteria. Truncate the startTime of the filter to ensure that + // we do not skip files at hour HH:00 with a startTime of (i.e.) HH:30. + if s.filterTruncatedStartTime(fileMeta) { + s.log.Printf("%s does not pass fileMeta filter, skipping", filename) + continue + } + s3Files = append(s3Files, &s3File{ + fileMeta, + s.bucket, + filename, + s.client, + }) + } + return true + }); err != nil { + return nil, fmt.Errorf("list on path %s failed with error: %s", prefix, err) + } + if !tolerateErr && parseMetaErr != nil { + return nil, parseMetaErr + } + return s3Files, nil +} + +// fetchAfter builds a list of S3 feed file objects using the S3 ListBucket API. It then concurrently +// fetches and parses the feed files, observing rate and concurrency limits. +func (s *s3Loader) fetchAfter(ctx context.Context, startAfter string, tolerateErr bool) ([]*Entry, error) { + s3Files, err := s.list(ctx, startAfter, tolerateErr) + if err != nil { + return nil, err + } + + mu := &sync.Mutex{} + spotDataEntries := make([]*Entry, 0) + group, groupCtx := errgroup.WithContext(ctx) + for _, file := range s3Files { + file := file + group.Go(func() error { + if entries, err := file.read(groupCtx); err != nil { + err = errors.E(err, fmt.Sprintf("failed to parse spot feed data file s3://%s/%s", file.bucket, file.key)) + if tolerateErr { + s.log.Printf("encountered error %s, tolerating and skipping file s3://%s/%s", err, file.bucket, file.key) + return nil + } else { + return err + } + } else { + mu.Lock() + spotDataEntries = append(spotDataEntries, entries...) + mu.Unlock() + } + return nil + }) + } + if err := group.Wait(); err != nil { + return nil, err + } + + filteredEntries := make([]*Entry, 0) + for _, e := range spotDataEntries { + if !s.filter(e) { + filteredEntries = append(filteredEntries, e) + } + } + + return filteredEntries, nil +} + +// Fetch makes a single blocking call to fetch feed files for some historical period, +// then parses and returns the results as a single slice. The call attempts to start +// from the first entry such that Key > l.StartTime and breaks when it encounters the +// first entry such that Key > l.EndTime +func (s *s3Loader) Fetch(ctx context.Context, tolerateErr bool) ([]*Entry, error) { + prefix := s.path(s.commonFilePrefix()) + return s.fetchAfter(ctx, prefix, tolerateErr) +} + +var ( + // streamSleepDuration specifies how long to wait between calls to S3 ListBucket + streamSleepDuration = 30 * time.Minute +) + +// Stream creates a goroutine that asynchronously checks (once per 30mins by default) the specified S3 +// location for new spot data feed files (and sends parsed entries into a channel provided to the user at invocation). +// s3Loader must be configured with an account id to support the Stream interface. To stream events for multiple account ids +// which share a feed bucket, create multiple s3Loader objects. +// TODO: Allow caller to pass channel, allowing a single reader to manage multiple s3Loader.Stream calls. +func (s *s3Loader) Stream(ctx context.Context, tolerateErr bool) (<-chan *Entry, error) { + if s.AccountId == "" { + return nil, fmt.Errorf("s3Loader must be configured with an account id to provide asynchronous event streaming") + } + + entryChan := make(chan *Entry) + go func() { + startAfter := s.timePrefix(time.Now()) + for { + if ctx.Err() != nil { + close(entryChan) + return + } + + entries, err := s.fetchAfter(ctx, startAfter, tolerateErr) + if err != nil { + close(entryChan) + return + } + + for _, entry := range entries { + entryChan <- entry + } + + if len(entries) != 0 { + finalEntry := entries[len(entries)-1] + startAfter = s.timePrefix(finalEntry.Timestamp) + } + + time.Sleep(streamSleepDuration) + } + }() + + return entryChan, nil +} + +// NewSpotFeedLoader returns a Loader which queries the spot data feed subscription using the given session and +// returns a Loader which queries the S3 API for feed files (if a subscription does exist). +// NewSpotFeedLoader will return an error if the spot data feed subscription is missing. +func NewSpotFeedLoader(sess *session.Session, log *log.Logger, startTime, endTime *time.Time, version int64) (Loader, error) { + ec2api := ec2.New(sess) + resp, err := ec2api.DescribeSpotDatafeedSubscription(&ec2.DescribeSpotDatafeedSubscriptionInput{}) + if err != nil { + return nil, errors.E("DescribeSpotDatafeedSubscription", err) + } + bucket := aws.StringValue(resp.SpotDatafeedSubscription.Bucket) + rootURI := aws.StringValue(resp.SpotDatafeedSubscription.Prefix) + accountID := aws.StringValue(resp.SpotDatafeedSubscription.OwnerId) + return NewS3Loader(bucket, rootURI, s3.New(sess), log, accountID, startTime, endTime, version), nil +} + +// NewS3Loader returns a Loader which queries the S3 API for feed files. It supports the Fetch and Stream APIs. +func NewS3Loader(bucket, rootURI string, client s3iface.S3API, log *log.Logger, accountId string, startTime, endTime *time.Time, version int64) Loader { + // Remove any trailing slash from bucket and trailing/leading slash from rootURI. + if strings.HasSuffix(bucket, "/") { + bucket = bucket[:len(bucket)-1] + } + if strings.HasPrefix(rootURI, "/") { + rootURI = rootURI[1:] + } + if strings.HasSuffix(rootURI, "/") { + rootURI = rootURI[:len(rootURI)-1] + } + + return &s3Loader{ + filters: filters{ + AccountId: accountId, + StartTime: startTime, + EndTime: endTime, + Version: version, + }, + log: log, + client: client, + bucket: bucket, + rootURI: rootURI, + } +} + +type localLoader struct { + Loader + filters + + log *log.Logger + rootPath string +} + +// Fetch queries the local filesystem for feed files at the given path which match the given filename filters. +// It then parses, filters again and returns the Entry objects. +func (l *localLoader) Fetch(ctx context.Context, tolerateErr bool) ([]*Entry, error) { + // Iterate over files in directory, filter and build slice of feed files. + spotFiles := make([]*localFile, 0) + items, _ := ioutil.ReadDir(l.rootPath) + for _, item := range items { + // Skip subdirectories. + if item.IsDir() { + continue + } + + p := path.Join(l.rootPath, item.Name()) + fileMeta, err := parseFeedFileName(item.Name()) + if err != nil { + err = errors.E(err, fmt.Sprintf("failed to parse spot feed data file name %s", p)) + if tolerateErr { + l.log.Printf("encountered error %s, tolerating and skipping file %s", err, p) + continue + } else { + return nil, err + } + } + + // skips files that do not match the loader criteria. Truncate the startTime of the filter to ensure that + // we do not skip files at hour HH:00 with a startTime of (i.e.) HH:30. + if l.filterTruncatedStartTime(fileMeta) { + l.log.Printf("%s does not pass fileMeta filter, skipping", p) + continue + } + + spotFiles = append(spotFiles, &localFile{ + fileMeta, + p, + }) + } + + // Concurrently iterate over spot data feed files and build a slice of entries. + mu := &sync.Mutex{} + spotDataEntries := make([]*Entry, 0) + group, _ := errgroup.WithContext(ctx) + for _, file := range spotFiles { + file := file + group.Go(func() error { + if entries, err := file.read(); err != nil { + err = errors.E(err, fmt.Sprintf("failed to parse spot feed data file %s", file.path)) + if tolerateErr { + l.log.Printf("encountered error %s, tolerating and skipping file %s", err, file.path) + return nil + } else { + return err + } + } else { + mu.Lock() + spotDataEntries = append(spotDataEntries, entries...) + mu.Unlock() + } + return nil + }) + } + if err := group.Wait(); err != nil { + return nil, err + } + + // Filter entries + filteredEntries := make([]*Entry, 0) + for _, e := range spotDataEntries { + if !l.filter(e) { + filteredEntries = append(filteredEntries, e) + } + } + + return filteredEntries, nil +} + +// NewLocalLoader returns a Loader which fetches feed files from a path on the local filesystem. It does not support +// the Stream API. +func NewLocalLoader(path string, log *log.Logger, accountId string, startTime, endTime *time.Time, version int64) Loader { + return &localLoader{ + filters: filters{ + AccountId: accountId, + StartTime: startTime, + EndTime: endTime, + Version: version, + }, + log: log, + rootPath: path, + } +} diff --git a/cloud/spotfeed/spotfeed_test.go b/cloud/spotfeed/spotfeed_test.go new file mode 100644 index 00000000..1af07aa2 --- /dev/null +++ b/cloud/spotfeed/spotfeed_test.go @@ -0,0 +1,322 @@ +package spotfeed + +import ( + "bytes" + "compress/gzip" + "context" + "fmt" + "io/ioutil" + "log" + "sync" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3iface" + "github.com/stretchr/testify/require" +) + +const ( + testAccountId = "123456789000" +) + +func TestFilter(t *testing.T) { + ptr := func(t time.Time) *time.Time { + return &t + } + + for _, test := range []struct { + name string + filters filters + filterable filterable + expectedFilterResult bool + }{ + { + "passes_filter", + filters{ + AccountId: testAccountId, + StartTime: ptr(time.Date(2020, 01, 01, 01, 00, 00, 00, time.UTC)), + EndTime: ptr(time.Date(2020, 02, 01, 01, 00, 00, 00, time.UTC)), + Version: 1, + }, + &fileMeta{ + nil, + testAccountId + ".2021-02-23-04.004.8a6d6bb8", + testAccountId, + time.Date(2020, 01, 10, 04, 0, 0, 0, time.UTC), + 1, + false, + }, + false, + }, + { + "filtered_out_by_meta", + filters{ + Version: 1, + }, + &fileMeta{ + nil, + testAccountId + ".2021-02-23-04.004.8a6d6bb8", + testAccountId, + time.Date(2021, 02, 23, 04, 0, 0, 0, time.UTC), + 4, + false, + }, + true, + }, + { + "filtered_out_by_start_time", + filters{ + StartTime: ptr(time.Date(2020, 01, 01, 01, 00, 00, 00, time.UTC)), + }, + &fileMeta{ + nil, + testAccountId + ".2021-02-23-04.004.8a6d6bb8", + testAccountId, + time.Date(2019, 02, 23, 04, 0, 0, 0, time.UTC), + 4, + false, + }, + true, + }, + } { + t.Run(test.name, func(t *testing.T) { + require.Equal(t, test.expectedFilterResult, test.filters.filter(test.filterable)) + }) + } +} + +// gzipBytes takes some bytes and performs compression with compress/gzip. +func gzipBytes(b []byte) ([]byte, error) { + var buffer bytes.Buffer + gw := gzip.NewWriter(&buffer) + if _, err := gw.Write(b); err != nil { + return nil, err + } + if err := gw.Close(); err != nil { + return nil, err + } + return buffer.Bytes(), nil +} + +type mockS3Client struct { + s3iface.S3API + + getObjectResults map[string]string // path (bucket/key) -> decompressed response contents + + listMu sync.Mutex + listObjectPageResults [][]string // [](result path slices) +} + +func (m *mockS3Client) ListObjectsV2PagesWithContext(_ aws.Context, input *s3.ListObjectsV2Input, callback func(*s3.ListObjectsV2Output, bool) bool, _ ...request.Option) error { + m.listMu.Lock() + defer m.listMu.Unlock() + + if len(m.listObjectPageResults) == 0 { + return fmt.Errorf("unexpected attempt to list s3 objects at path %s/%s", aws.StringValue(input.Bucket), aws.StringValue(input.Prefix)) + } + var currKeys []string + currKeys, m.listObjectPageResults = m.listObjectPageResults[0], m.listObjectPageResults[1:] + + resultObjects := make([]*s3.Object, len(currKeys)) + for i, key := range currKeys { + resultObjects[i] = &s3.Object{ + Key: aws.String(key), + } + } + + callback(&s3.ListObjectsV2Output{Contents: resultObjects}, true) + return nil +} + +func (m *mockS3Client) GetObjectWithContext(_ aws.Context, input *s3.GetObjectInput, _ ...request.Option) (*s3.GetObjectOutput, error) { + path := fmt.Sprintf("%s/%s", aws.StringValue(input.Bucket), aws.StringValue(input.Key)) + if v, ok := m.getObjectResults[path]; ok { + // compress the response via gzip + gzContents, err := gzipBytes([]byte(v)) + if err != nil { + return nil, fmt.Errorf("failed to compress test response at %s: %s", path, err) + } + return &s3.GetObjectOutput{ + // compress the response contents via gzip + Body: ioutil.NopCloser(bytes.NewReader(gzContents)), + }, nil + } else { + return nil, fmt.Errorf("attempted to get unexpected path %s from mock s3", path) + } +} + +func TestLoaderFetch(t *testing.T) { + ctx := context.Background() + devNull := log.New(ioutil.Discard, "", 0) + for _, test := range []struct { + name string + loader Loader + expectedEntries []*Entry + }{ + { + "s3_no_source_files", + NewS3Loader("test-bucket", "", &mockS3Client{ + listObjectPageResults: [][]string{ + {}, // single empty response + }, + }, devNull, testAccountId, nil, nil, 0), + []*Entry{}, + }, + { + "s3_single_source_file", + NewS3Loader("test-bucket", "", &mockS3Client{ + listObjectPageResults: [][]string{ + { // single populated response + testAccountId + ".2021-02-25-15.002.3eb820a5.gz", + }, + }, + getObjectResults: map[string]string{ + "test-bucket/" + testAccountId + ".2021-02-25-15.002.3eb820a5.gz": `#Version: 1.0 +#Fields: Timestamp UsageType Operation InstanceID MyBidID MyMaxPrice MarketPrice Charge Version +2021-02-20 18:52:50 UTC USW2-SpotUsage:x1.16xlarge RunInstances:SV002 i-0053c2917e2afa2f0 sir-yb3gavgp 6.669 USD 2.001 USD 0.073 USD 2 +2021-02-20 18:50:58 UTC USW2-SpotUsage:x1.16xlarge RunInstances:SV002 i-07eaa4b2bf27c4b75 sir-w13gadim 6.669 USD 2.001 USD 1.741 USD 2`, + }, + }, devNull, testAccountId, nil, nil, 0), + []*Entry{ + { + AccountId: testAccountId, + Timestamp: time.Date(2021, 02, 20, 18, 52, 50, 0, time.UTC), + UsageType: "USW2-SpotUsage:x1.16xlarge", + Instance: "x1.16xlarge", + Operation: "RunInstances:SV002", + InstanceID: "i-0053c2917e2afa2f0", + MyBidID: "sir-yb3gavgp", + MyMaxPriceUSD: 6.669, + MarketPriceUSD: 2.001, + ChargeUSD: 0.073, + Version: 2, + }, + { + AccountId: testAccountId, + Timestamp: time.Date(2021, 02, 20, 18, 50, 58, 0, time.UTC), + UsageType: "USW2-SpotUsage:x1.16xlarge", + Instance: "x1.16xlarge", + Operation: "RunInstances:SV002", + InstanceID: "i-07eaa4b2bf27c4b75", + MyBidID: "sir-w13gadim", + MyMaxPriceUSD: 6.669, + MarketPriceUSD: 2.001, + ChargeUSD: 1.741, + Version: 2, + }, + }, + }, + { + "local_no_source_files", + NewLocalLoader( + "testdata/no_source_files", + devNull, testAccountId, nil, nil, 0), + []*Entry{}, + }, + { + "local_single_source_file", + NewLocalLoader( + "testdata/single_source_file", + devNull, testAccountId, nil, nil, 0), + []*Entry{ + { + AccountId: testAccountId, + Timestamp: time.Date(2021, 02, 20, 18, 52, 50, 0, time.UTC), + UsageType: "USW2-SpotUsage:x1.16xlarge", + Instance: "x1.16xlarge", + Operation: "RunInstances:SV002", + InstanceID: "i-0053c2917e2afa2f0", + MyBidID: "sir-yb3gavgp", + MyMaxPriceUSD: 6.669, + MarketPriceUSD: 2.001, + ChargeUSD: 0.073, + Version: 2, + }, + { + AccountId: testAccountId, + Timestamp: time.Date(2021, 02, 20, 18, 50, 58, 0, time.UTC), + UsageType: "USW2-SpotUsage:x1.16xlarge", + Instance: "x1.16xlarge", + Operation: "RunInstances:SV002", + InstanceID: "i-07eaa4b2bf27c4b75", + MyBidID: "sir-w13gadim", + MyMaxPriceUSD: 6.669, + MarketPriceUSD: 2.001, + ChargeUSD: 1.741, + Version: 2, + }, + { + AccountId: testAccountId, + Timestamp: time.Date(2021, 02, 20, 18, 56, 14, 0, time.UTC), + UsageType: "USW2-SpotUsage:x1.32xlarge", + Instance: "x1.32xlarge", + Operation: "RunInstances:SV002", + InstanceID: "i-000e2cebfe213246e", + MyBidID: "sir-fcg8btin", + MyMaxPriceUSD: 13.338, + MarketPriceUSD: 4.001, + ChargeUSD: 2.636, + Version: 2, + }, + { + AccountId: testAccountId, + Timestamp: time.Date(2021, 02, 20, 18, 56, 01, 0, time.UTC), + UsageType: "USW2-SpotUsage:x1.32xlarge", + Instance: "x1.32xlarge", + Operation: "RunInstances:SV002", + InstanceID: "i-032a1a622fb441a7b", + MyBidID: "sir-c6ag9vxn", + MyMaxPriceUSD: 13.338, + MarketPriceUSD: 4.001, + ChargeUSD: 4.001, + Version: 2, + }, + }, + }, + } { + t.Run(test.name, func(t *testing.T) { + entries, err := test.loader.Fetch(ctx, false) + require.NoError(t, err, "unexpected error fetching local feed files") + require.Equal(t, test.expectedEntries, entries) + }) + } +} + +func TestLoaderStream(t *testing.T) { + ctx := context.Background() + devNull := log.New(ioutil.Discard, "", 0) + + loader := NewS3Loader("test-bucket", "", &mockS3Client{ + listObjectPageResults: [][]string{ + { + // initial empty response + }, + { // single file response + testAccountId + ".2100-02-20-15.002.3eb820a5.gz", + }, + }, + getObjectResults: map[string]string{ + "test-bucket/" + testAccountId + "2100-02-20-15.002.3eb820a5.gz": `#Version: 1.0 +#Fields: Timestamp UsageType Operation InstanceID MyBidID MyMaxPrice MarketPrice Charge Version +2100-02-20 18:52:50 UTC USW2-SpotUsage:x1.16xlarge RunInstances:SV002 i-0053c2917e2afa2f0 sir-yb3gavgp 6.669 USD 2.001 USD 0.073 USD 2 +2100-02-20 18:50:58 UTC USW2-SpotUsage:x1.16xlarge RunInstances:SV002 i-07eaa4b2bf27c4b75 sir-w13gadim 6.669 USD 2.001 USD 1.741 USD 2`, + }, + }, devNull, testAccountId, nil, nil, 0) + + // speed up sleep duration to drain list objects slice + streamSleepDuration = time.Second + + entryChan, err := loader.Stream(ctx, false) + require.NoError(t, err, "unexpected err streaming s3 entries") + + // test successful drain of two entries channel + for i := 0; i < 2; i++ { + <-entryChan + } + + // kill the Stream goroutine + ctx.Done() +} diff --git a/cloud/spotfeed/testdata/single_source_file/123456789000.2021-02-20-18.002.cf3d1500 b/cloud/spotfeed/testdata/single_source_file/123456789000.2021-02-20-18.002.cf3d1500 new file mode 100644 index 00000000..bbbe6edc --- /dev/null +++ b/cloud/spotfeed/testdata/single_source_file/123456789000.2021-02-20-18.002.cf3d1500 @@ -0,0 +1,6 @@ +#Version: 1.0 +#Fields: Timestamp UsageType Operation InstanceID MyBidID MyMaxPrice MarketPrice Charge Version +2021-02-20 18:52:50 UTC USW2-SpotUsage:x1.16xlarge RunInstances:SV002 i-0053c2917e2afa2f0 sir-yb3gavgp 6.669 USD 2.001 USD 0.073 USD 2 +2021-02-20 18:50:58 UTC USW2-SpotUsage:x1.16xlarge RunInstances:SV002 i-07eaa4b2bf27c4b75 sir-w13gadim 6.669 USD 2.001 USD 1.741 USD 2 +2021-02-20 18:56:14 UTC USW2-SpotUsage:x1.32xlarge RunInstances:SV002 i-000e2cebfe213246e sir-fcg8btin 13.338 USD 4.001 USD 2.636 USD 2 +2021-02-20 18:56:01 UTC USW2-SpotUsage:x1.32xlarge RunInstances:SV002 i-032a1a622fb441a7b sir-c6ag9vxn 13.338 USD 4.001 USD 4.001 USD 2 diff --git a/cloud/url/url.go b/cloud/url/url.go deleted file mode 100644 index c7036c8b..00000000 --- a/cloud/url/url.go +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright 2018 GRAIL, Inc. All rights reserved. -// Use of this source code is governed by the Apache-2.0 -// license that can be found in the LICENSE file. - -// Package url provide AWS compatible url encoding and decoding funtions. -// AWS requires that the CopySource arg to the CopyObject* s3 api calls -// are 'url encoded'. -package url - -import ( - "strconv" -) - -// EscapeError provides a custom url escape related error type. -type EscapeError string - -func (e EscapeError) Error() string { - return "invalid URL escape " + strconv.Quote(string(e)) -} - -// shouldEscape returns true if byte should be escaped -func shouldEscape(c byte) bool { - return !((c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || - (c >= '0' && c <= '9') || c == '_' || c == '-' || c == '.' || c == '/') -} - -// Encode does uri encoding in an Amazon compatible way. -func Encode(s string) string { - hexCount := 0 - - for i := 0; i < len(s); i++ { - if shouldEscape(s[i]) { - hexCount++ - } - } - - if hexCount == 0 { - return s - } - - t := make([]byte, len(s)+2*hexCount) - j := 0 - for i := 0; i < len(s); i++ { - if c := s[i]; shouldEscape(c) { - t[j] = '%' - t[j+1] = "0123456789ABCDEF"[c>>4] - t[j+2] = "0123456789ABCDEF"[c&15] - j += 3 - } else { - t[j] = s[i] - j++ - } - } - return string(t) -} - -func ishex(c byte) bool { - switch { - case '0' <= c && c <= '9': - return true - case 'a' <= c && c <= 'f': - return true - case 'A' <= c && c <= 'F': - return true - } - return false -} - -func unhex(c byte) byte { - switch { - case '0' <= c && c <= '9': - return c - '0' - case 'a' <= c && c <= 'f': - return c - 'a' + 10 - case 'A' <= c && c <= 'F': - return c - 'A' + 10 - } - return 0 -} - -// Decode decodes Amazon compatible encoded urls. -func Decode(s string) (string, error) { - n := 0 - for i := 0; i < len(s); { - switch s[i] { - case '%': - n++ - if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) { - s = s[i:] - if len(s) > 3 { - s = s[:3] - } - return "", EscapeError(s) - } - i += 3 - default: - i++ - } - } - - if n == 0 { - return s, nil - } - - t := make([]byte, len(s)-2*n) - j := 0 - for i := 0; i < len(s); { - switch s[i] { - case '%': - t[j] = unhex(s[i+1])<<4 | unhex(s[i+2]) - j++ - i += 3 - default: - t[j] = s[i] - j++ - i++ - } - } - return string(t), nil -} diff --git a/cloud/url/url_test.go b/cloud/url/url_test.go deleted file mode 100644 index f078370c..00000000 --- a/cloud/url/url_test.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2018 GRAIL, Inc. All rights reserved. -// Use of this source code is governed by the Apache-2.0 -// license that can be found in the LICENSE file. - -package url - -import ( - "testing" -) - -func TestEncode(t *testing.T) { - if got, want := Encode("prefix/hi+there"), "prefix/hi%2Bthere"; got != want { - t.Errorf("wanted %s, got %s", want, got) - } - if got, want := Encode("preF:ix/&$@=,? "), "preF%3Aix/%26%24%40%3D%2C%3F%20"; got != want { - t.Errorf("wanted %s, got %s", want, got) - } -} - -func TestDecode(t *testing.T) { - got, err := Decode("prefix/hi%2Bthere") - want := "prefix/hi+there" - if err != nil { - t.Errorf("decode error %s", err) - } else if got != want { - t.Errorf("wanted %s, got %s", want, got) - } - got, err = Decode("preF%3Aix/%26%24%40%3D%2C%3F%20") - want = "preF:ix/&$@=,? " - if err != nil { - t.Errorf("decode error %s", err) - } else if got != want { - t.Errorf("wanted %s, got %s", want, got) - } -} diff --git a/cmd/gofat/gofat.go b/cmd/gofat/gofat.go new file mode 100644 index 00000000..92d5ece7 --- /dev/null +++ b/cmd/gofat/gofat.go @@ -0,0 +1,207 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +// Command gofat is a simple utility to make fat binaries in the +// fatbin format (see github.com/grailbio/base/fatbin). +package main + +import ( + "bytes" + "flag" + "fmt" + "io" + "io/ioutil" + "os" + "os/exec" + "path" + "runtime" + "strings" + + "github.com/grailbio/base/embedbin" + "github.com/grailbio/base/fatbin" + "github.com/grailbio/base/log" +) + +func main() { + log.AddFlags() + log.SetFlags(0) + log.SetPrefix("gofat: ") + flag.Usage = func() { + fmt.Fprintf(os.Stderr, `usage: + gofat build build a fatbin binary + gofat info show fatbin binary information + gofat embed build an embedbin binary +`) + flag.PrintDefaults() + os.Exit(2) + } + flag.Parse() + if flag.NArg() == 0 { + flag.Usage() + } + + cmd, args := flag.Arg(0), flag.Args()[1:] + switch cmd { + default: + fmt.Fprintf(os.Stderr, "unknown command %s\n", cmd) + flag.Usage() + case "info": + info(args) + case "build": + build(args) + case "embed": + embed(args) + } +} + +func info(args []string) { + if len(args) == 0 { + fmt.Fprintf(os.Stderr, "usage: gofat info binaries...\n") + os.Exit(2) + } + for _, filename := range args { + f, err := os.Open(filename) + must(err) + info, err := f.Stat() + must(err) + r, err := fatbin.OpenFile(f, info.Size()) + must(err) + fmt.Println(filename, r.GOOS()+"/"+r.GOARCH(), info.Size()) + for _, info := range r.List() { + fmt.Print("\t", info.Goos, "/", info.Goarch, " ", info.Size, "\n") + } + must(f.Close()) + } +} + +func build(args []string) { + var ( + flags = flag.NewFlagSet("build", flag.ExitOnError) + goarches = flags.String("goarches", "amd64", "list of GOARCH values to build") + gooses = flags.String("gooses", "darwin,linux", "list of GOOS values to build") + out = flags.String("o", "", "build output path") + ) + flags.Usage = func() { + fmt.Fprintf(os.Stderr, "usage: gofat build [-o output] [packages]\n") + flags.PrintDefaults() + os.Exit(2) + } + + must(flags.Parse(args)) + if *out == "" { + cmd := exec.Command("go", append([]string{"list"}, flags.Args()...)...) + cmd.Stderr = os.Stderr + listout, err := cmd.Output() + must(err) + *out = path.Base(string(bytes.TrimSpace(listout))) + } + + cmd := exec.Command("go", append([]string{"build", "-o", *out}, flags.Args()...)...) + cmd.Stderr = os.Stderr + must(cmd.Run()) + + f, err := os.OpenFile(*out, os.O_WRONLY|os.O_APPEND, 0777) + must(err) + info, err := f.Stat() + must(err) + fat := fatbin.NewWriter(f, info.Size(), runtime.GOOS, runtime.GOARCH) + + for _, goarch := range strings.Split(*goarches, ",") { + for _, goos := range strings.Split(*gooses, ",") { + if goarch == runtime.GOARCH && goos == runtime.GOOS { + continue + } + outfile, err := ioutil.TempFile("", *out) + must(err) + name := outfile.Name() + outfile.Close() + cmd := exec.Command("go", "build", "-o", name) + cmd.Stderr = os.Stderr + cmd.Env = append(os.Environ(), "GOOS="+goos, "GOARCH="+goarch) + must(cmd.Run()) + + outfile, err = os.Open(name) + must(err) + w, err := fat.Create(goos, goarch) + must(err) + _, err = io.Copy(w, outfile) + must(err) + must(os.Remove(name)) + must(outfile.Close()) + log.Print("append ", goos, "/", goarch) + } + } + must(fat.Close()) + must(f.Close()) +} + +func embed(args []string) { + var ( + flags = flag.NewFlagSet("embed", flag.ExitOnError) + out = flags.String("o", "", "build output path") + ) + flags.Usage = func() { + fmt.Fprintf(os.Stderr, "usage: gofat embed [-o output] [name1:path1 [name2:path2 ...]]\n") + flags.PrintDefaults() + os.Exit(2) + } + + must(flags.Parse(args)) + args = flags.Args() + if len(args) == 0 { + log.Fatal("missing path to input binary") + } + inputPath, args := args[0], args[1:] + + paths := map[string]string{} + var names []string + for _, arg := range args { + parts := strings.SplitN(arg, ":", 2) + if len(parts) != 2 { + log.Fatalf("malformed argument: %s", arg) + } + name, path := parts[0], parts[1] + if _, ok := paths[name]; ok { + log.Fatalf("duplicate name: %s", name) + } + paths[name] = path + names = append(names, name) + } + + var outF *os.File + var err error + if *out != "" { + outF, err = os.OpenFile(*out, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0777) + must(err) + var inF *os.File + inF, err = os.Open(inputPath) + must(err) + _, err = io.Copy(outF, inF) + must(err) + must(inF.Close()) + } else { + outF, err = os.OpenFile(inputPath, os.O_RDWR|os.O_APPEND, 0777) + must(err) + } + + ew, err := embedbin.NewFileWriter(outF) + must(err) + for _, name := range names { + embedF, err := os.Open(paths[name]) + must(err) + embedW, err := ew.Create(name) + must(err) + _, err = io.Copy(embedW, embedF) + must(err) + must(embedF.Close()) + } + must(ew.Close()) + must(outF.Close()) +} + +func must(err error) { + if err != nil { + log.Fatal(err) + } +} diff --git a/cmd/grail-access/cmd_test.go b/cmd/grail-access/cmd_test.go new file mode 100644 index 00000000..8f779bb9 --- /dev/null +++ b/cmd/grail-access/cmd_test.go @@ -0,0 +1,336 @@ +package main_test + +import ( + "bytes" + "encoding/json" + "fmt" + "io/ioutil" + "net" + "net/http" + "os" + "os/exec" + "path" + "strings" + "testing" + "time" + + ticketServerUtil "github.com/grailbio/base/cmd/ticket-server/testutil" + "github.com/grailbio/base/security/identity" + "github.com/grailbio/testutil" + _ "github.com/grailbio/v23/factories/grail" + "github.com/stretchr/testify/assert" + v23 "v.io/v23" + "v.io/v23/context" + "v.io/v23/naming" + "v.io/v23/rpc" + "v.io/v23/security" + "v.io/x/ref" + libsecurity "v.io/x/ref/lib/security" +) + +func TestCmd(t *testing.T) { + ctx, v23CleanUp := v23.Init() + defer v23CleanUp() + assert.NoError(t, ref.EnvClearCredentials()) + exe := testutil.GoExecutable(t, "//go/src/github.com/grailbio/base/cmd/grail-access/grail-access") + + // Preserve the test environment's PATH. On Darwin, Vanadium's agentlib uses `ioreg` from the + // path in the process of locking [1] and loading [2] the principal when there's no agent. + // [1] https://github.com/vanadium/core/blob/694a147f5dfd7ebc2d2e5a4fb3c4fe448c7a377c/x/ref/services/agent/internal/lockutil/version1_darwin.go#L21 + // [2] https://github.com/vanadium/core/blob/694a147f5dfd7ebc2d2e5a4fb3c4fe448c7a377c/x/ref/services/agent/agentlib/principal.go#L57 + pathEnv := "PATH=" + os.Getenv("PATH") + + t.Run("help", func(t *testing.T) { + cmd := exec.Command(exe, "-help") + cmd.Env = []string{pathEnv} + stdout, stderr := ticketServerUtil.RunAndCapture(t, cmd) + assert.NotEmpty(t, stdout) + assert.Empty(t, stderr) + }) + + // TODO(josh): Test with v23agentd on the path, too. + t.Run("dump_existing_principal", func(t *testing.T) { + homeDir, cleanUp := testutil.TempDir(t, "", "") + defer cleanUp() + principalDir := path.Join(homeDir, ".v23") + principal, err := libsecurity.CreatePersistentPrincipal(principalDir, nil) + assert.NoError(t, err) + decoyPrincipalDir := path.Join(homeDir, "decoy_principal_dir") + // Create a principal in the decoyPrincipalDir, as -dir still requires + // a valid principal at $V23_CREDENTIALS. + // TODO: Consider removing -dir flag, as this is surprising behavior. + _, err = libsecurity.CreatePersistentPrincipal(decoyPrincipalDir, nil) + assert.NoError(t, err) + + const blessingName = "grail-access-test-blessing-ln7z94" + blessings, err := principal.BlessSelf(blessingName) + assert.NoError(t, err) + assert.NoError(t, principal.BlessingStore().SetDefault(blessings)) + + t.Run("flag_dir", func(t *testing.T) { + cmd := exec.Command(exe, "-dump", "-dir", principalDir) + // Set $V23_CREDENTIALS to test that -dir takes priority. + cmd.Env = []string{pathEnv, "V23_CREDENTIALS=" + decoyPrincipalDir} + stdout, stderr := ticketServerUtil.RunAndCapture(t, cmd) + assert.Contains(t, stdout, blessingName) + assert.Empty(t, stderr) + }) + + t.Run("env_home", func(t *testing.T) { + cmd := exec.Command(exe, "-dump") + cmd.Env = []string{pathEnv, "HOME=" + homeDir} + stdout, stderr := ticketServerUtil.RunAndCapture(t, cmd) + assert.Contains(t, stdout, blessingName) + assert.Empty(t, stderr) + }) + }) + + t.Run("do_not_refresh/existing", func(t *testing.T) { + principalDir, cleanUp := testutil.TempDir(t, "", "") + defer cleanUp() + principal, err := libsecurity.CreatePersistentPrincipal(principalDir, nil) + assert.NoError(t, err) + + const blessingName = "grail-access-test-blessing-nuz823" + doNotRefreshDuration := time.Hour + expirationTime := time.Now().Add(2 * doNotRefreshDuration) + expiryCaveat, err := security.NewExpiryCaveat(expirationTime) + assert.NoError(t, err) + blessings, err := principal.BlessSelf(blessingName, expiryCaveat) + assert.NoError(t, err) + assert.NoError(t, principal.BlessingStore().SetDefault(blessings)) + + cmd := exec.Command(exe, + "-dir", principalDir, + "-do-not-refresh-duration", doNotRefreshDuration.String()) + cmd.Env = []string{pathEnv} + stdout, stderr := ticketServerUtil.RunAndCapture(t, cmd) + assert.Contains(t, stdout, blessingName) + assert.Empty(t, stderr) + }) + + t.Run("fake_v23_servers", func(t *testing.T) { + + t.Run("ec2", func(t *testing.T) { + const ( + wantDoc = "grailaccesstesttoken92lsl83" + serverBlessingName = "grail-access-test-blessing-laul37" + clientBlessingExtension = "ec2-test" + wantClientBlessing = serverBlessingName + ":" + clientBlessingExtension + ) + + // Run fake ticket server: accepts EC2 instance identity document, returns blessings. + var blesserEndpoint naming.Endpoint + ctx, blesserEndpoint = ticketServerUtil.RunBlesserServer(ctx, t, + identity.Ec2BlesserServer(fakeBlesser( + func(gotDoc string, recipient security.PublicKey) security.Blessings { + assert.Equal(t, wantDoc, gotDoc) + p := v23.GetPrincipal(ctx) + caveat, err := security.NewExpiryCaveat(time.Now().Add(24 * time.Hour)) + assert.NoError(t, err) + localBlessings, err := p.BlessSelf(serverBlessingName) + assert.NoError(t, err) + b, err := p.Bless(recipient, localBlessings, clientBlessingExtension, caveat) + assert.NoError(t, err) + return b + }), + ), + ) + + // Run fake EC2 instance identity server. + listener, err := net.Listen("tcp", "localhost:") + assert.NoError(t, err) + defer func() { assert.NoError(t, listener.Close()) }() + go http.Serve( // nolint: errcheck + listener, + http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + _, httpErr := w.Write([]byte(wantDoc)) + assert.NoError(t, httpErr) + }), + ) + + // Run grail-access to create a principal and bless it with the EC2 flow. + principalDir, principalCleanUp := testutil.TempDir(t, "", "") + defer principalCleanUp() + cmd := exec.Command(exe, + "-dir", principalDir, + "-ec2", + "-blesser", fmt.Sprintf("/%s", blesserEndpoint.Address), + "-ec2-instance-identity-url", fmt.Sprintf("http://%s/", listener.Addr().String())) + cmd.Env = []string{pathEnv} + stdout, _ := ticketServerUtil.RunAndCapture(t, cmd) + assert.Contains(t, stdout, wantClientBlessing) + + // Make sure we got the right blessing. + principal, err := libsecurity.LoadPersistentPrincipal(principalDir, nil) + assert.NoError(t, err) + defaultBlessing, _ := principal.BlessingStore().Default() + assert.Contains(t, defaultBlessing.String(), wantClientBlessing) + }) + + t.Run("google", func(t *testing.T) { + const ( + wantToken = "grailaccesstesttokensjo289d" + serverBlessingName = "grail-access-test-blessing-s8j9dk" + clientBlessingExtension = "google-test" + wantClientBlessing = serverBlessingName + ":" + clientBlessingExtension + ) + + // Run fake ticket server: accepts Google ID token, returns blessings. + var blesserEndpoint naming.Endpoint + ctx, blesserEndpoint = ticketServerUtil.RunBlesserServer(ctx, t, + identity.GoogleBlesserServer(fakeBlesser( + func(gotToken string, recipient security.PublicKey) security.Blessings { + assert.Equal(t, wantToken, gotToken) + p := v23.GetPrincipal(ctx) + caveat, err := security.NewExpiryCaveat(time.Now().Add(24 * time.Hour)) + assert.NoError(t, err) + localBlessings, err := p.BlessSelf(serverBlessingName) + assert.NoError(t, err) + b, err := p.Bless(recipient, localBlessings, clientBlessingExtension, caveat) + assert.NoError(t, err) + return b + }), + ), + ) + + // Run fake oauth server. + listener, err := net.Listen("tcp", "localhost:") + assert.NoError(t, err) + defer func() { assert.NoError(t, listener.Close()) }() + go http.Serve( // nolint: errcheck + listener, + http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + if req.URL.Path != "/token" { + assert.FailNowf(t, "fake oauth server: unexpected request: %s", req.URL.Path) + } + w.Header().Set("Content-Type", "application/json") + assert.NoError(t, json.NewEncoder(w).Encode( + map[string]interface{}{ + "access_token": "testtoken", + "expires_in": 3600, + "id_token": wantToken, + "scope": "https://www.googleapis.com/auth/userinfo.email", + }, + )) + }), + ) + + // Run grail-access to create a principal and bless it with the EC2 flow. + principalDir, principalCleanUp := testutil.TempDir(t, "", "") + defer principalCleanUp() + cmd := exec.Command(exe, + "-dir", principalDir, + "-browser=false", + "-blesser", fmt.Sprintf("/%s", blesserEndpoint.Address), + "-google-oauth2-url", fmt.Sprintf("http://%s", listener.Addr().String())) + cmd.Env = []string{pathEnv} + cmd.Stdin = bytes.NewReader([]byte("testcode")) + stdout, _ := ticketServerUtil.RunAndCapture(t, cmd) + assert.Contains(t, stdout, wantClientBlessing) + + // Make sure we got the right blessing. + principal, err := libsecurity.LoadPersistentPrincipal(principalDir, nil) + assert.NoError(t, err) + defaultBlessing, _ := principal.BlessingStore().Default() + assert.Contains(t, defaultBlessing.String(), wantClientBlessing) + }) + + t.Run("k8s", func(t *testing.T) { + const ( + wantCaCrt = "caCrt" + wantNamespace = "namespace" + wantToken = "token" + wantRegion = "us-west-2" + serverBlessingName = "grail-access-test-blessing-abc123" + clientBlessingExtension = "k8s-test" + wantClientBlessing = serverBlessingName + ":" + clientBlessingExtension + ) + + // Run fake ticket server: accepts (caCrt, namespace, token), returns blessings. + var blesserEndpoint naming.Endpoint + ctx, blesserEndpoint = ticketServerUtil.RunBlesserServer(ctx, t, + identity.K8sBlesserServer(fakeK8sBlesser( + func(gotCaCrt string, gotNamespace string, gotToken string, gotRegion string, recipient security.PublicKey) security.Blessings { + assert.Equal(t, gotCaCrt, wantCaCrt) + assert.Equal(t, gotNamespace, wantNamespace) + assert.Equal(t, gotToken, wantToken) + assert.Equal(t, gotRegion, wantRegion) + p := v23.GetPrincipal(ctx) + caveat, err := security.NewExpiryCaveat(time.Now().Add(24 * time.Hour)) + assert.NoError(t, err) + localBlessings, err := p.BlessSelf(serverBlessingName) + assert.NoError(t, err) + b, err := p.Bless(recipient, localBlessings, clientBlessingExtension, caveat) + assert.NoError(t, err) + return b + }), + ), + ) + + // Create caCrt, namespace, and token files + tmpDir, cleanUp := testutil.TempDir(t, "", "") + defer cleanUp() + + assert.NoError(t, ioutil.WriteFile(path.Join(tmpDir, "caCrt"), []byte(wantCaCrt), 0644)) + assert.NoError(t, ioutil.WriteFile(path.Join(tmpDir, "namespace"), []byte(wantNamespace), 0644)) + assert.NoError(t, ioutil.WriteFile(path.Join(tmpDir, "token"), []byte(wantToken), 0644)) + + // Run grail-access to create a principal and bless it with the k8s flow. + principalDir, principalCleanUp := testutil.TempDir(t, "", "") + defer principalCleanUp() + cmd := exec.Command(exe, + "-dir", principalDir, + "-blesser", fmt.Sprintf("/%s", blesserEndpoint.Address), + "-k8s", + "-ca-crt", path.Join(tmpDir, "caCrt"), + "-namespace", path.Join(tmpDir, "namespace"), + "-token", path.Join(tmpDir, "token"), + ) + cmd.Env = []string{pathEnv} + stdout, _ := ticketServerUtil.RunAndCapture(t, cmd) + assert.Contains(t, stdout, wantClientBlessing) + + // Make sure we got the right blessing. + principal, err := libsecurity.LoadPersistentPrincipal(principalDir, nil) + assert.NoError(t, err) + defaultBlessing, _ := principal.BlessingStore().Default() + assert.Contains(t, defaultBlessing.String(), wantClientBlessing) + }) + + // If any of ca.crt, namespace, or token files are missing, an error should be thrown. + t.Run("k8s_missing_file_should_fail", func(t *testing.T) { + // Run grail-access to create a principal and bless it with the k8s flow. + principalDir, principalCleanUp := testutil.TempDir(t, "", "") + defer principalCleanUp() + cmd := exec.Command(exe, + "-dir", principalDir, + "-k8s", + ) + cmd.Env = []string{pathEnv} + var stderrBuf strings.Builder + cmd.Stderr = &stderrBuf + err := cmd.Run() + assert.Error(t, err) + wantStderr := "no such file or directory" + assert.True(t, strings.Contains(stderrBuf.String(), wantStderr)) + }) + + }) +} + +type fakeBlesser func(arg string, recipientKey security.PublicKey) security.Blessings + +func (f fakeBlesser) BlessEc2(_ *context.T, call rpc.ServerCall, s string) (security.Blessings, error) { + return f(s, call.Security().RemoteBlessings().PublicKey()), nil +} + +func (f fakeBlesser) BlessGoogle(_ *context.T, call rpc.ServerCall, s string) (security.Blessings, error) { + return f(s, call.Security().RemoteBlessings().PublicKey()), nil +} + +type fakeK8sBlesser func(arg1, arg2, arg3, arg4 string, recipientKey security.PublicKey) security.Blessings + +func (f fakeK8sBlesser) BlessK8s(_ *context.T, call rpc.ServerCall, s1, s2, s3, s4 string) (security.Blessings, error) { + return f(s1, s2, s3, s4, call.Security().RemoteBlessings().PublicKey()), nil +} diff --git a/cmd/grail-access/doc.go b/cmd/grail-access/doc.go index 0d3e5ae7..7300b2ec 100644 --- a/cmd/grail-access/doc.go +++ b/cmd/grail-access/doc.go @@ -1,7 +1,3 @@ -// Copyright 2018 GRAIL, Inc. All rights reserved. -// Use of this source code is governed by the Apache-2.0 -// license that can be found in the LICENSE file. - // This file was auto-generated via go generate. // DO NOT UPDATE MANUALLY @@ -26,29 +22,54 @@ Usage: grail-access [flags] The grail-access flags are: - -blesser-ec2=/ticket-server.eng.grail.com:8102/blesser/ec2 - Blesser to talk to for the EC2-based flow. - -blesser-google=/ticket-server.eng.grail.com:8102/blesser/google - Blesser to talk to for the Google-based flow. + -bless-remotes=true + Whether to attempt to bless remotes with local blessings; only applies to + Google blessings + -bless-remotes-targets=ec2-name:ubuntu@adhoc.jjc.* + Comma-separated list of targets to bless; targets may be + "ssh:[user@]host[:port]" SSH destinations or + "ec2-name:[user@]ec2-instance-name-filter" EC2 instance name filters; see + https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/Using_Filtering.html + -blesser= + Flow specific blesser endpoint to use. Defaults to + /ticket-server.eng.grail.com:8102/blesser/. -browser=true Attempt to open a browser. - -dir=/home/razvanm/.v23 + -ca-crt=/var/run/secrets/kubernetes.io/serviceaccount/ca.crt + Path to ca.crt file. + -dir=/mnt/home/jjc/.v23 Where to store the Vanadium credentials. NOTE: the content will be erased if the credentials are regenerated. + -do-not-refresh-duration=168h0m0s + Do not refresh credentials if they are present and do not expire within this + duration. + -dump=false + If credentials are present, dump them on the console instead of refreshing + them. -ec2=false Use the role of the EC2 VM. + -ec2-instance-identity-url=http://169.254.169.254/latest/dynamic/instance-identity/pkcs7 + URL for fetching instance identity document, for testing + -expiry-caveat= + Duration of expiry caveat added to blessings (for testing); empty means no + caveat added + -google-oauth2-url=https://accounts.google.com/o/oauth2 + URL for oauth2 API calls, for testing + -internal-bless-remotes-mode= + (INTERNAL) Controls the mode in which we run for the remote blessing + protocol; one of {public-key,receive,send} + -k8s=false + Use the Kubernetes flow. + -namespace=/var/run/secrets/kubernetes.io/serviceaccount/namespace + Path to namespace file. + -region=us-west-2 + AWS EKS region to use for k8s cluster token review. + -token=/var/run/secrets/kubernetes.io/serviceaccount/token + Path to token file. The global flags are: -alsologtostderr=true log to standard error as well as files - -block-profile= - filename prefix for block profiles - -block-profile-rate=1 - rate for runtime. SetBlockProfileRate - -cpu-profile= - filename for cpu profile - -heap-profile= - filename prefix for heap profiles -log_backtrace_at=:0 when logging hits line file:N, emit a stack trace -log_dir= @@ -59,50 +80,65 @@ The global flags are: max size in bytes of the buffer to use for logging stack traces -metadata= Displays metadata for the program and exits. - -mutex-profile= - filename prefix for mutex profiles - -mutex-profile-rate=1 - rate for runtime.SetMutexProfileFraction - -pprof= - address for pprof server - -profile-interval-s=0 - If >0, output new profiles at this interval (seconds). If <=0, profiles are - written only when Write() is called -stderrthreshold=2 logs at or above this threshold go to stderr - -thread-create-profile= - filename prefix for thread create profiles -time=false Dump timing information to stderr before exiting the program. -v=0 log level for V logs -v23.credentials= directory to use for storing security credentials - -v23.i18n-catalogue= - 18n catalogue files to load, comma separated - -v23.namespace.root=[/(v23.grail.com:internal:mounttabled)@ns.v23.grail.com:8101] + -v23.namespace.root=[/(v23.grail.com:internal:mounttabled)@ns-0.v23.grail.com:8101,/(v23.grail.com:internal:mounttabled)@ns-1.v23.grail.com:8101,/(v23.grail.com:internal:mounttabled)@ns-2.v23.grail.com:8101] local namespace root; can be repeated to provided multiple roots - -v23.permissions.file=map[] + -v23.permissions.file= specify a perms file as : -v23.permissions.literal= explicitly specify the runtime perms as a JSON-encoded access.Permissions. - Overrides all --v23.permissions.file flags. + Overrides all --v23.permissions.file flags -v23.proxy= object name of proxy service to use to export services across network boundaries + -v23.proxy.limit=0 + max number of proxies to connect to when the policy is to connect to all + proxies; 0 implies all proxies + -v23.proxy.policy= + policy for choosing from a set of available proxy instances -v23.tcp.address= address to listen on - -v23.tcp.protocol=wsh + -v23.tcp.protocol= protocol to listen with + -v23.virtualized.advertise-private-addresses= + if set the process will also advertise its private addresses + -v23.virtualized.disallow-native-fallback=false + if set, a failure to detect the requested virtualization provider will result + in an error, otherwise, native mode is used + -v23.virtualized.dns.public-name= + if set the process will use the supplied dns name (and port) without + resolution for its entry in the mounttable + -v23.virtualized.docker= + set if the process is running in a docker container and needs to configure + itself differently therein + -v23.virtualized.provider= + the name of the virtualization/cloud provider hosting this process if the + process needs to configure itself differently therein + -v23.virtualized.tcp.public-address= + if set the process will use this address (resolving via dns if appropriate) + for its entry in the mounttable + -v23.virtualized.tcp.public-protocol= + if set the process will use this protocol for its entry in the mounttable -v23.vtrace.cache-size=1024 - The number of vtrace traces to store in memory. + The number of vtrace traces to store in memory -v23.vtrace.collect-regexp= Spans and annotations that match this regular expression will trigger trace - collection. + collection -v23.vtrace.dump-on-shutdown=true - If true, dump all stored traces on runtime shutdown. + If true, dump all stored traces on runtime shutdown + -v23.vtrace.enable-aws-xray=false + Enable the use of AWS x-ray integration with vtrace + -v23.vtrace.root-span-name= + Set the name of the root vtrace span created by the runtime at startup -v23.vtrace.sample-rate=0 - Rate (from 0.0 to 1.0) to sample vtrace traces. + Rate (from 0.0 to 1.0) to sample vtrace traces -v23.vtrace.v=0 The verbosity level of the log messages to be captured in traces -vmodule= diff --git a/cmd/grail-access/ec2.go b/cmd/grail-access/ec2.go index 659b5060..4370d3d1 100644 --- a/cmd/grail-access/ec2.go +++ b/cmd/grail-access/ec2.go @@ -5,81 +5,38 @@ package main import ( - "fmt" "io/ioutil" "net/http" - "os" "time" - "github.com/grailbio/base/grail/data/v23data" + "github.com/grailbio/base/errors" + "github.com/grailbio/base/log" "github.com/grailbio/base/security/identity" - "v.io/v23" "v.io/v23/context" "v.io/v23/security" - "v.io/x/lib/cmdline" - "v.io/x/lib/vlog" - libsecurity "v.io/x/ref/lib/security" ) -const instanceIdentityURL = "http://169.254.169.254/latest/dynamic/instance-identity/pkcs7" +const defaultEc2BlesserFlag = "/ticket-server.eng.grail.com:8102/blesser/ec2" -func runEc2(ctx *context.T, env *cmdline.Env, args []string) error { - // TODO(razvanm): do we need to kill the v23agentd? - - // Best-effort cleanup. - os.RemoveAll(credentialsDirFlag) - - principal, err := libsecurity.CreatePersistentPrincipal(credentialsDirFlag, nil) - if err != nil { - vlog.Error(err) - return err - } - - ctx, err = v23.WithPrincipal(ctx, principal) - if err != nil { - vlog.Error(err) - return err +func fetchEC2Blessings(ctx *context.T) (security.Blessings, error) { + if blesserFlag == "" { + blesserFlag = defaultEc2BlesserFlag } - - stub := identity.Ec2BlesserClient(blesserEc2Flag) - doc := identityDocumentFlag - if doc == "" { - client := http.Client{ - Timeout: time.Duration(5 * time.Second), - } - resp, err := client.Get(instanceIdentityURL) - if err != nil { - vlog.Error(err) - return fmt.Errorf("unable to talk to the EC2 metadata server (not an EC2 instance?)") - } - b, err := ioutil.ReadAll(resp.Body) - resp.Body.Close() - vlog.VI(1).Infof("pkcs7: %d bytes", len(b)) - if err != nil { - return err - } - doc = string(b) + stub := identity.Ec2BlesserClient(blesserFlag) + client := http.Client{ + Timeout: 5 * time.Second, } - blessings, err := stub.BlessEc2(ctx, doc) + resp, err := client.Get(ec2InstanceIdentityFlag) if err != nil { - vlog.Error(err) - return err + return security.Blessings{}, errors.E("unable to talk to the EC2 metadata server (not an EC2 instance?)", err) } - - principal = v23.GetPrincipal(ctx) - principal.BlessingStore().SetDefault(blessings) - principal.BlessingStore().Set(blessings, security.AllPrincipals) - if err := security.AddToRoots(principal, blessings); err != nil { - vlog.Error(err) - return fmt.Errorf("failed to add blessings to recognized roots: %v", err) + identityDocument, err := ioutil.ReadAll(resp.Body) + if err2 := resp.Body.Close(); err2 != nil { + log.Print("warning: ", err2) } - - if err := v23data.InjectPipelineBlessings(ctx); err != nil { - vlog.Error(err) - return fmt.Errorf("failed to add the pipeline roots") + log.Debug.Printf("pkcs7: %d bytes", len(identityDocument)) + if err != nil { + return security.Blessings{}, err } - - dump(ctx, env) - - return nil + return stub.BlessEc2(ctx, string(identityDocument)) } diff --git a/cmd/grail-access/flag.go b/cmd/grail-access/flag.go new file mode 100644 index 00000000..81098d38 --- /dev/null +++ b/cmd/grail-access/flag.go @@ -0,0 +1,27 @@ +// Copyright 2022 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +package main + +import ( + "strings" +) + +// Strings is a comma-separated list of string flag, like +// `-myflag=foo,bar`. Subsequent flags will replace the previous list, not +// append: `-myflag=foo,bar -myflag=baz` yields []string{'baz'}. +type FlagStrings []string + +// String implements flag.Value. +func (is FlagStrings) String() string { return strings.Join(is, ",") } + +// Set implements flag.Value. +func (is *FlagStrings) Set(s string) error { + if s == "" { + *is = nil + } else { + *is = strings.Split(s, ",") + } + return nil +} diff --git a/cmd/grail-access/google.go b/cmd/grail-access/google.go index c3e2ba1e..cbbfc9a8 100644 --- a/cmd/grail-access/google.go +++ b/cmd/grail-access/google.go @@ -11,71 +11,36 @@ import ( "fmt" "net" "net/http" - "os" "strings" "sync" - "golang.org/x/oauth2" - goauth2 "google.golang.org/api/oauth2/v1" - "github.com/grailbio/base/grail/data/v23data" - "github.com/grailbio/base/cmdutil" + "github.com/grailbio/base/log" "github.com/grailbio/base/security/identity" "github.com/grailbio/base/web/webutil" - "v.io/v23" - v23context "v.io/v23/context" + "golang.org/x/oauth2" + goauth2 "google.golang.org/api/oauth2/v1" + vcontext "v.io/v23/context" "v.io/v23/security" - "v.io/x/lib/cmdline" "v.io/x/lib/vlog" - libsecurity "v.io/x/ref/lib/security" ) -func runGoogle(ctx *v23context.T, env *cmdline.Env, args []string) error { - // TODO(razvanm): do we need to kill the v23agentd? - - // Best-effort cleanup. - os.RemoveAll(credentialsDirFlag) - - principal, err := libsecurity.CreatePersistentPrincipal(credentialsDirFlag, nil) - if err != nil { - return err - } - - ctx, err = v23.WithPrincipal(ctx, principal) - if err != nil { - return err - } +const defaultGoogleBlesserFlag = "/ticket-server.eng.grail.com:8102/blesser/google" - idToken, err := fetchIDToken() - if err != nil { - return err +func fetchGoogleBlessings(ctx *vcontext.T) (security.Blessings, error) { + if blesserFlag == "" { + blesserFlag = defaultGoogleBlesserFlag } - - stub := identity.GoogleBlesserClient(blesserGoogleFlag) - blessings, err := stub.BlessGoogle(ctx, idToken) + idToken, err := fetchIDToken(ctx) if err != nil { - return err - } - - principal = v23.GetPrincipal(ctx) - principal.BlessingStore().SetDefault(blessings) - principal.BlessingStore().Set(blessings, "...") - if err := security.AddToRoots(principal, blessings); err != nil { - return fmt.Errorf("failed to add blessings to recognized roots: %v", err) + return security.Blessings{}, err } - - if err := v23data.InjectPipelineBlessings(ctx); err != nil { - vlog.Error(err) - return fmt.Errorf("failed to add the pipeline roots") - } - - dump(ctx, env) - - return nil + stub := identity.GoogleBlesserClient(blesserFlag) + return stub.BlessGoogle(ctx, idToken) } // fetchIDToken obtains a Google ID Token using an OAuth2 flow with Google. The // user will be instructed to use and URL or a browser will automatically open. -func fetchIDToken() (string, error) { +func fetchIDToken(ctx context.Context) (string, error) { stateBytes := make([]byte, 16) if _, err := rand.Read(stateBytes); err != nil { return "", err @@ -90,13 +55,13 @@ func fetchIDToken() (string, error) { return } if got, want := r.FormValue("state"), state; got != want { - cmdutil.Fatalf("Bad state: got %q, want %q", got, want) + log.Fatalf("Bad state: got %q, want %q", got, want) } code = r.FormValue("code") w.Header().Set("Content-Type", "text/html") // JavaScript only allows closing windows/tab that were open via // JavaScript. - fmt.Fprintf(w, `Code received. Please close this tab/window.`) + _, _ = fmt.Fprintf(w, `Code received. Please close this tab/window.`) wg.Done() }) @@ -107,7 +72,7 @@ func fetchIDToken() (string, error) { vlog.Infof("listening: %v\n", ln.Addr().String()) port := strings.Split(ln.Addr().String(), ":")[1] server := http.Server{Addr: "localhost:"} - go server.Serve(ln.(*net.TCPListener)) + go server.Serve(ln.(*net.TCPListener)) // nolint: errcheck config := &oauth2.Config{ ClientID: clientID, @@ -115,8 +80,8 @@ func fetchIDToken() (string, error) { Scopes: []string{goauth2.UserinfoEmailScope}, RedirectURL: fmt.Sprintf("http://localhost:%s", port), Endpoint: oauth2.Endpoint{ - AuthURL: "https://accounts.google.com/o/oauth2/v2/auth", - TokenURL: "https://accounts.google.com/o/oauth2/token", + AuthURL: googleOauth2Flag + "/v2/auth", + TokenURL: googleOauth2Flag + "/token", }, } @@ -126,7 +91,9 @@ func fetchIDToken() (string, error) { fmt.Printf("Opening %q...\n", url) if webutil.StartBrowser(url) { wg.Wait() - server.Shutdown(context.Background()) + if err = server.Shutdown(ctx); err != nil { + vlog.Errorf("shutting down: %v", err) + } } else { browserFlag = false } @@ -135,14 +102,16 @@ func fetchIDToken() (string, error) { if !browserFlag { config.RedirectURL = "urn:ietf:wg:oauth:2.0:oob" url := config.AuthCodeURL(state, oauth2.AccessTypeOnline) - fmt.Printf("The attempt to automatically open a browser failed. Please open the following link in your browse:\n\n\t%s\n\n", url) + fmt.Printf("The attempt to automatically open a browser failed. Please open the following link:\n\n\t%s\n\n", url) fmt.Printf("Paste the received code and then press enter: ") - fmt.Scanf("%s", &code) + if _, err := fmt.Scanf("%s", &code); err != nil { + return "", err + } fmt.Println("") } vlog.VI(1).Infof("code: %+v", code) - token, err := config.Exchange(oauth2.NoContext, code) + token, err := config.Exchange(ctx, code) if err != nil { return "", err } diff --git a/cmd/grail-access/k8s.go b/cmd/grail-access/k8s.go new file mode 100644 index 00000000..d444460b --- /dev/null +++ b/cmd/grail-access/k8s.go @@ -0,0 +1,57 @@ +package main + +import ( + "fmt" + "io/ioutil" + "path/filepath" + + "github.com/grailbio/base/security/identity" + "v.io/v23/context" + "v.io/v23/security" +) + +const defaultK8sBlesserFlag = "/ticket-server.eng.grail.com:8102/blesser/k8s" + +func fetchK8sBlessings(ctx *context.T) (blessing security.Blessings, err error) { + if blesserFlag == "" { + blesserFlag = defaultK8sBlesserFlag + } + stub := identity.K8sBlesserClient(blesserFlag) + + caCrt, namespace, token, err := getFiles() + if err != nil { + return blessing, err + } + + return stub.BlessK8s(ctx, caCrt, namespace, token, regionFlag) +} + +func getFiles() (caCrt, namespace, token string, err error) { + caCrtPath, err := filepath.Abs(caCrtFlag) + if err != nil { + return "", "", "", fmt.Errorf("parsing ca.crt path: %w", err) + } + namespacePath, err := filepath.Abs(namespaceFlag) + if err != nil { + return "", "", "", fmt.Errorf("parsing namespace path: %w", err) + } + tokenPath, err := filepath.Abs(tokenFlag) + if err != nil { + return "", "", "", fmt.Errorf("parsing token path: %w", err) + } + caCrtData, err := ioutil.ReadFile(caCrtPath) + if err != nil { + return "", "", "", fmt.Errorf("opening ca.crt: %w", err) + } + namespaceData, err := ioutil.ReadFile(namespacePath) + if err != nil { + return "", "", "", fmt.Errorf("opening namespace file: %w", err) + + } + tokenData, err := ioutil.ReadFile(tokenPath) + if err != nil { + return "", "", "", fmt.Errorf("opening token file: %w", err) + } + + return string(caCrtData), string(namespaceData), string(tokenData), err +} diff --git a/cmd/grail-access/main.go b/cmd/grail-access/main.go index 8da99b8c..7266aa19 100644 --- a/cmd/grail-access/main.go +++ b/cmd/grail-access/main.go @@ -3,23 +3,26 @@ // license that can be found in the LICENSE file. // The following enables go generate to generate the doc.go file. -//go:generate go run $GRAIL/go/src/vendor/v.io/x/lib/cmdline/testdata/gendoc.go "--build-cmd=go install" --copyright-notice= . -help +//go:generate go run v.io/x/lib/cmdline/gendoc "--build-cmd=go install" --copyright-notice= . -help + package main import ( - "flag" "fmt" "os" + "strings" "time" - "v.io/v23" - v23context "v.io/v23/context" + "github.com/grailbio/base/cmd/grail-access/remote" + "github.com/grailbio/base/errors" + "github.com/grailbio/base/log" + _ "github.com/grailbio/v23/factories/grail" + v23 "v.io/v23" + "v.io/v23/context" "v.io/v23/security" "v.io/x/lib/cmdline" "v.io/x/ref" - "v.io/x/ref/lib/v23cmd" - _ "v.io/x/ref/runtime/factories/grail" - "v.io/x/ref/services/agent/agentlib" + libsecurity "v.io/x/ref/lib/security" ) const ( @@ -28,36 +31,49 @@ const ( // is not secret in this case because it is part of client tool. It does act // as an identifier that allows restriction based on quota on the Google // side. - clientID = "27162366543-edih9cqc3t8p5hn9ord1k1n7h4oajfhm.apps.googleusercontent.com" - clientSecret = "eRZyFfe5xJu0083zDk8Mlb6K" + clientID = "fake" + clientSecret = "fake" ) var ( credentialsDirFlag string - ec2Flag bool - blesserGoogleFlag string - browserFlag bool + blesserFlag string + browserFlag bool + googleOauth2Flag string + ec2Flag bool + ec2InstanceIdentityFlag string + k8sFlag bool + regionFlag string + caCrtFlag string + namespaceFlag string + tokenFlag string - blesserEc2Flag string - identityDocumentFlag string + dumpFlag bool + doNotRefreshDurationFlag time.Duration + expiryCaveatFlag string - pipelineBlessings security.Blessings - pipelineStagingBlessings security.Blessings + blessRemotesFlag bool + blessRemotesModeFlag string + blessRemotesTargetsFlag FlagStrings ) func init() { - // We disable this flag because it's initialized with the value of the - // V23_CREDENTIALS environmental variable and that directory might be empty. - flag.Set("v23.credentials", "") - - // Prevent the v23agentd from running. - os.Setenv(ref.EnvCredentialsNoAgent, "1") + blessRemotesTargetsFlag = []string{os.ExpandEnv("ec2-name:ubuntu@adhoc.${USER}.*")} } -func newCmdRoot() *cmdline.Command { +func main() { + var defaultCredentialsDir string + if dir, ok := os.LookupEnv(ref.EnvCredentials); ok { + defaultCredentialsDir = dir + } else { + // TODO(josh): This expands to /.v23 if $HOME is undefined. + // We keep this for backwards compatibility, but maybe we shouldn't. + defaultCredentialsDir = os.ExpandEnv("${HOME}/.v23") + } + cmd := &cmdline.Command{ - Runner: v23cmd.RunnerFunc(run), + Runner: cmdline.RunnerFunc(run), Name: "grail-access", Short: "Creates fresh Vanadium credentials", Long: ` @@ -79,70 +95,162 @@ a '[server]:ec2:619867110810:role:adhoc:i-0aec7b085f8432699' blessing where 'server' is the blessing of the server. `, } - cmd.Flags.StringVar(&blesserGoogleFlag, "blesser-google", "/ticket-server.eng.grail.com:8102/blesser/google", "Blesser to talk to for the Google-based flow.") - cmd.Flags.StringVar(&blesserEc2Flag, "blesser-ec2", "/ticket-server.eng.grail.com:8102/blesser/ec2", "Blesser to talk to for the EC2-based flow.") - cmd.Flags.StringVar(&credentialsDirFlag, "dir", os.ExpandEnv("${HOME}/.v23"), "Where to store the Vanadium credentials. NOTE: the content will be erased if the credentials are regenerated.") - cmd.Flags.BoolVar(&ec2Flag, "ec2", false, "Use the role of the EC2 VM.") + cmd.Flags.StringVar(&credentialsDirFlag, "dir", defaultCredentialsDir, "Where to store the Vanadium credentials. NOTE: the content will be erased if the credentials are regenerated.") + cmd.Flags.StringVar(&blesserFlag, "blesser", "", "Flow specific blesser endpoint to use. Defaults to /ticket-server.eng.grail.com:8102/blesser/.") cmd.Flags.BoolVar(&browserFlag, "browser", os.Getenv("SSH_CLIENT") == "", "Attempt to open a browser.") - return cmd + cmd.Flags.StringVar(&googleOauth2Flag, "google-oauth2-url", + "https://accounts.google.com/o/oauth2", + "URL for oauth2 API calls, for testing") + cmd.Flags.BoolVar(&ec2Flag, "ec2", false, "Use the role of the EC2 VM.") + cmd.Flags.StringVar(&ec2InstanceIdentityFlag, "ec2-instance-identity-url", + "http://169.254.169.254/latest/dynamic/instance-identity/pkcs7", + "URL for fetching instance identity document, for testing") + cmd.Flags.BoolVar(&k8sFlag, "k8s", false, "Use the Kubernetes flow.") + cmd.Flags.StringVar(®ionFlag, "region", "us-west-2", "AWS EKS region to use for k8s cluster token review.") + cmd.Flags.StringVar(&caCrtFlag, "ca-crt", "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt", "Path to ca.crt file.") + cmd.Flags.StringVar(&namespaceFlag, "namespace", "/var/run/secrets/kubernetes.io/serviceaccount/namespace", "Path to namespace file.") + cmd.Flags.StringVar(&tokenFlag, "token", "/var/run/secrets/kubernetes.io/serviceaccount/token", "Path to token file.") + cmd.Flags.BoolVar(&dumpFlag, "dump", false, "If credentials are present, dump them on the console instead of refreshing them.") + cmd.Flags.DurationVar(&doNotRefreshDurationFlag, "do-not-refresh-duration", 7*24*time.Hour, "Do not refresh credentials if they are present and do not expire within this duration.") + cmd.Flags.StringVar(&expiryCaveatFlag, "expiry-caveat", "", "Duration of expiry caveat added to blessings (for testing); empty means no caveat added") + + // TODO(2022-10-18): Fix commentary generation to bring doc.go up to date. + // go.mod is currently broken such that required go tooling fails. We are + // apparently specifying old versions of protobuf related packages, which + // causes `go install` to fail, which causes doc generation to fail. + cmd.Flags.BoolVar(&blessRemotesFlag, "bless-remotes", true, "Whether to attempt to bless remotes with local blessings; only applies to Google blessings") + cmd.Flags.StringVar(&blessRemotesModeFlag, remote.FlagNameMode, "", "(INTERNAL) Controls the mode in which we run for the remote blessing protocol; one of {public-key,receive,send}") + cmd.Flags.Var(&blessRemotesTargetsFlag, "bless-remotes-targets", "Comma-separated list of targets to bless; targets may be \"ssh:[user@]host[:port]\" SSH destinations or \"ec2-name:[user@]ec2-instance-name-filter\" EC2 instance name filters; see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/Using_Filtering.html") + + cmdline.HideGlobalFlagsExcept() + cmdline.Main(cmd) } -func run(ctx *v23context.T, env *cmdline.Env, args []string) error { +func run(*cmdline.Env, []string) error { + if credentialsDirFlag == "" { + return fmt.Errorf("missing credentials dir, need -dir, $HOME, or $%s", ref.EnvCredentials) + } + if _, ok := os.LookupEnv(ref.EnvCredentials); !ok { - fmt.Printf("*******************************************************\n") - fmt.Printf("* WARNING: $V23_CREDENTIALS is not defined! *\n") + fmt.Print("*******************************************************\n") + fmt.Printf("* WARNING: $%s is not defined! *\n", ref.EnvCredentials) fmt.Printf("*******************************************************\n\n") - fmt.Printf("How to fix this in bash: export V23_CREDENTIALS=%s\n\n", credentialsDirFlag) + fmt.Printf("How to fix this in bash: export %s=%s\n\n", ref.EnvCredentials, credentialsDirFlag) } - agentPrincipal, err := agentlib.LoadPrincipal(credentialsDirFlag) - if err == nil { - // We have access to some credentials so we'll try to load them. - ctx, err = v23.WithPrincipal(ctx, agentPrincipal) - if err != nil { - return err - } - agentBlessings, _ := agentPrincipal.BlessingStore().Default() - if !agentBlessings.IsZero() { - principal := v23.GetPrincipal(ctx) - if err := principal.BlessingStore().SetDefault(agentBlessings); err != nil { - return err - } - if err := security.AddToRoots(principal, agentBlessings); err != nil { - return err - } + principal, err := libsecurity.LoadPersistentPrincipal(credentialsDirFlag, nil) + if err != nil { + log.Printf("INFO: Couldn't load principal from %s. Creating new one...", credentialsDirFlag) + _, createErr := libsecurity.CreatePersistentPrincipal(credentialsDirFlag, nil) + if createErr != nil { + return errors.E(fmt.Sprintf("failed to create new principal: %v, after load error: %v", createErr, err)) } - } else { - // We don't have access to credentials. Typically this happen on the first - // run when the credentials directory is empty. + principal, err = libsecurity.LoadPersistentPrincipal(credentialsDirFlag, nil) + } + if err != nil { + return errors.E("failed to load principal", err) } - b, _ := v23.GetPrincipal(ctx).BlessingStore().Default() - - now := time.Now() - if b.Expiry().After(now.Add(30 * time.Minute)) { - // If the blessing is not expired we show the current state and when - // the blessings will expire. - dump(ctx, env) - fmt.Printf("%s (%s)\n", b.Expiry().Local(), b.Expiry().Sub(now)) + ctx, shutDown := v23.Init() + defer shutDown() + ctx, err = v23.WithPrincipal(ctx, principal) + if err != nil { + return errors.E("failed to initialize context", err) + } + switch blessRemotesModeFlag { + case "": + // No-op. + case remote.ModePublicKey: + if err = remote.PrintPublicKey(ctx, os.Stdout); err != nil { + return errors.E("failed to print public key", err) + } + return nil + case remote.ModeReceive: + if err = remote.ReceiveBlessings(ctx, os.Stdin); err != nil { + return errors.E("failed to receive blessings", err) + } + return nil + default: + return errors.E("invalid -"+remote.FlagNameMode, blessRemotesModeFlag) + } + defaultBlessings, _ := principal.BlessingStore().Default() + if dumpFlag || defaultBlessings.Expiry().After(time.Now().Add(doNotRefreshDurationFlag)) { + dump(principal) + if err = maybeBlessRemotes(ctx); err != nil { + return err + } return nil } + + var blessings security.Blessings if ec2Flag { - return runEc2(ctx, env, args) + blessings, err = fetchEC2Blessings(ctx) + } else if k8sFlag { + blessings, err = fetchK8sBlessings(ctx) + } else { + blessings, err = fetchGoogleBlessings(ctx) + } + if err != nil { + return errors.E("failed to fetch blessings", err) + } + if expiryCaveatFlag != "" { + d, err := time.ParseDuration(expiryCaveatFlag) + if err != nil { + return errors.E("failed to parse expiry-caveat") + } + expiryCaveat, err := security.NewExpiryCaveat(time.Now().Add(d)) + if err != nil { + return errors.E("failed to make expiry caveat", err) + } + extension := fmt.Sprintf("expires-%v", d) + blessings, err = principal.Bless(principal.PublicKey(), blessings, extension, expiryCaveat) + if err != nil { + return errors.E("failed to make expired blessings", err) + } + } + if err = principal.BlessingStore().SetDefault(blessings); err != nil { + return errors.E(err, "failed to set default blessings") } - return runGoogle(ctx, env, args) + _, err = principal.BlessingStore().Set(blessings, security.AllPrincipals) + if err != nil { + return errors.E(err, "failed to set peer blessings") + } + if err := security.AddToRoots(principal, blessings); err != nil { + return errors.E(err, "failed to add blessing roots") + } + + fmt.Println("Successfully applied new blessing:") + dump(principal) + if err = maybeBlessRemotes(ctx); err != nil { + return err + } + return nil } -func dump(ctx *v23context.T, env *cmdline.Env) { - // Mimic the principal dump output. - principal := v23.GetPrincipal(ctx) +func dump(principal security.Principal) { + // Mimic the output of the v.io/x/ref/cmd/principal dump command. fmt.Printf("Public key: %s\n", principal.PublicKey()) fmt.Println("---------------- BlessingStore ----------------") - fmt.Fprintf(env.Stdout, principal.BlessingStore().DebugString()) + fmt.Print(principal.BlessingStore().DebugString()) fmt.Println("---------------- BlessingRoots ----------------") - fmt.Fprintf(env.Stdout, principal.Roots().DebugString()) + fmt.Print(principal.Roots().DebugString()) + + blessing, _ := principal.BlessingStore().Default() + fmt.Printf("Expires on %s (in %s)\n", blessing.Expiry().Local(), time.Until(blessing.Expiry())) } -func main() { - cmdline.HideGlobalFlagsExcept() - cmdline.Main(newCmdRoot()) +func maybeBlessRemotes(ctx *context.T) error { + if !blessRemotesFlag { + return nil + } + // The only use case for blessing remotes is for Google blessings, i.e. + // using the local browser for OAuth to bless a headless EC2 instance. + const prefix = "v23.grail.com:google:" + blessings, _ := v23.GetPrincipal(ctx).BlessingStore().Default() + if !strings.HasPrefix(blessings.String(), prefix) { + return nil + } + if err := remote.Bless(ctx, blessRemotesTargetsFlag); err != nil { + return errors.E("failed to send blessings to instances", err) + } + return nil } diff --git a/cmd/grail-access/manual_google_test.bash b/cmd/grail-access/manual_google_test.bash new file mode 100755 index 00000000..e9cb53b4 --- /dev/null +++ b/cmd/grail-access/manual_google_test.bash @@ -0,0 +1,55 @@ +#!/bin/bash + +# The grail-access Google authentication flow is difficult to exercise in a unit test because the +# outh2 server isn't trivial to fake (at least not for josh@). + +# Instead, we provide a manual test script. Please run this yourself. + +set -euo pipefail + +echo "Instructions:" +echo +echo "Running manual tests. If the script exits with a non-zero error code, the tests failed." +echo "You'll also be prompted to review output manually. If it doesn't look right, the tests failed." +echo +dir="$(mktemp -d)" +echo "Using temporary directory for test: $dir" +echo "Building grail-access for the test..." +cd "$( dirname "${BASH_SOURCE[0]}" )" +go build -o "$dir/grail-access" github.com/grailbio/base/cmd/grail-access +cd "$dir" +echo +echo "Step 1/3: Starting grail-access Google authentication flow. Please complete it." +echo +echo "************************************************************************" +./grail-access -dir ./v23 +echo "************************************************************************" +echo +echo "Done with authentication flow." +echo "If it succeeded, you should lines like these above:" +echo " Default Blessings v23.grail.com:google:YOUR_USERNAME@grailbio.com" +echo "and" +echo " ... v23.grail.com:google:YOUR_USERNAME@grailbio.com" +echo "and an expiration date in the future." +echo +read -p "Continue with next test? [Y] " +echo +echo "Step 2/3: Next, running the same flow, but automatically canceling." +echo +echo "************************************************************************" +set +e +cat /dev/null | ./grail-access -dir ./v23 -browser=false +set -e +echo "************************************************************************" +echo +echo "Step 3/3: Finally, make sure our Step 1 credentials survived. " +echo +echo "************************************************************************" +./grail-access -dir ./v23 -dump +echo "************************************************************************" +echo +echo "You should see the same blessing lines as in Step 1, and a consistent expiry time." +echo "If not, the tests failed." +echo +echo "Cleaning up test directory: $dir" +rm -rf "$dir" diff --git a/cmd/grail-access/remote/bless.go b/cmd/grail-access/remote/bless.go new file mode 100644 index 00000000..7720da02 --- /dev/null +++ b/cmd/grail-access/remote/bless.go @@ -0,0 +1,430 @@ +// Copyright 2022 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +package remote + +import ( + "bytes" + "fmt" + "os/exec" + "strings" + "text/template" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go/service/ec2/ec2iface" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/grailbio/base/cloud/awssession" + "github.com/grailbio/base/must" + v23 "v.io/v23" + "v.io/v23/context" + "v.io/v23/security" +) + +const ( + // awsTicketPath is the path of the ticket that provides AWS credentials + // for querying AWS/EC2 for running instances. + awsTicketPath = "tickets/eng/dev/aws" + // blessingsExtension is the extension added to the blessings sent to + // remotes. + blessingsExtension = "remote" + + // remoteExecS3Bucket is the bucket in which the known-compatible + // grail-access binary installed on remote targets is stored. + remoteExecS3Bucket = "grail-bin-public" + // remoteExecS3Key is the object key of the known-compatible grail-access + // binary installed on remote targets. + // TODO: Stop assuming single platform (Linux/AMD64) of targets. + remoteExecS3Key = "linux/amd64/2023-02-10.dev-201357/grail-access" + // remoteExecExpiry is the expiry of the presigned URL we generate to + // download (remoteExecS3Bucket, remoteExecS3Key). + remoteExecExpiry = 15 * time.Minute + // remoteExecSHA256 is the expected SHA-256 of the executable at + // (remoteExecS3Bucket, remoteExecS3Key). + remoteExecSHA256 = "eeede8ad76ee106735867facfe70d5ae917f645de3d7c6a7274cbd25da34460d" + // remoteExecPath is the path on the remote target at which we install and + // later invoke the grail-access executable. This string will be + // double-quoted in a bash script, so variable expansions can be used. + // + // See XDG Base Directory Specification: + // https://specifications.freedesktop.org/basedir-spec/basedir-spec-latest.html + remoteExecPath = "${XDG_DATA_HOME:-${HOME}/.local/share}/grail-access/grail-access" +) + +// Bless blesses the principals of targets with unconstrained extensions of +// the default blessings of the principal of ctx. See package documentation +// (doc.go) for a description of target strings. +func Bless(ctx *context.T, targets []string) error { + fmt.Println("---------------- Bless Remotes ----------------") + sess, err := awssession.NewWithTicket(ctx, awsTicketPath) + if err != nil { + return fmt.Errorf("creating AWS session: %v", err) + } + dests, err := resolveTargets(ctx, sess, targets) + if err != nil { + return fmt.Errorf("resolving targets: %v", err) + } + p := v23.GetPrincipal(ctx) + if p == nil { + return fmt.Errorf("no local principal") + } + blessings, _ := p.BlessingStore().Default() + for i, target := range targets { + fmt.Printf("%s:\n", target) + if len(dests[i]) == 0 { + fmt.Println(" ") + continue + } + for _, d := range dests[i] { + if !d.running { + fmt.Printf(" %-60s [ NOT RUNNING ]\n", d.s) + continue + } + if err := blessSSHDest(ctx, sess, p, blessings, d.s); err != nil { + return fmt.Errorf("blessing %q: %v", d.s, err) + } + fmt.Printf(" %-60s [ OK ]\n", d.s) + } + } + return nil +} + +type sshDest struct { + // s represents this destination. If running is true, then it is a valid + // SSH destination, i.e. we can connect to it using SSH. + s string + // running is false if we believe that the host is not currently running, + // e.g. because EC2 tells us so. Otherwise, it is true. + running bool +} + +// blessSSHDest uses commands over SSH to bless dest's principal. p is the +// blesser, and with are the blessings with which to bless dest's principal. +func blessSSHDest( + ctx *context.T, + sess *session.Session, + p security.Principal, + with security.Blessings, + dest string, +) error { + if err := ensureRemoteExec(ctx, sess, dest); err != nil { + return fmt.Errorf("ensuring remote executable (grail-access) is available: %v", err) + } + key, err := remotePublicKey(ctx, dest) + if err != nil { + return fmt.Errorf("getting remote public key: %v", err) + } + blessingSelf, err := keysEqual(key, p.PublicKey()) + if err != nil { + return fmt.Errorf("checking if blessing self: %v", err) + } + if blessingSelf { + return fmt.Errorf("cannot bless self; check that target is a remote machine/principal") + } + b, err := p.Bless(key, with, blessingsExtension, security.UnconstrainedUse()) + if err != nil { + return fmt.Errorf("blessing %v with %v: %v", key, with, err) + } + if err := sendBlessings(ctx, b, dest); err != nil { + return fmt.Errorf("sending blessings to %s: %v", dest, err) + } + return nil +} + +func ensureRemoteExec(ctx *context.T, sess *session.Session, dest string) error { + script, err := makeEnsureRemoteExecScript(sess) + if err != nil { + return fmt.Errorf( + "making script to ensure remote grail-access executable is available: %v", + err, + ) + } + cmd := sshCommand(ctx, dest, "bash -s") + cmd.Stdin = strings.NewReader(script) + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf( + "running installation script on %q: %v"+ + "\n--- std{err,out} ---\n%s", + dest, + err, + output, + ) + } + return nil +} + +func makeEnsureRemoteExecScript(sess *session.Session) (string, error) { + url, err := presignRemoteExecURL(sess) + if err != nil { + return "", fmt.Errorf("presigning URL of grail-access executable: %v", err) + } + // "Escape" single quotes, as we enclose the URL in single quotes in our + // generated script. + url = strings.ReplaceAll(url, "'", "'\\''") + var b strings.Builder + ensureRemoteExecTemplate.Execute(&b, map[string]string{ + "url": url, + "sha256": remoteExecSHA256, + "path": remoteExecPath, + }) + return b.String(), nil +} + +// ensureRemoteExecTemplate is the template for building the script used to +// ensure that the remote has a compatible grail-access binary installed. We +// inject the configuration for installation. +var ensureRemoteExecTemplate *template.Template + +func init() { + must.True(!strings.Contains(remoteExecSHA256, "'")) + ensureRemoteExecTemplate = template.Must(template.New("script").Parse(` +set -euxo pipefail + +# url is the S3 URL from which to fetch the grail-access binary that will run +# on the target. +url='{{.url}}' +# sha256 is the expected SHA-256 hash of the grail-access binary. +sha256='{{.sha256}}' + +# path is the path at which will we ultimately place the grail-access binary. +path="{{.path}}" +dir="$(dirname "${path}")" + +sha_bad=0 +echo "${sha256} ${path}" | sha256sum --check --quiet - || sha_bad=$? +if [ $sha_bad == 0 ]; then + # We already have the right binary. Ensure that it is executable. This + # should be a no-op unless it was changed externally. + chmod 700 "${path}" + exit +fi + +mkdir --mode=700 --parents "${dir}" +chmod 700 "${dir}" +path_download="$(mktemp "${path}.XXXXXXXXXX")" +trap "rm --force -- \"${path_download}\"" EXIT +curl --fail "${url}" --output "${path_download}" +echo "${sha256} ${path_download}" | sha256sum --check --quiet - +chmod 700 "${path_download}" +mv --force "${path_download}" "${path}" +`)) +} + +func remotePublicKey(ctx *context.T, dest string) (security.PublicKey, error) { + var ( + cmd = remoteExecCommand(ctx, dest, ModePublicKey) + stderr bytes.Buffer + ) + cmd.Stderr = &stderr + output, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf( + "running grail-access(in mode: %s) on remote: %v;"+ + "\n--- stderr ---\n%s", + ModePublicKey, + err, + stderr.String(), + ) + } + key, err := decodePublicKey(string(output)) + if err != nil { + return nil, fmt.Errorf("decoding public key %q: %v", string(output), err) + } + return key, nil +} + +func keysEqual(lhs, rhs security.PublicKey) (bool, error) { + lhsBytes, err := lhs.MarshalBinary() + if err != nil { + return false, fmt.Errorf("left-hand side of comparison invalid: %v", err) + } + rhsBytes, err := rhs.MarshalBinary() + if err != nil { + return false, fmt.Errorf("right-hand side of comparison invalid: %v", err) + } + return bytes.Equal(lhsBytes, rhsBytes), nil +} + +func sendBlessings(ctx *context.T, b security.Blessings, dest string) error { + var ( + cmd = remoteExecCommand(ctx, dest, ModeReceive) + blessingsString, err = encodeBlessings(b) + ) + if err != nil { + return fmt.Errorf("encoding blessings: %v", err) + } + _ = blessingsString + cmd.Stdin = strings.NewReader(blessingsString) + var stderr bytes.Buffer + cmd.Stderr = &stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf( + "running grail-access(in mode: %s) on remote: %v;"+ + "\n--- stderr ---\n%s", + ModeReceive, + err, + stderr.String(), + ) + } + return nil +} + +func remoteExecCommand(ctx *context.T, dest, mode string) *exec.Cmd { + return sshCommand( + ctx, + dest, + // Set a reasonable value V23_CREDENTIALS in case the target's bash + // does not configure it (in non-login shells). + "V23_CREDENTIALS=${HOME}/.v23", + remoteExecPath, "-"+FlagNameMode+"="+mode, + ) +} + +func sshCommand(ctx *context.T, dest string, args ...string) *exec.Cmd { + cmdArgs := []string{ + // Use batch mode which prevents prompting for an SSH passphrase. The + // prompt is more confusing than failing outright, as we run multiple + // SSH commands, so even if the user enters the correct passphrase, + // they will see more prompts. + "-o", "BatchMode yes", + // Don't check the identity of the remote host. + "-o", "StrictHostKeyChecking no", + // Don't store the identity of the remote host. + "-o", "UserKnownHostsFile /dev/null", + dest, + } + cmdArgs = append(cmdArgs, args...) + return exec.CommandContext(ctx, "ssh", cmdArgs...) +} + +// resolveTargets resolves targets into SSH destinations. Destinations are +// returned as a two-dimensional slice of length len(targets). Each entry +// corresponds to the input target and is a slice of the matching SSH +// destinations, if any. +// +// Note that for ec2-name targets, we make API calls to EC2 to resolve the +// corresponding hosts. A single ec2-name target may resolve to multiple (or +// zero) SSH destinations, as names are given as filters. +func resolveTargets(ctx *context.T, sess *session.Session, targets []string) ([][]sshDest, error) { + var dests = make([][]sshDest, len(targets)) + for i, target := range targets { + parts := strings.SplitN(target, ":", 2) + if len(parts) != 2 { + return nil, fmt.Errorf("target not in \"type:value\" format: %v", target) + } + var ( + typ = parts[0] + val = parts[1] + ec2API = ec2.New(sess) + ) + switch typ { + case "ssh": + dests[i] = append(dests[i], sshDest{s: val, running: true}) + case "ec2-name": + ec2Dests, err := resolveEC2Target(ctx, ec2API, val) + if err != nil { + return nil, fmt.Errorf("resolving EC2 target %v: %v", val, err) + } + dests[i] = append(dests[i], ec2Dests...) + default: + return nil, fmt.Errorf("invalid target type for %q: %v", target, typ) + } + } + return dests, nil +} + +func resolveEC2Target(ctx *context.T, ec2API ec2iface.EC2API, s string) ([]sshDest, error) { + var ( + user string + name string + ) + parts := strings.SplitN(s, "@", 2) + switch len(parts) { + case 1: + user = "ubuntu" + name = parts[0] + case 2: + user = parts[0] + name = parts[1] + default: + must.Never("SplitN returned invalid result") + } + instances, err := findInstances(ctx, ec2API, name) + if err != nil { + return nil, fmt.Errorf("finding instances matching %q: %v", name, err) + } + var dests []sshDest + for _, i := range instances { + if i.InstanceId == nil { + return nil, fmt.Errorf("instance has no ID: %s", i.String()) + } + if i.State == nil || i.State.Name == nil { + return nil, fmt.Errorf("instance has no state: %s", i.String()) + } + if *i.State.Name != ec2.InstanceStateNameRunning { + dests = append(dests, sshDest{ + s: fmt.Sprintf("%s@%s", user, *i.InstanceId), + running: false, + }) + continue + } + if i.PublicIpAddress == nil { + return nil, fmt.Errorf("running instance %q has no public IP address", *i.InstanceId) + } + dests = append(dests, sshDest{ + s: fmt.Sprintf("%s@%s", user, *i.PublicIpAddress), + running: true, + }) + } + return dests, nil +} + +func presignRemoteExecURL(sess *session.Session) (string, error) { + s3API := s3.New(sess) + req, _ := s3API.GetObjectRequest(&s3.GetObjectInput{ + Bucket: aws.String(remoteExecS3Bucket), + Key: aws.String(remoteExecS3Key), + }) + url, err := req.Presign(remoteExecExpiry) + if err != nil { + return "", fmt.Errorf( + "presigning URL for s3://%s/%s: %v", + remoteExecS3Bucket, + remoteExecS3Key, + err, + ) + } + return url, nil +} + +func findInstances(ctx *context.T, api ec2iface.EC2API, name string) ([]*ec2.Instance, error) { + input := &ec2.DescribeInstancesInput{ + Filters: []*ec2.Filter{ + { + Name: aws.String("tag:Name"), + Values: aws.StringSlice([]string{name}), + }, + }, + } + output, err := api.DescribeInstancesWithContext(ctx, input) + if err != nil { + return nil, fmt.Errorf( + "DescribeInstances error:\n%v\nDescribeInstances request:\n%v", + err, + input, + ) + } + return reservationsInstances(output.Reservations), nil +} + +func reservationsInstances(reservations []*ec2.Reservation) []*ec2.Instance { + instances := []*ec2.Instance{} + for _, r := range reservations { + instances = append(instances, r.Instances...) + } + return instances +} diff --git a/cmd/grail-access/remote/bless_test.go b/cmd/grail-access/remote/bless_test.go new file mode 100644 index 00000000..f3c74562 --- /dev/null +++ b/cmd/grail-access/remote/bless_test.go @@ -0,0 +1,16 @@ +// Copyright 2022 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +package remote_test + +import ( + "testing" + + _ "github.com/grailbio/base/cmd/grail-access/remote" +) + +// TestInit verifies that init code does not panic. +func TestInit(t *testing.T) { + // This space is intentionally left blank. +} diff --git a/cmd/grail-access/remote/doc.go b/cmd/grail-access/remote/doc.go new file mode 100644 index 00000000..ef35f8c7 --- /dev/null +++ b/cmd/grail-access/remote/doc.go @@ -0,0 +1,75 @@ +// Copyright 2022 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +/* +Package remote implements sending (and receiving) of blessings to remote +machines over SSH. + +The remote machine must be accessible by SSH and have a version of grail-access +in $PATH that supports remote blessing. + +The protocol looks like this: + + +-------+ +---------+ + | Local | | Remote | + +-------+ +---------+ + | | + | grail-access -bless-remotes | + |---------------------------- | + | | | + |<--------------------------- | + | | + | ssh dest grail-access -bless-remotes-mode=PublicKey | + |--------------------------------------------------------->| + | | + | [remote principal public key] | + |<---------------------------------------------------------| + | | + | blessings <= bless remote principal public key | + |----------------------------------------------- | + | | | + |<---------------------------------------------- | + | | + | ssh dest grail-access -bless-remotes-mode=Receive | + |--------------------------------------------------------->| + | | + | transmit blessings (on stdout) | + |--------------------------------------------------------->| + | | + | | set blessings + | |-------------- + | | | + | |<------------- + | | + +Remote machines are specified by the -bless-remotes-targets flag which accepts +a comma-separated list of targets. There are two types of targets: SSH +destinations and EC2 names, specified with "ssh:" and "ec2-name:" respectively. + +SSH destination targets are destinations as ssh accepts, [user@]host[:port], +e.g.: + ssh:10.1.0.120 + ssh:ubuntu@ec2-34-214-222-123.us-west-2.compute.amazonaws.com + ssh:10.1.0.120:822 + +EC2 name targets use AWS EC2 instance names (i.e. the value of the Name tag), +[user@]instancename, e.g.: + ec2-name:my-instance-name + ec2-name:core@another-instance + +EC2 names are treated as filters, so "ec2-name:core@my-*-name" will target all +instances matching "my-*-name" (and ssh them as user "core"). See +https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/Using_Filtering.html . + +Example: + grail-access -bless-remotes -bless-remotes-targets="ssh:me@mine.com,ec2-name:my-instance-*" + +This invocation will target the SSH destination "me@mine.com" as well as all +EC2 instances whose Name tag matches "my-instance-*" (using the default ssh +username). + +Note that we don't yet support custom ports for ec2-name targets, as ':' is a +valid character in names, and we are preferring to keep the parsing simple. +*/ +package remote diff --git a/cmd/grail-access/remote/encoding.go b/cmd/grail-access/remote/encoding.go new file mode 100644 index 00000000..b01427ca --- /dev/null +++ b/cmd/grail-access/remote/encoding.go @@ -0,0 +1,55 @@ +// Copyright 2022 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +package remote + +import ( + "encoding/base64" + "fmt" + + "v.io/v23/security" + "v.io/v23/vom" +) + +var b64 = base64.StdEncoding + +func encodeBlessings(b security.Blessings) (string, error) { + bs, err := vom.Encode(b) + if err != nil { + return "", fmt.Errorf("vom-encoding blessings: %v", err) + } + return b64.EncodeToString(bs), nil +} + +func decodeBlessings(s string) (security.Blessings, error) { + b, err := b64.DecodeString(s) + if err != nil { + return security.Blessings{}, fmt.Errorf("base64 decoding blessings string: %v", err) + } + var blessings security.Blessings + if err := vom.Decode(b, &blessings); err != nil { + return security.Blessings{}, fmt.Errorf("vom-decoding: %v", err) + } + return blessings, nil +} + +func encodePublicKey(k security.PublicKey) (string, error) { + der, err := k.MarshalBinary() + if err != nil { + return "", fmt.Errorf("corrupted public key: %v", err) + } + return b64.EncodeToString(der), nil +} + +func decodePublicKey(s string) (security.PublicKey, error) { + bs, err := b64.DecodeString(s) + if err != nil { + return nil, fmt.Errorf("base64-decoding public key: %v", err) + } + key, err := security.UnmarshalPublicKey(bs) + if err != nil { + return nil, fmt.Errorf("unmarshalling public key: %v", err) + } + return key, nil +} diff --git a/cmd/grail-access/remote/flag.go b/cmd/grail-access/remote/flag.go new file mode 100644 index 00000000..a1106a94 --- /dev/null +++ b/cmd/grail-access/remote/flag.go @@ -0,0 +1,19 @@ +// Copyright 2022 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +package remote + +const ( + // FlagNameMode is the name of the string flag used to set the mode of + // grail-access for sending and receiving blessings. + FlagNameMode = "internal-bless-remotes-mode" + // ModeSend initiates the full sender workflow. See package documentation. + ModeSend = "send" + // ModePublicKey causes grail-access to print the local principal's public + // key. + ModePublicKey = "public-key" + // ModeReceive causes grail-access to read blessings from os.Stdin and set + // them as both the default and for all principal peers. + ModeReceive = "receive" +) diff --git a/cmd/grail-access/remote/publickey.go b/cmd/grail-access/remote/publickey.go new file mode 100644 index 00000000..c0e5f186 --- /dev/null +++ b/cmd/grail-access/remote/publickey.go @@ -0,0 +1,32 @@ +// Copyright 2022 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +package remote + +import ( + "fmt" + "io" + + v23 "v.io/v23" + "v.io/v23/context" +) + +// PrintPublicKey prints the principal of ctx to w (to be read and decoded by +// Bless). +func PrintPublicKey(ctx *context.T, w io.Writer) error { + p := v23.GetPrincipal(ctx) + if p == nil { + // We rely on the caller to set up the principal before making this + // call. + return fmt.Errorf("no local principal to bless") + } + publicKeyString, err := encodePublicKey(p.PublicKey()) + if err != nil { + return fmt.Errorf("encoding public key: %v", err) + } + if _, err := fmt.Fprintln(w, publicKeyString); err != nil { + return fmt.Errorf("printing public key: %v", err) + } + return nil +} diff --git a/cmd/grail-access/remote/receive.go b/cmd/grail-access/remote/receive.go new file mode 100644 index 00000000..6cf112a3 --- /dev/null +++ b/cmd/grail-access/remote/receive.go @@ -0,0 +1,46 @@ +// Copyright 2022 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +package remote + +import ( + "fmt" + "io" + + v23 "v.io/v23" + "v.io/v23/context" + "v.io/v23/security" +) + +// ReceiveBlessings reads encoded blessings from r and sets them as the default +// blessings and as blessings for all principal peers. +func ReceiveBlessings(ctx *context.T, r io.Reader) error { + p := v23.GetPrincipal(ctx) + if p == nil { + // We rely on the caller to set up the principal before making this + // call. + return fmt.Errorf("no local principal to bless") + } + // Read a single-line encoding of the received blessing, and set them as + // both the default and for all peer principals. + input, err := io.ReadAll(r) + if err != nil { + return fmt.Errorf("reading input: %v", err) + } + b, err := decodeBlessings(string(input)) + if err != nil { + return fmt.Errorf("decoding blessings string: %v", err) + } + store := p.BlessingStore() + if err := store.SetDefault(b); err != nil { + return fmt.Errorf("setting blessings %v as default: %v", b, err) + } + if _, err := store.Set(b, security.AllPrincipals); err != nil { + return fmt.Errorf("setting blessings %v for peers %v: %v", b, security.AllPrincipals, err) + } + if err := security.AddToRoots(p, b); err != nil { + return fmt.Errorf("adding blessings to recognized roots: %v", err) + } + return nil +} diff --git a/cmd/grail-file/cmd/cat.go b/cmd/grail-file/cmd/cat.go new file mode 100644 index 00000000..b44ed42d --- /dev/null +++ b/cmd/grail-file/cmd/cat.go @@ -0,0 +1,23 @@ +package cmd + +import ( + "context" + "io" + + "github.com/grailbio/base/errors" + "github.com/grailbio/base/file" +) + +func Cat(ctx context.Context, out io.Writer, args []string) (err error) { + for _, arg := range expandGlobs(ctx, args) { + f, err := file.Open(ctx, arg) + if err != nil { + return errors.E(err, "cat", arg) + } + defer errors.CleanUpCtx(ctx, f.Close, &err) + if _, err = io.Copy(out, f.Reader(ctx)); err != nil { + return errors.E(err, "cat", arg) + } + } + return nil +} diff --git a/cmd/grail-file/cmd/cmd.go b/cmd/grail-file/cmd/cmd.go new file mode 100644 index 00000000..b5c55394 --- /dev/null +++ b/cmd/grail-file/cmd/cmd.go @@ -0,0 +1,170 @@ +package cmd + +import ( + "context" + "fmt" + "io" + "os" + "runtime" + "strings" + "sync" + + "github.com/gobwas/glob" + "github.com/gobwas/glob/syntax" + "github.com/gobwas/glob/syntax/ast" + "github.com/grailbio/base/cmd/grail-file/cmd/internal/semaphore" + "github.com/grailbio/base/errors" + "github.com/grailbio/base/file" +) + +var commands = []struct { + name string + callback func(ctx context.Context, out io.Writer, args []string) error + help string +}{ + {"cat", Cat, `Cat prints contents of the files to the stdout. It supports globs defined in https://github.com/gobwas/glob.`}, + {"put", Put, `Put stores stdin to the provided path.`}, + {"ls", Ls, `List files`}, + {"rm", Rm, `Rm removes files. It supports globs defined in https://github.com/gobwas/glob.`}, + {"cp", Cp, `Cp copies files. It can be invoked in three forms: + +1. cp src dst +2. cp src dst/ +3. cp src.... dstdir + +The first form first tries to copy file src to dst. If dst exists as a +directory, it copies src to dst/, where is the basename of the +source file. + +The second form copies file src to dst/. + +The third form copies each of "src" to destdir/. + +This command supports globs defined in https://github.com/gobwas/glob.`}, +} + +func PrintHelp() { + fmt.Fprintln(os.Stderr, "Subcommands:") + for _, c := range commands { + fmt.Fprintf(os.Stderr, "%s: %s\n", c.name, c.help) + } +} + +func Run(ctx context.Context, args []string) error { + + if len(args) == 0 { + PrintHelp() + return errors.E("No subcommand given") + } + for _, c := range commands { + if c.name == args[0] { + return c.callback(ctx, os.Stdout, args[1:]) + } + } + PrintHelp() + return errors.E("unknown command", args[0]) +} + +// parLimiter controls concurrency as well as total memory buffer capacity. +// grail-file is used on both small-ish laptops in office, etc. and large EC2 instances, so we +// choose to scale with number of CPUs. The exact numbers are somewhat arbitrary. +// A large-ish buffer size improves S3 throughput, at least in EC2. +var parLimiter = semaphore.New(32*runtime.NumCPU(), 1<<20) + +// forEachFile runs the callback for every file under the directory in +// parallel. It returns any of the errors returned by the callback. +func forEachFile(ctx context.Context, dir string, callback func(path string) error) error { + err := errors.Once{} + wg := sync.WaitGroup{} + ch := make(chan string, parLimiter.Cap()*100) + for i := 0; i < parLimiter.Cap(); i++ { + wg.Add(1) + go func() { + for path := range ch { + err.Set(callback(path)) + } + wg.Done() + }() + } + + lister := file.List(ctx, dir, true /*recursive*/) + for lister.Scan() { + if !lister.IsDir() { + ch <- lister.Path() + } + } + close(ch) + err.Set(lister.Err()) + wg.Wait() + return err.Err() +} + +// parseGlob parses a string that potentially contains glob metacharacters, and +// returns (nonglobprefix, hasglob). If the string does not contain any glob +// metacharacter, this function returns (str, false). Else, it returns the +// prefix of path elements up to the element containing a glob character. +// +// For example, parseGlob("foo/bar/baz*/*.txt" returns ("foo/bar", true). +func parseGlob(str string) (string, bool) { + node, err := syntax.Parse(str) + if err != nil { + panic(err) + } + if node.Kind != ast.KindPattern || len(node.Children) == 0 { + panic(node) + } + if node.Children[0].Kind != ast.KindText { + return "", true + } + if len(node.Children) == 1 { + return str, false + } + nonGlobPrefix := node.Children[0].Value.(ast.Text).Text + if i := strings.LastIndexByte(nonGlobPrefix, '/'); i > 0 { + nonGlobPrefix = nonGlobPrefix[:i+1] + } + return nonGlobPrefix, true +} + +// expandGlob expands the given glob string. If the string does not contain a +// glob metacharacter, or on any error, it returns {str}. +func expandGlob(ctx context.Context, str string) []string { + nonGlobPrefix, hasGlob := parseGlob(str) + if !hasGlob { + return []string{str} + } + m, err := glob.Compile(str) + if err != nil { + return []string{str} + } + + globSuffix := str[len(nonGlobPrefix):] + if strings.HasSuffix(globSuffix, "/") { + globSuffix = globSuffix[:len(globSuffix)-1] + } + recursive := len(strings.Split(globSuffix, "/")) > 1 || strings.Contains(globSuffix, "**") + + lister := file.List(ctx, nonGlobPrefix, recursive) + matches := []string{} + for lister.Scan() { + if m.Match(lister.Path()) { + matches = append(matches, lister.Path()) + } + } + if err := lister.Err(); err != nil { + return []string{str} + } + if len(matches) == 0 { + return []string{str} + } + return matches +} + +// expandGlobs calls expandGlob on each string and unions the results. +func expandGlobs(ctx context.Context, patterns []string) []string { + matches := []string{} + for _, pattern := range patterns { + matches = append(matches, expandGlob(ctx, pattern)...) + } + return matches +} diff --git a/cmd/grail-file/cmd/cp.go b/cmd/grail-file/cmd/cp.go new file mode 100644 index 00000000..a6151128 --- /dev/null +++ b/cmd/grail-file/cmd/cp.go @@ -0,0 +1,149 @@ +package cmd + +import ( + "context" + "flag" + "fmt" + "io" + "os" + + "strings" + + "github.com/grailbio/base/errors" + "github.com/grailbio/base/file" + "github.com/grailbio/base/file/s3file" + "github.com/grailbio/base/ioctx" + "github.com/grailbio/base/traverse" +) + +func Cp(ctx context.Context, out io.Writer, args []string) error { + var ( + flags flag.FlagSet + verboseFlag = flags.Bool("v", false, "Enable verbose logging") + recursiveFlag = flags.Bool("R", false, "Recursive copy") + ) + if err := flags.Parse(args); err != nil { + return err + } + args = flags.Args() + + // Copy a regular file. The first return value is true if the source exists as + // a regular file. + copyRegularFile := func(src, dst string) (bool, error) { + if *verboseFlag { + fmt.Fprintf(os.Stderr, "%s -> %s\n", src, dst) // nolint: errcheck + } + in, err := file.Open(ctx, src) + if err != nil { + return false, err + } + defer in.Close(ctx) // nolint: errcheck + // If the file "src" doesn't exist, either Open or Stat should fail. + if _, err := in.Stat(ctx); err != nil { + return false, err + } + out, err := file.Create(ctx, dst) + if err != nil { + return true, errors.E(err, fmt.Sprintf("cp %v->%v", src, dst)) + } + if err = copyFile(ctx, out, in); err != nil { + _ = out.Close(ctx) + return true, errors.E(err, fmt.Sprintf("cp %v->%v", src, dst)) + } + err = out.Close(ctx) + if err != nil { + err = errors.E(err, fmt.Sprintf("cp %v->%v", src, dst)) + } + return true, err + } + + // Copy a regular file or a directory. + copyFile := func(src, dst string) error { + if srcExists, err := copyRegularFile(src, dst); srcExists || !*recursiveFlag { + return err + } + return forEachFile(ctx, src, func(path string) error { + suffix := path[len(src):] + for len(suffix) > 0 && suffix[0] == '/' { + suffix = suffix[1:] + } + _, e := copyRegularFile(file.Join(src, suffix), file.Join(dst, suffix)) + return e + }) + } + + copyFileInDir := func(src, dstDir string) error { + return copyFile(src, file.Join(dstDir, file.Base(src))) + } + + nArg := len(args) + if nArg < 2 { + return errors.New("Usage: cp src... dst") + } + dst := args[nArg-1] + if _, hasGlob := parseGlob(dst); hasGlob { + return fmt.Errorf("cp: destination %s cannot be a glob", dst) + } + srcs := expandGlobs(ctx, args[:nArg-1]) + if len(srcs) == 1 { + // Try copying to dst. Failing that, copy to dst/. + if !strings.HasSuffix(dst, "/") && copyFile(srcs[0], dst) == nil { + return nil + } + return copyFileInDir(srcs[0], dst) + } + return traverse.Each(len(srcs), func(i int) error { + return copyFileInDir(srcs[i], dst) + }) +} + +var copyFileChunkSize = int64(s3file.ReadChunkBytes()) + +// TODO: Move copyFile to a common location. +func copyFile(ctx context.Context, dst file.File, src file.File) error { + // TODO: Use dst.WriterAt(), after it's introduced. + dstAt, dstOK := dst.(ioctx.WriterAt) + if !dstOK { + return copyStream(ctx, dst.Writer(ctx), src.Reader(ctx)) + } + info, err := src.Stat(ctx) + if err != nil { + return err + } + size := info.Size() + nChunks := int((size + copyFileChunkSize - 1) / copyFileChunkSize) + return traverse.Each(nChunks, func(chunkIdx int) (err error) { + offset := int64(chunkIdx) * copyFileChunkSize + wantN := size - offset + if wantN > copyFileChunkSize { + wantN = copyFileChunkSize + } + srcR := src.OffsetReader(offset) + defer errors.CleanUpCtx(ctx, srcR.Close, &err) + return copyStream(ctx, + ioctx.ToStdWriter(ctx, &offsetWriter{at: dstAt, offset: offset}), + io.LimitReader(ioctx.ToStdReader(ctx, srcR), wantN), + ) + }) +} + +func copyStream(ctx context.Context, dst io.Writer, src io.Reader) error { + item, err := parLimiter.Acquire(ctx) + if err != nil { + return err + } + defer item.Release() + _, err = io.CopyBuffer(dst, src, item.Buf()) + return err +} + +type offsetWriter struct { + at ioctx.WriterAt + offset int64 +} + +func (w *offsetWriter) Write(ctx context.Context, p []byte) (int, error) { + n, err := w.at.WriteAt(ctx, p, w.offset) + w.offset += int64(n) + return n, err +} diff --git a/cmd/grail-file/cmd/glob_test.go b/cmd/grail-file/cmd/glob_test.go new file mode 100644 index 00000000..4a60ab92 --- /dev/null +++ b/cmd/grail-file/cmd/glob_test.go @@ -0,0 +1,56 @@ +package cmd + +import ( + "context" + "strings" + "testing" + + "github.com/grailbio/base/file" + "github.com/grailbio/testutil" + "github.com/grailbio/testutil/assert" +) + +func TestParseGlob(t *testing.T) { + doParse := func(str string) string { + prefix, hasGlob := parseGlob(str) + if !hasGlob { + return "none" + } + return prefix + } + assert.EQ(t, "none", doParse("s3://a/b/c")) + assert.EQ(t, "none", doParse("s3://a/b\\*/c")) + assert.EQ(t, "s3://a/", doParse("s3://a/b*/c")) + assert.EQ(t, "s3://a/b/", doParse("s3://a/b/*")) + assert.EQ(t, "s3://a/", doParse("s3://a/b?")) + assert.EQ(t, "s3://a/", doParse("s3://a/**/b")) + assert.EQ(t, "", doParse("**")) +} + +func TestExpandGlob(t *testing.T) { + ctx := context.Background() + tmpDir, cleanup := testutil.TempDir(t, "", "") + defer cleanup() + src0Path := file.Join(tmpDir, "abc/def/tmp0") + src1Path := file.Join(tmpDir, "abd/efg/hij/tmp1") + src2Path := file.Join(tmpDir, "tmp0") + assert.NoError(t, file.WriteFile(ctx, src0Path, []byte("a"))) + assert.NoError(t, file.WriteFile(ctx, src1Path, []byte("b"))) + assert.NoError(t, file.WriteFile(ctx, src2Path, []byte("c"))) + + doExpand := func(str string) string { + matches := expandGlob(ctx, tmpDir+"/"+str) + for i := range matches { + matches[i] = matches[i][len(tmpDir)+1:] // remove the tmpDir part. + } + return strings.Join(matches, ",") + } + + assert.EQ(t, "abc/def/tmp0", doExpand("abc/*/tmp0")) + assert.EQ(t, "xxx/yyy", doExpand("xxx/yyy")) + assert.EQ(t, "xxx/*", doExpand("xxx/*")) + assert.EQ(t, "abc/def/tmp0", doExpand("a*/*/tmp0")) + assert.EQ(t, "abd/efg/hij/tmp1", doExpand("abd/**/tmp*")) + assert.EQ(t, "abc/def/tmp0,abd/efg/hij/tmp1", doExpand("a*/**/tmp*")) + assert.EQ(t, "abc/def/tmp0,abd/efg/hij/tmp1,tmp0", doExpand("**")) +} diff --git a/cmd/grail-file/cmd/internal/semaphore/semaphore.go b/cmd/grail-file/cmd/internal/semaphore/semaphore.go new file mode 100644 index 00000000..0a7664f3 --- /dev/null +++ b/cmd/grail-file/cmd/internal/semaphore/semaphore.go @@ -0,0 +1,80 @@ +package semaphore + +import ( + "context" +) + +type ( + // S is a concurrency limiter. Holders of a resource unit (represented by + // *Item) also get temporary ownership of a memory buffer. + // TODO: Consider implementing some ordering or prioritization. For example, while copying + // many files in multiple chunks each, maybe we prefer to use available concurrency to complete + // all the chunks of one file rather than a couple of chunks of many, so that if there's an + // error we have made some useful progress. + S struct{ c chan *Item } + Item struct { + // bufSize describes the eventual size of buf, since it's allocated lazily. + bufSize int + // buf is scratch space that the (temporary) owner of this item can use while holding + // this item. That is, they must not use it after release. + buf []byte + // releaseTo is the collection that Item will be returned to, when the user is done. + // It's set only when the *Item is in use, for best-effort detection of incorrect API usage. + releaseTo S + } +) + +// New constructs S. Total possible buffer allocation is capacity*bufSize, but it's allocated lazily +// per-Item. +func New(capacity int, bufSize int) S { + s := S{make(chan *Item, capacity)} + for i := 0; i < capacity; i++ { + s.c <- &Item{bufSize: bufSize} + } + return s +} + +// Acquire waits for semaphore capacity, respecting context cancellation. +// Returns exactly one of item or error. If non-nil, Item must be released. +// +// Intended usage: +// item, err := sema.Acquire(ctx) +// if err != nil { +// return /* ..., */ err +// ) +// defer item.Release() +// doMyThing(item.Buf()) +func (s S) Acquire(ctx context.Context) (*Item, error) { + select { + case item := <-s.c: + item.releaseTo = s + return item, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// Cap returns the capacity (that was passed to New). +func (s S) Cap() int { return cap(s.c) } + +// Buf borrows this item's buffer. Caller must not use it after releasing the item. +func (i *Item) Buf() []byte { + if len(i.buf) < i.bufSize { + i.buf = make([]byte, i.bufSize) + } + return i.buf +} + +// Release returns capacity to the semaphore. It must be called exactly once after acquisition. +func (i *Item) Release() { + if i == nil { + panic("usage error: Release after failed Acquire") + } + if i.releaseTo.c == nil { + panic("usage error: multiple Release") + } + releaseTo := i.releaseTo + i.releaseTo = S{} + // Note: We must modify i before sending it to the channel, to avoid a race. + releaseTo.c <- i +} diff --git a/cmd/grail-file/cmd/ls.go b/cmd/grail-file/cmd/ls.go new file mode 100644 index 00000000..543b6e55 --- /dev/null +++ b/cmd/grail-file/cmd/ls.go @@ -0,0 +1,70 @@ +package cmd + +import ( + "context" + "flag" + "fmt" + "io" + + "github.com/grailbio/base/file" +) + +func Ls(ctx context.Context, out io.Writer, args []string) error { + var ( + flags flag.FlagSet + longOutputFlag = flags.Bool("l", false, "Print file size and last modification time") + recursiveFlag = flags.Bool("R", false, "Descend into directories recursively") + ) + if err := flags.Parse(args); err != nil { + return err + } + type result struct { + err error + lines chan string // stream of entries found for an arg, closed when done + } + longOutput := func(path string, info file.Info) string { + // TODO(saito) prettyprint + const iso8601 = "2006-01-02T15:04:05-0700" + return fmt.Sprintf("%s\t%d\t%s", path, info.Size(), info.ModTime().Format(iso8601)) + } + args = expandGlobs(ctx, flags.Args()) + results := make([]result, len(args)) + for i := range args { + results[i].lines = make(chan string, 10000) + go func(path string, r *result) { + defer close(r.lines) + // Check if the file is a regular file + if info, err := file.Stat(ctx, path); err == nil { + if *longOutputFlag { + r.lines <- longOutput(path, info) + } else { + r.lines <- path + } + return + } + lister := file.List(ctx, path, *recursiveFlag) + for lister.Scan() { + switch { + case lister.IsDir(): + r.lines <- lister.Path() + "/" + case *longOutputFlag: + r.lines <- longOutput(lister.Path(), lister.Info()) + default: + r.lines <- lister.Path() + } + } + r.err = lister.Err() + }(args[i], &results[i]) + } + // Print the results in order. + var err error + for i := range results { + for line := range results[i].lines { + _, _ = fmt.Fprintln(out, line) + } + if err2 := results[i].err; err2 != nil && err == nil { + err = err2 + } + } + return err +} diff --git a/cmd/grail-file/cmd/put.go b/cmd/grail-file/cmd/put.go new file mode 100644 index 00000000..5466a576 --- /dev/null +++ b/cmd/grail-file/cmd/put.go @@ -0,0 +1,26 @@ +package cmd + +import ( + "context" + "io" + "os" + + "github.com/grailbio/base/errors" + "github.com/grailbio/base/file" +) + +func Put(ctx context.Context, out io.Writer, args []string) (err error) { + if len(args) != 1 { + return errors.New("put requires a single path") + } + arg := args[0] + f, err := file.Create(ctx, arg) + if err != nil { + return errors.E(err, "put", arg) + } + defer errors.CleanUpCtx(ctx, f.Close, &err) + if _, err = io.Copy(f.Writer(ctx), os.Stdin); err != nil { + return errors.E(err, "put", arg) + } + return nil +} diff --git a/cmd/grail-file/cmd/rm.go b/cmd/grail-file/cmd/rm.go new file mode 100644 index 00000000..6de479a4 --- /dev/null +++ b/cmd/grail-file/cmd/rm.go @@ -0,0 +1,39 @@ +package cmd + +import ( + "context" + "flag" + "fmt" + "io" + "os" + + "github.com/grailbio/base/file" + "github.com/grailbio/base/traverse" +) + +func Rm(ctx context.Context, out io.Writer, args []string) error { + var ( + flags flag.FlagSet + verboseFlag = flags.Bool("v", false, "Enable verbose logging") + recursiveFlag = flags.Bool("R", false, "Recursive remove") + ) + if err := flags.Parse(args); err != nil { + return err + } + args = expandGlobs(ctx, flags.Args()) + return traverse.Each(len(args), func(i int) error { + path := args[i] + if *verboseFlag { + fmt.Fprintf(os.Stderr, "%s\n", path) // nolint: errcheck + } + if *recursiveFlag { + return forEachFile(ctx, path, func(path string) error { + if *verboseFlag { + fmt.Fprintf(os.Stderr, "%s\n", path) // nolint: errcheck + } + return file.Remove(ctx, path) + }) + } + return file.Remove(ctx, path) + }) +} diff --git a/cmd/grail-file/main.go b/cmd/grail-file/main.go index 2e36603a..0e2d475f 100644 --- a/cmd/grail-file/main.go +++ b/cmd/grail-file/main.go @@ -2,281 +2,29 @@ package main import ( "context" - "fmt" - "io" + "flag" "os" - "strings" - "github.com/gobwas/glob" - "github.com/gobwas/glob/syntax" - "github.com/gobwas/glob/syntax/ast" - "github.com/grailbio/base/cmdutil" + "github.com/grailbio/base/cmd/grail-file/cmd" "github.com/grailbio/base/file" + "github.com/grailbio/base/file/s3file" "github.com/grailbio/base/log" - "github.com/grailbio/base/traverse" - "github.com/grailbio/base/vcontext" - "github.com/pkg/errors" - "v.io/x/lib/cmdline" ) -// parseGlob parses a string that potentially contains glob metacharacters, and -// returns (nonglobprefix, hasglob). If the string does not contain any glob -// metacharacter, this function returns (str, false). Else, it returns the -// prefix of path elements up to the element containing a glob character. -// -// For example, parseGlob("foo/bar/baz*/*.txt" returns ("foo/bar", true). -func parseGlob(str string) (string, bool) { - node, err := syntax.Parse(str) - if err != nil { - panic(err) - } - if node.Kind != ast.KindPattern || len(node.Children) == 0 { - panic(node) - } - if node.Children[0].Kind != ast.KindText { - return "", true - } - if len(node.Children) == 1 { - return str, false - } - nonGlobPrefix := node.Children[0].Value.(ast.Text).Text - if i := strings.LastIndexByte(nonGlobPrefix, '/'); i > 0 { - nonGlobPrefix = nonGlobPrefix[:i+1] - } - return nonGlobPrefix, true -} - -// expandGlob expands the given glob string. If the string does not contain a -// glob metacharacter, or on any error, it returns {str}. -func expandGlob(ctx context.Context, str string) []string { - nonGlobPrefix, hasGlob := parseGlob(str) - if !hasGlob { - return []string{str} - } - m, err := glob.Compile(str) - if err != nil { - return []string{str} - } - - globSuffix := str[len(nonGlobPrefix):] - if strings.HasSuffix(globSuffix, "/") { - globSuffix = globSuffix[:len(globSuffix)-1] - } - recursive := len(strings.Split(globSuffix, "/")) > 1 || strings.Index(globSuffix, "**") >= 0 - - lister := file.List(ctx, nonGlobPrefix, recursive) - matches := []string{} - for lister.Scan() { - if m.Match(lister.Path()) { - matches = append(matches, lister.Path()) - } - } - if err := lister.Err(); err != nil { - return []string{str} - } - if len(matches) == 0 { - return []string{str} - } - return matches -} - -// expandGlobs calls expandGlob on each string and unions the results. -func expandGlobs(ctx context.Context, patterns []string) []string { - matches := []string{} - for _, pattern := range patterns { - matches = append(matches, expandGlob(ctx, pattern)...) - } - return matches -} - -func runCat(_ *cmdline.Env, args []string) error { - ctx := vcontext.Background() - for _, arg := range expandGlobs(ctx, args) { - f, err := file.Open(ctx, arg) - if err != nil { - return errors.Wrapf(err, "cat %v", arg) - } - defer f.Close(ctx) // nolint: errcheck - if _, err = io.Copy(os.Stdout, f.Reader(ctx)); err != nil { - return errors.Wrapf(err, "cat %v (io.Copy)", arg) - } - } - return nil -} - -func newCatCmd() *cmdline.Command { - return &cmdline.Command{ - Runner: cmdutil.RunnerFunc(runCat), - Name: "cat", - Short: "Print files to stdout", - ArgsName: "files...", - Long: ` -This command prints contents of the files to the stdout. It supports globs defined in https://github.com/gobwas/glob.`, - } -} - -const parallelism = 16 - -type cprmOpts struct { - verbose bool -} - -func runRm(args []string, opts cprmOpts) error { - ctx := vcontext.Background() - args = expandGlobs(ctx, args) - return traverse.Each(len(args)).Limit(parallelism).Do(func(i int) error { - if opts.verbose { - log.Printf("%s", args[i]) - } - return file.Remove(ctx, args[i]) - }) -} - -func newRmCmd() *cmdline.Command { - opts := cprmOpts{} - c := &cmdline.Command{ - Runner: cmdutil.RunnerFunc(func(_ *cmdline.Env, args []string) error { return runRm(args, opts) }), - Name: "rm", - Short: "Remove files", - ArgsName: "files...", - Long: ` -This command removes files. It supports globs defined in https://github.com/gobwas/glob.`, - } - c.Flags.BoolVar(&opts.verbose, "v", false, "Enable verbose logging") - return c -} - -func runCp(args []string, opts cprmOpts) error { - ctx := vcontext.Background() - copyFile := func(src, dst string) error { - if opts.verbose { - log.Printf("%s -> %s", src, dst) - } - in, err := file.Open(ctx, src) - if err != nil { - return err - } - defer in.Close(ctx) // nolint: errcheck - out, err := file.Create(ctx, dst) - if err != nil { - return errors.Wrapf(err, "cp %v->%v", src, dst) - } - if _, err := io.Copy(out.Writer(ctx), in.Reader(ctx)); err != nil { - _ = out.Close(ctx) - return errors.Wrapf(err, "cp %v->%v", src, dst) - } - err = out.Close(ctx) - if err != nil { - err = errors.Wrapf(err, "cp %v->%v", src, dst) - } - return err - } - - copyFileInDir := func(src, dstDir string) error { - return copyFile(src, file.Join(dstDir, file.Base(src))) - } - nArg := len(args) - if nArg < 2 { - return errors.New("Usage: cp src... dst") - } - dst := args[nArg-1] - if _, hasGlob := parseGlob(dst); hasGlob { - return fmt.Errorf("cp: destination %s cannot be a glob", dst) - } - srcs := expandGlobs(ctx, args[:nArg-1]) - if len(srcs) == 1 { - // Try copying to dst. Failing that, copy to dst/. - if !strings.HasSuffix(dst, "/") && copyFile(srcs[0], dst) == nil { - return nil - } - return copyFileInDir(srcs[0], dst) - } - return traverse.Each(len(srcs)).Limit(parallelism).Do(func(i int) error { - return copyFileInDir(srcs[i], dst) - }) -} - -func newCpCmd() *cmdline.Command { - opts := cprmOpts{} - c := &cmdline.Command{ - Runner: cmdutil.RunnerFunc(func(_ *cmdline.Env, args []string) error { return runCp(args, opts) }), - Name: "cp", - Short: "Copy files", - ArgsName: "srcfiles... dstfile-or-dir", - Long: ` -This command copies files. It can be invoked in two forms: - -1. cp src dst -2. cp src dst/ -3. cp src.... dstdir - -The first form first tries to copy file src to dst. If dst exists as a -directory, it copies src to dst/, where is the basename of the -source file. The second form copies file src to dst/. - -The third form copies each of "src" to destdir/. - -This command supports globs defined in https://github.com/gobwas/glob. `, - } - c.Flags.BoolVar(&opts.verbose, "v", false, "Enable verbose logging") - return c -} - -type lsOpts struct { - recursive bool - longOutput bool -} - -func runLs(out io.Writer, args []string, opts lsOpts) error { - const iso8601 = "2006-01-02T15:04:05-0700" - ctx := vcontext.Background() - args = expandGlobs(ctx, args) - for _, arg := range args { - lister := file.List(ctx, arg, opts.recursive) - for lister.Scan() { - // TODO(saito) prettyprint - switch { - case lister.IsDir(): - fmt.Fprintf(out, "%s/\n", lister.Path()) - case opts.longOutput: - info := lister.Info() - fmt.Fprintf(out, "%s\t%d\t%s\n", lister.Path(), info.Size(), info.ModTime().Format(iso8601)) - default: - fmt.Fprintf(out, "%s\n", lister.Path()) - } - } - if err := lister.Err(); err != nil { - return err - } - } - return nil -} - -func newLsCmd() *cmdline.Command { - opts := lsOpts{} - c := &cmdline.Command{ - Runner: cmdutil.RunnerFunc(func(_ *cmdline.Env, args []string) error { return runLs(os.Stdout, args, opts) }), - Name: "ls", - Short: "List files", - ArgsName: "prefix...", +func main() { + help := flag.Bool("help", false, "Display help about this command") + flag.Parse() + if *help { + cmd.PrintHelp() + os.Exit(0) } - c.Flags.BoolVar(&opts.longOutput, "l", false, "Print file size and last modification time") - c.Flags.BoolVar(&opts.recursive, "R", false, "Descend into directories recursively") - return c -} -func main() { log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds | log.Lshortfile) - cmdline.HideGlobalFlagsExcept() - cmd := &cmdline.Command{ - Name: "grail-file", - Short: "Access files using grailfile", - Children: []*cmdline.Command{ - newCatCmd(), - newCpCmd(), - newLsCmd(), - newRmCmd(), - }, + file.RegisterImplementation("s3", func() file.Implementation { + return s3file.NewImplementation(s3file.NewDefaultProvider(), s3file.Options{}) + }) + err := cmd.Run(context.Background(), os.Args[1:]) + if err != nil { + log.Fatal(err) } - cmdline.Main(cmd) } diff --git a/cmd/grail-file/main_test.go b/cmd/grail-file/main_test.go index 70299a98..b9dcf2e6 100644 --- a/cmd/grail-file/main_test.go +++ b/cmd/grail-file/main_test.go @@ -8,6 +8,7 @@ import ( "strings" "testing" + "github.com/grailbio/base/cmd/grail-file/cmd" "github.com/grailbio/base/file" "github.com/grailbio/testutil" "github.com/stretchr/testify/assert" @@ -22,9 +23,9 @@ func readFile(path string) string { } func TestLs(t *testing.T) { - doLs := func(dir string, opts lsOpts) []string { + doLs := func(args ...string) []string { out := bytes.Buffer{} - assert.NoError(t, runLs(&out, []string{dir}, opts)) + assert.NoError(t, cmd.Ls(context.Background(), &out, args)) s := strings.Split(strings.TrimSpace(out.String()), "\n") sort.Strings(s) return s @@ -40,17 +41,17 @@ func TestLs(t *testing.T) { assert.NoError(t, file.WriteFile(ctx, path1, []byte("1"))) assert.Equal(t, []string{tmpDir + "/0.txt", tmpDir + "/d/"}, - doLs(tmpDir, lsOpts{})) + doLs(tmpDir)) assert.Equal(t, []string{tmpDir + "/0.txt", tmpDir + "/d/1.txt"}, - doLs(tmpDir, lsOpts{recursive: true})) + doLs("-R", tmpDir)) - s := doLs(tmpDir, lsOpts{longOutput: true}) + s := doLs("-l", tmpDir) assert.Equal(t, 2, len(s)) assert.Regexp(t, tmpDir+"/0.txt\t1\t20.*", s[0]) assert.Equal(t, tmpDir+"/d/", s[1]) - s = doLs(tmpDir, lsOpts{longOutput: true, recursive: true}) + s = doLs("-l", "-R", tmpDir) assert.Equal(t, 2, len(s)) assert.Regexp(t, tmpDir+"/0.txt\t1\t20.*", s[0]) assert.Regexp(t, tmpDir+"/d/1.txt\t1\t20.*", s[1]) @@ -70,26 +71,48 @@ func TestCp(t *testing.T) { // "cp xxx yyy", where yyy doesn't exist. dstPath := file.Join(tmpDir, "d0.txt") - assert.NoError(t, runCp([]string{src0Path, dstPath}, cprmOpts{})) + assert.NoError(t, cmd.Cp(ctx, os.Stdout, []string{src0Path, dstPath})) assert.Equal(t, expected0, readFile(dstPath)) // "cp x0 x1 yyy", where yyy doesn't exist dstPath = file.Join(tmpDir, "d1") - assert.NoError(t, runCp([]string{src0Path, src1Path, dstPath}, cprmOpts{})) + assert.NoError(t, cmd.Cp(ctx, os.Stdout, []string{src0Path, src1Path, dstPath})) assert.Equal(t, expected0, readFile(file.Join(dstPath, "tmp0.txt"))) assert.Equal(t, expected1, readFile(file.Join(dstPath, "tmp1.txt"))) // Try "cp xxx yyy/", where yyy doesn't exist. Cp should create file yyy/xxx. dstDir := file.Join(tmpDir, "testdir0") - assert.NoError(t, runCp([]string{src0Path, dstDir + "/"}, cprmOpts{})) + assert.NoError(t, cmd.Cp(ctx, os.Stdout, []string{src0Path, dstDir + "/"})) assert.Equal(t, expected0, readFile(file.Join(dstDir, "tmp0.txt"))) dstDir = tmpDir + "/d2" assert.NoError(t, os.Mkdir(dstDir, 0700)) - assert.NoError(t, runCp([]string{src0Path, dstDir}, cprmOpts{})) + assert.NoError(t, cmd.Cp(ctx, os.Stdout, []string{src0Path, dstDir})) assert.Equal(t, expected0, readFile(file.Join(dstDir, "tmp0.txt"))) } +func TestCpRecursive(t *testing.T) { + ctx := context.Background() + tmpDir, cleanup := testutil.TempDir(t, "", "") + defer cleanup() + + srcDir := file.Join(tmpDir, "dir") + path0 := "/dir/tmp0.txt" + path1 := "/dir/dir2/tmp1.txt" + path2 := "/dir/dir2/dir3/tmp2.txt" + expected0 := "tmp0" + expected1 := "tmp1" + expected2 := "tmp2" + assert.NoError(t, file.WriteFile(ctx, srcDir+path0, []byte(expected0))) + assert.NoError(t, file.WriteFile(ctx, srcDir+path1, []byte(expected1))) + assert.NoError(t, file.WriteFile(ctx, srcDir+path2, []byte(expected2))) + dstDir := file.Join(tmpDir, "dir1") + assert.NoError(t, cmd.Cp(ctx, os.Stdout, []string{"-R", srcDir, dstDir})) + assert.Equal(t, expected0, readFile(dstDir+path0)) + assert.Equal(t, expected1, readFile(dstDir+path1)) + assert.Equal(t, expected2, readFile(dstDir+path2)) +} + func TestRm(t *testing.T) { ctx := context.Background() tmpDir, cleanup := testutil.TempDir(t, "", "") @@ -105,58 +128,30 @@ func TestRm(t *testing.T) { assert.Equal(t, "1", readFile(src1Path)) assert.Equal(t, "2", readFile(src2Path)) - assert.NoError(t, runRm([]string{src0Path, src1Path}, cprmOpts{})) + assert.NoError(t, cmd.Rm(ctx, os.Stdout, []string{src0Path, src1Path})) assert.Regexp(t, "no such file", readFile(src0Path)) assert.Regexp(t, "no such file", readFile(src1Path)) assert.Equal(t, "2", readFile(src2Path)) - assert.NoError(t, runRm([]string{src2Path}, cprmOpts{})) + assert.NoError(t, cmd.Rm(ctx, os.Stdout, []string{src2Path})) assert.Regexp(t, "no such file", readFile(src0Path)) assert.Regexp(t, "no such file", readFile(src1Path)) assert.Regexp(t, "no such file", readFile(src2Path)) } -func TestParseGlob(t *testing.T) { - doParse := func(str string) string { - prefix, hasGlob := parseGlob(str) - if !hasGlob { - return "none" - } - return prefix - } - assert.Equal(t, "none", doParse("s3://a/b/c")) - assert.Equal(t, "none", doParse("s3://a/b\\*/c")) - assert.Equal(t, "s3://a/", doParse("s3://a/b*/c")) - assert.Equal(t, "s3://a/b/", doParse("s3://a/b/*")) - assert.Equal(t, "s3://a/", doParse("s3://a/b?")) - assert.Equal(t, "s3://a/", doParse("s3://a/**/b")) - assert.Equal(t, "", doParse("**")) -} - -func TestExpandGlob(t *testing.T) { +func TestRmRecursive(t *testing.T) { ctx := context.Background() tmpDir, cleanup := testutil.TempDir(t, "", "") defer cleanup() - src0Path := file.Join(tmpDir, "abc/def/tmp0") - src1Path := file.Join(tmpDir, "abd/efg/hij/tmp1") - src2Path := file.Join(tmpDir, "tmp0") - assert.NoError(t, file.WriteFile(ctx, src0Path, []byte("a"))) - assert.NoError(t, file.WriteFile(ctx, src1Path, []byte("b"))) - assert.NoError(t, file.WriteFile(ctx, src2Path, []byte("c"))) - - doExpand := func(str string) string { - matches := expandGlob(ctx, tmpDir+"/"+str) - for i := range matches { - matches[i] = matches[i][len(tmpDir)+1:] // remove the tmpDir part. - } - return strings.Join(matches, ",") - } + src0Path := file.Join(tmpDir, "dir/tmp0.txt") + src1Path := file.Join(tmpDir, "dir/dir2/tmp1.txt") + src2Path := file.Join(tmpDir, "dir/dir2/dir3/tmp2.txt") + assert.NoError(t, file.WriteFile(ctx, src0Path, []byte("0"))) + assert.NoError(t, file.WriteFile(ctx, src1Path, []byte("1"))) + assert.NoError(t, file.WriteFile(ctx, src2Path, []byte("2"))) - assert.Equal(t, "abc/def/tmp0", doExpand("abc/*/tmp0")) - assert.Equal(t, "xxx/yyy", doExpand("xxx/yyy")) - assert.Equal(t, "xxx/*", doExpand("xxx/*")) - assert.Equal(t, "abc/def/tmp0", doExpand("a*/*/tmp0")) - assert.Equal(t, "abd/efg/hij/tmp1", doExpand("abd/**/tmp*")) - assert.Equal(t, "abc/def/tmp0,abd/efg/hij/tmp1", doExpand("a*/**/tmp*")) - assert.Equal(t, "abc/def/tmp0,abd/efg/hij/tmp1,tmp0", doExpand("**")) + assert.NoError(t, cmd.Rm(ctx, os.Stdout, []string{"-R", file.Join(tmpDir, "dir/dir2")})) + assert.Equal(t, "0", readFile(src0Path)) + assert.Regexp(t, "no such file", readFile(src1Path)) + assert.Regexp(t, "no such file", readFile(src2Path)) } diff --git a/cmd/grail-fuse/README.md b/cmd/grail-fuse/README.md new file mode 100644 index 00000000..ce3d8831 --- /dev/null +++ b/cmd/grail-fuse/README.md @@ -0,0 +1,45 @@ +# Grail-fuse + +Grail-fuse allows reading and writing S3 files as if they are on the local file +system. It supports Linux and OSX. No extra step is needed for Linux. For OSX, +install osxfuse from https://github.com/osxfuse/osxfuse/releases + +## Usage + + grail-fuse [-remote-root-dir DIR] [-daemon] [-log-dir LOGDIR] MOUNTPOINT + +Example: + + grail-fuse -daemon $HOME/s3 + +Flag `-mount-root-dir` defaults to `s3://`. Thus, after mounting grail-fuse on +`~/s3`, Toplevel directories under `~/s3` will list S3 buckets owned by the +default AWS account. You can descend into buckets to read files underneath. + +If `-daemon` is set, `grail-fuse` runs as a daemon. + +## Unmounting the file system + +To unmount `grail-fuse`, run: + + fusermount -u $HOME/s3 + +on Linux, or + + umount $HOME/s3 + +on OSX. + +## Bugs and limitations + +- `mkdir` is not supported. `rmdir` is supported, but is is just a noop. You can + just create a file under a nonexisting subdirectory without `mkdir`ing it first. + +- grail-fuse caches file attributes in memory for up to 5 minutes. Thus, if the + remote file system is updated by another user, it will not be reflected for up + to five minutes. You can send a SIGHUP to `grail-fuse` to invalidate the + cache. + +- File contents are not written back to the remote file system until `close`. + +- In general, random seeks are supported during writes. diff --git a/cmd/grail-fuse/gfs/gfs.go b/cmd/grail-fuse/gfs/gfs.go new file mode 100644 index 00000000..34726410 --- /dev/null +++ b/cmd/grail-fuse/gfs/gfs.go @@ -0,0 +1,1005 @@ +// Package gfs implements FUSE on top oh grailfile. Function Main is the entry +// point. +package gfs + +import ( + "context" + "crypto/sha512" + "encoding/binary" + "fmt" + "io" + "os" + "runtime/debug" + "sync" + "sync/atomic" + "syscall" + "time" + "unsafe" + + "github.com/grailbio/base/errors" + "github.com/grailbio/base/file" + "github.com/grailbio/base/log" + gunsafe "github.com/grailbio/base/unsafe" + "github.com/hanwen/go-fuse/v2/fs" + "github.com/hanwen/go-fuse/v2/fuse" +) + +// Inode represents a file or a directory. +type inode struct { + fs.Inode + // full pathname, such as "s3://bucket/key0/key1" + path string + // dir entry as stored in the parent directory. + ent fuse.DirEntry + + mu sync.Mutex // guards the following fields. + stat cachedStat // TODO: Remove this since we're now using kernel caching. + + // nDirStreamRef tracks the usage of this inode in DirStreams. It is used + // to decide whether an inode can be reused to service LOOKUP + // operations. To handle READDIRPLUS, go-fuse interleaves LOOKUP calls for + // each directory entry. We allow the inode associated with the previous + // directory entry to be used in LOOKUP to avoid costly API calls. + // + // Because an inode can be the previous entry in multiple DirStreams, we + // maintain a reference count. + // + // It is possible for the inode to be forgotten, e.g. when the kernel is + // low on memory, before the LOOKUP call. If this happens, LOOKUP will not + // be able to reuse it. This seems to happen rarely, if at all, in + // practice. + nDirStreamRef int32 +} + +// Amount of time to cache directory entries and file stats (size, mtime). +const cacheExpiration = 5 * time.Minute + +// RootInode is a singleton inode created for the root mount point. +type rootInode struct { + inode + // The context to be used for all file operations. It's vcontext.Background() + // in Grail environments. + // TODO(josh): Consider removing and using operation-specific contexts instead (like readdir). + ctx context.Context + // Directory for storing tmp files. + tmpDir string +} + +// Handle represents an open file handle. +type handle struct { + // The file that the handle belongs to + inode *inode + // Open mode bits. O_WRONLY, etc. + openMode uint32 + // Size passed to Setattr, if any. -1 if not set. + requestedSize int64 + // Remembers the result of the first Flush. If Flush is called multiple times + // they will return this code. + closeErrno syscall.Errno + + // At most one of the following three will be set. Initialized lazily on + // first Read or Write. + dw *directWrite // O_WRONLY|O_TRUNC, or O_WRONLY for a new file. + dr *directRead // O_RDONLY. + tmp *tmpIO // everything else, e.g., O_RDWR or O_APPEND. +} + +// openMode is a bitmap of O_RDONLY, O_APPEND, etc. +func newHandle(inode *inode, openMode uint32) *handle { + return &handle{inode: inode, openMode: openMode, requestedSize: -1} +} + +// DirectWrite is part of open file handle. It uploads data directly to the remote +// file. Used when creating a new file, or overwriting an existing file with +// O_WRONLY|O_TRUNC. +type directWrite struct { + fp file.File + w io.Writer + // The next expected write offset. Calling Write on a wrong offset results in + // error (w doesn't implement a seeker). + off int64 +} + +// DirectRead is part of open file handle. It is used when reading a file +// readonly. +type directRead struct { + fp file.File + r io.ReadSeeker +} + +// TmpIO is part of open file handle. It writes data to a file in the local file +// system. On Flush (i.e., close), the file contents are copied to the remote +// file. It is used w/ O_RDWR, O_APPEND, etc. +type tmpIO struct { + fp *os.File // refers to a file in -tmp-dir. +} + +// CachedStat is stored in inode and a directory entry to provide quick access +// to basic stats. +type cachedStat struct { + expiration time.Time + size int64 + modTime time.Time +} + +func downCast(n *fs.Inode) *inode { + nn := (*inode)(unsafe.Pointer(n)) + if nn.path == "" { + log.Panicf("not an inode: %+v", n) + } + return nn +} + +var ( + _ fs.InodeEmbedder = (*inode)(nil) + + _ fs.NodeAccesser = (*inode)(nil) + _ fs.NodeCreater = (*inode)(nil) + _ fs.NodeGetattrer = (*inode)(nil) + _ fs.NodeLookuper = (*inode)(nil) + _ fs.NodeMkdirer = (*inode)(nil) + _ fs.NodeOpener = (*inode)(nil) + _ fs.NodeReaddirer = (*inode)(nil) + _ fs.NodeRmdirer = (*inode)(nil) + _ fs.NodeSetattrer = (*inode)(nil) + _ fs.NodeUnlinker = (*inode)(nil) + + _ fs.FileFlusher = (*handle)(nil) + _ fs.FileFsyncer = (*handle)(nil) + _ fs.FileLseeker = (*handle)(nil) + _ fs.FileReader = (*handle)(nil) + _ fs.FileReleaser = (*handle)(nil) + _ fs.FileWriter = (*handle)(nil) +) + +func newAttr(ino uint64, mode uint32, size uint64, optionalMtime time.Time) (attr fuse.Attr) { + const blockSize = 1 << 20 + attr.Ino = ino + attr.Mode = mode + attr.Nlink = 1 + attr.Size = size + attr.Blocks = (attr.Size-1)/blockSize + 1 + if !optionalMtime.IsZero() { + attr.SetTimes(nil, &optionalMtime, nil) + } + return +} + +// GetModeBits produces the persistent mode bits so that the kernel can +// distinguish regular files from directories. +func getModeBits(isDir bool) uint32 { + mode := uint32(0) + if isDir { + mode |= syscall.S_IFDIR | 0755 + } else { + mode |= syscall.S_IFREG | 0644 + } + return mode +} + +// GetIno produces a fake inode number by hashing the path. +func getIno(path string) uint64 { + h := sha512.Sum512_256(gunsafe.StringToBytes(path)) + return binary.LittleEndian.Uint64(h[:8]) +} + +// GetFileName extracts the filename part of the path. "dir" is the directory +// that the file belongs in. +func getFileName(dir *inode, path string) string { + if dir.IsRoot() { + return path[len(dir.path):] + } + return path[len(dir.path)+1:] // +1 to remove '/'. +} + +func errToErrno(err error) syscall.Errno { + if err == nil { + return 0 + } + log.Debug.Printf("error %v: stack=%s", err, string(debug.Stack())) + switch { + case err == nil: + return 0 + case errors.Is(errors.Timeout, err): + return syscall.ETIMEDOUT + case errors.Is(errors.Canceled, err): + return syscall.EINTR + case errors.Is(errors.NotExist, err): + return syscall.ENOENT + case errors.Is(errors.Exists, err): + return syscall.EEXIST + case errors.Is(errors.NotAllowed, err): + return syscall.EACCES + case errors.Is(errors.Integrity, err): + return syscall.EIO + case errors.Is(errors.Invalid, err): + return syscall.EINVAL + case errors.Is(errors.Precondition, err), errors.Is(errors.Unavailable, err): + return syscall.EAGAIN + case errors.Is(errors.Net, err): + return syscall.ENETUNREACH + case errors.Is(errors.TooManyTries, err): + log.Error.Print(err) + return syscall.EINVAL + } + return fs.ToErrno(err) +} + +// Root reports the inode of the root mountpoint. +func (n *inode) root() *rootInode { return n.Root().Operations().(*rootInode) } + +// Ctx reports the context passed from the application when mounting the +// filesystem. +func (n *inode) ctx() context.Context { return n.root().ctx } + +// addDirStreamRef adds a single reference to this inode. It must be eventually +// followed by a dropRef. +func (n *inode) addDirStreamRef() { + _ = atomic.AddInt32(&n.nDirStreamRef, 1) +} + +// dropDirStreamRef drops a single reference to this inode. +func (n *inode) dropDirStreamRef() { + if x := atomic.AddInt32(&n.nDirStreamRef, -1); x < 0 { + panic("negative reference count; unmatched drop") + } +} + +// previousOfAnyDirStream returns true iff the inode is the previous entry +// returned by any outstanding DirStream. +func (n *inode) previousOfAnyDirStream() bool { + return atomic.LoadInt32(&n.nDirStreamRef) > 0 +} + +// Access is called to implement access(2). +func (n *inode) Access(_ context.Context, mask uint32) syscall.Errno { + // TODO(saito) I'm not sure returning 0 blindly is ok here. + log.Debug.Printf("setattr %s: mask=%x", n.path, mask) + return 0 +} + +// Setattr is called to change file attributes. This function only supports +// changing the size. +func (n *inode) Setattr(_ context.Context, fhi fs.FileHandle, in *fuse.SetAttrIn, out *fuse.AttrOut) syscall.Errno { + n.mu.Lock() + defer n.mu.Unlock() + + usize, ok := in.GetSize() + if !ok { + // We don't support setting other attributes now. + return 0 + } + size := int64(usize) + + if fhi != nil { + fh := fhi.(*handle) + switch { + case fh.dw != nil: + if size == fh.dw.off { + return 0 + } + log.Error.Printf("setattr %s: setting size to %d in directio mode not supported (request: %+v)", n.path, size, in) + return syscall.ENOSYS + case fh.dr != nil: + log.Error.Printf("setattr %s: readonly", n.path) + return syscall.EPERM + case fh.tmp != nil: + return errToErrno(fh.tmp.fp.Truncate(size)) + default: + fh.requestedSize = size + return 0 + } + } + + if size != 0 { + log.Error.Printf("setattr %s: setting size to nonzero value (%d) not supported", n.path, size) + return syscall.ENOSYS + } + ctx := n.ctx() + fp, err := file.Create(ctx, n.path) + if err != nil { + log.Error.Printf("setattr %s: %v", n.path, err) + return errToErrno(err) + } + if err := fp.Close(ctx); err != nil { + log.Error.Printf("setattr %s: %v", n.path, err) + return errToErrno(err) + } + return 0 +} + +func (n *inode) Getattr(_ context.Context, fhi fs.FileHandle, out *fuse.AttrOut) syscall.Errno { + ctx := n.ctx() + if n.ent.Ino == 0 || n.ent.Mode == 0 { + log.Panicf("node %s: ino or mode unset: %+v", n.path, n) + } + if n.IsDir() { + log.Debug.Printf("getattr %s: directory", n.path) + out.Attr = newAttr(n.ent.Ino, n.ent.Mode, 0, time.Time{}) + return 0 + } + + var fh *handle + if fhi != nil { + fh = fhi.(*handle) + } + + n.mu.Lock() + defer n.mu.Unlock() + if fh != nil { + if err := fh.maybeInitIO(); err != nil { + return errToErrno(err) + } + if t := fh.tmp; t != nil { + log.Debug.Printf("getattr %s: tmp", n.path) + stat, err := t.fp.Stat() + if err != nil { + log.Printf("getattr %s (%s): %v", n.path, t.fp.Name(), err) + return errToErrno(err) + } + out.Attr = newAttr(n.ent.Ino, n.ent.Mode, uint64(stat.Size()), stat.ModTime()) + return 0 + } + if fh.dw != nil { + out.Attr = newAttr(n.ent.Ino, n.ent.Mode, uint64(n.stat.size), n.stat.modTime) + return 0 + } + // fall through + } + stat, err := n.getCachedStat(ctx) + if err != nil { + log.Printf("getattr %s: err %v", n.path, err) + return errToErrno(err) + } + out.Attr = newAttr(n.ent.Ino, n.ent.Mode, uint64(stat.size), stat.modTime) + log.Debug.Printf("getattr %s: out %+v", n.path, out) + return 0 +} + +func (n *inode) getCachedStat(ctx context.Context) (cachedStat, error) { + now := time.Now() + if now.After(n.stat.expiration) { + log.Debug.Printf("getcachedstat %s: cache miss", n.path) + info, err := file.Stat(ctx, n.path) + if err != nil { + log.Printf("getcachedstat %s: err %v", n.path, err) + return cachedStat{}, err + } + n.stat = cachedStat{ + expiration: now.Add(cacheExpiration), + size: info.Size(), + modTime: info.ModTime(), + } + } else { + log.Debug.Printf("getcachedstat %s: cache hit %+v now %v", n.path, n.stat, now) + } + return n.stat, nil +} + +// MaybeInitIO is called on the first call to Read or Write after open. It +// initializes either the directio uploader or a tempfile. +// +// REQUIRES: fh.inode.mu is locked +func (fh *handle) maybeInitIO() error { + n := fh.inode + if fh.dw != nil || fh.dr != nil || fh.tmp != nil { + return nil + } + if (fh.openMode & fuse.O_ANYWRITE) == 0 { + // Readonly handle should have fh.direct set at the time of Open. + log.Panicf("open %s: uninitialized readonly handle", n.path) + } + if fh.inode == nil { + log.Panicf("open %s: nil inode: %+v", n.path, fh) + } + ctx := n.ctx() + if (fh.openMode&syscall.O_RDWR) != syscall.O_RDWR && + (fh.requestedSize == 0 || (fh.openMode&syscall.O_TRUNC == syscall.O_TRUNC)) { + // We are fully overwriting the file. Do that w/o a local tmpfile. + log.Debug.Printf("open %s: direct IO", n.path) + fp, err := file.Create(ctx, n.path) + if err != nil { + return err + } + fh.dw = &directWrite{fp: fp, w: fp.Writer(ctx)} + return nil + } + // Do all reads/writes on a local tmp file, and copy it to the remote file on + // close. + log.Debug.Printf("open %s: tmp IO", n.path) + in, err := file.Open(ctx, n.path) + if err != nil { + log.Error.Printf("open %s: %v", n.path, err) + return err + } + tmpPath := file.Join(n.root().tmpDir, fmt.Sprintf("%08x", n.ent.Ino)) + tmp, err := os.OpenFile(tmpPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + log.Error.Printf("create %s (open %s): %v", tmpPath, n.path, err) + _ = in.Close(ctx) + return errToErrno(err) + } + inSize, err := io.Copy(tmp, in.Reader(ctx)) + log.Debug.Printf("copy %s->%s: n+%d, %v", n.path, tmp.Name(), inSize, err) + if err != nil { + _ = in.Close(ctx) + _ = tmp.Close() + return errToErrno(err) + } + if err := in.Close(ctx); err != nil { + _ = tmp.Close() + return errToErrno(err) + } + now := time.Now() + n.stat.expiration = now.Add(cacheExpiration) + n.stat.size = inSize + n.stat.modTime = now + fh.tmp = &tmpIO{ + fp: tmp, + } + return nil +} + +func (fh *handle) Read(_ context.Context, dest []byte, off int64) (fuse.ReadResult, syscall.Errno) { + n := fh.inode + readDirect := func() (fuse.ReadResult, syscall.Errno) { + d := fh.dr + if d == nil { + return nil, syscall.EINVAL + } + log.Debug.Printf("read %s(fh=%p): off=%d seek start", n.path, fh, off) + newOff, err := d.r.Seek(off, io.SeekStart) + log.Debug.Printf("read %s(fh=%p): off=%d seek end", n.path, fh, off) + if err != nil { + return nil, errToErrno(err) + } + if newOff != off { + log.Panicf("%d <-> %d", newOff, off) + } + + nByte, err := d.r.Read(dest) + log.Debug.Printf("read %s(fh=%p): off=%d, nbyte=%d, err=%v", n.path, fh, off, nByte, err) + if err != nil { + if err != io.EOF { + return nil, errToErrno(err) + } + } + return fuse.ReadResultData(dest[:nByte]), 0 + } + + readTmp := func() (fuse.ReadResult, syscall.Errno) { + t := fh.tmp + nByte, err := t.fp.ReadAt(dest, off) + if err != nil { + if err != io.EOF { + return nil, errToErrno(err) + } + } + return fuse.ReadResultData(dest[:nByte]), 0 + } + + n.mu.Lock() + defer n.mu.Unlock() + if err := fh.maybeInitIO(); err != nil { + //return fuse.ReadResult{}, errToErrno(err) + return nil, errToErrno(err) + } + switch { + case fh.dr != nil: + return readDirect() + case fh.tmp != nil: + return readTmp() + default: + log.Error.Printf("read %s: reading unopened or writeonly file", n.path) + return nil, syscall.EBADF + } +} + +func (fh *handle) Lseek(ctx context.Context, off uint64, whence uint32) (uint64, syscall.Errno) { + const ( + // Copied from https://github.com/torvalds/linux/blob/a050a6d2b7e80ca52b2f4141eaf3420d201b72b3/tools/include/uapi/linux/fs.h#L43-L47. + SEEK_DATA = 3 + SEEK_HOLE = 4 + ) + switch whence { + case SEEK_DATA: + return off, 0 // We don't support holes so current offset is correct. + case SEEK_HOLE: + stat, err := fh.inode.getCachedStat(ctx) + if err != nil { + log.Error.Printf("lseek %s: stat: %v", fh.inode.path, err) + return 0, errToErrno(err) + } + return uint64(stat.size), 0 + } + log.Error.Printf("lseek %s: unimplemented whence: %d", fh.inode.path, whence) + return 0, syscall.ENOSYS +} + +func (fh *handle) Write(_ context.Context, dest []byte, off int64) (uint32, syscall.Errno) { + n := fh.inode + tmpWrite := func() (uint32, syscall.Errno) { + nByte, err := fh.tmp.fp.WriteAt(dest, off) + if err != nil { + log.Error.Printf("write %s: size=%d, off=%d: %v", n.path, len(dest), off, err) + return 0, errToErrno(err) + } + return uint32(nByte), 0 + } + + directWrite := func() (uint32, syscall.Errno) { + d := fh.dw + if d.off != off { + log.Error.Printf("write %s: offset mismatch (expect %d, got %d)", n.path, d.off, off) + return 0, syscall.EINVAL + } + if d.w == nil { + // closed already + log.Printf("write %s: already closed", n.path) + return 0, syscall.EBADF + } + nByte, err := d.w.Write(dest) + if err != nil { + if nByte > 0 { + panic(n) + } + return 0, errToErrno(err) + } + d.off += int64(nByte) + log.Debug.Printf("write %s: done %d bytes", n.path, nByte) + return uint32(nByte), 0 + } + + n.mu.Lock() + defer n.mu.Unlock() + log.Debug.Printf("write %s: %d bytes, off=%d", n.path, len(dest), off) + if err := fh.maybeInitIO(); err != nil { + return 0, errToErrno(err) + } + switch { + case fh.dw != nil: + return directWrite() + case fh.tmp != nil: + return tmpWrite() + default: + // file descriptor already closed + log.Error.Printf("write %s: writing after close", n.path) + return 0, syscall.EBADF + } +} + +func (fh *handle) Fsync(_ context.Context, _ uint32) syscall.Errno { + n := fh.inode + n.mu.Lock() + defer n.mu.Unlock() + if d := fh.dw; d != nil { + n := fh.inode + // There's not much we can do, but returning ENOSYS breaks too many apps. + now := time.Now() + n.stat.expiration = now.Add(cacheExpiration) + n.stat.size = d.off + n.stat.modTime = now + log.Debug.Printf("fsync %s: update stats: stat=%v", n.path, n.stat) + } + return 0 +} + +// Release is called just before the inode is dropped from the kernel memory. +// Return value is unused. +func (fh *handle) Release(_ context.Context) syscall.Errno { + n := fh.inode + n.mu.Lock() + defer n.mu.Unlock() + switch { + case fh.tmp != nil: + if fh.tmp.fp != nil { + log.Panicf("%s: release called w/o flush", n.path) + } + case fh.dw != nil: + if fh.dw.fp != nil || fh.dw.w != nil { + log.Panicf("%s: release called w/o flush", n.path) + } + default: + if fh.dr != nil { + // Readonly handles are closed on the last release. + _ = fh.dr.fp.Close(n.ctx()) + } + } + return 0 +} + +// Flush is called on close(2). It may be called multiple times when the file +// descriptor is duped. +// +// TODO(saito) We don't support dups now. We close the underlying filestream on +// the first close and subsequent flush calls will do nothing. +func (fh *handle) Flush(_ context.Context) syscall.Errno { + n := fh.inode + ctx := n.ctx() + + flushTmpAndUnlock := func() syscall.Errno { + t := fh.tmp + mu := &n.mu + defer func() { + if mu != nil { + mu.Unlock() + } + }() + if t.fp == nil { + mu.Unlock() + return fh.closeErrno + } + out, err := file.Create(ctx, n.path) + if err != nil { + log.Error.Printf("flush %s (create): err=%v", n.path, err) + fh.closeErrno = errToErrno(err) + _ = t.fp.Close() + mu.Unlock() + return fh.closeErrno + } + defer func() { + if out != nil { + _ = out.Close(ctx) + } + if t.fp != nil { + _ = t.fp.Close() + t.fp = nil + } + }() + + newOff, err := t.fp.Seek(0, io.SeekStart) + if err != nil { + log.Error.Printf("flush %s (seek): err=%v", n.path, err) + fh.closeErrno = errToErrno(err) + return fh.closeErrno + } + if newOff != 0 { + log.Panicf("newoff %d", newOff) + } + + nByte, err := io.Copy(out.Writer(ctx), t.fp) + if err != nil { + log.Error.Printf("flush %s (copy): err=%v", n.path, err) + fh.closeErrno = errToErrno(err) + return fh.closeErrno + } + errp := errors.Once{} + errp.Set(t.fp.Close()) + errp.Set(out.Close(ctx)) + out = nil + t.fp = nil + if err := errp.Err(); err != nil { + fh.closeErrno = errToErrno(err) + log.Error.Printf("flush %s (close): err=%v", n.path, err) + return fh.closeErrno + } + + now := time.Now() + n.stat.expiration = now.Add(cacheExpiration) + n.stat.size = nByte + n.stat.modTime = now + + closeErrno := fh.closeErrno + mu.Unlock() + mu = nil + return closeErrno + } + + flushDirectAndUnlock := func() syscall.Errno { + mu := &n.mu + defer func() { + if mu != nil { + mu.Unlock() + } + }() + d := fh.dw + if d.fp == nil { + return fh.closeErrno + } + + err := d.fp.Close(ctx) + fh.closeErrno = errToErrno(err) + log.Debug.Printf("flush %s fh=%p, err=%v", n.path, fh, err) + if d.w != nil { + now := time.Now() + n.stat.expiration = now.Add(cacheExpiration) + n.stat.size = d.off + n.stat.modTime = now + } + d.fp = nil + d.w = nil + closeErrno := fh.closeErrno + mu.Unlock() + mu = nil + return closeErrno + } + n.mu.Lock() + switch { + case fh.tmp != nil: + return flushTmpAndUnlock() + case fh.dw != nil: + return flushDirectAndUnlock() + } + n.mu.Unlock() + return 0 +} + +// Create is called to create a new file. +func (n *inode) Create(ctx context.Context, name string, flags uint32, mode uint32, + out *fuse.EntryOut) (*fs.Inode, fs.FileHandle, uint32, syscall.Errno) { + newPath := file.Join(n.path, name) + childNode := &inode{ + path: newPath, + ent: fuse.DirEntry{ + Name: name, + Ino: getIno(newPath), + Mode: getModeBits(false)}} + childInode := n.NewInode(ctx, childNode, fs.StableAttr{ + Mode: childNode.ent.Mode, + Ino: childNode.ent.Ino, + }) + fh := newHandle(childNode, syscall.O_WRONLY|syscall.O_CREAT|syscall.O_TRUNC) + fh.requestedSize = 0 + log.Debug.Printf("create %s: (mode %x)", n.path, mode) + out.Attr = newAttr(n.ent.Ino, n.ent.Mode, 0, time.Time{}) + return childInode, fh, 0, 0 +} + +// Open opens an existing file. +func (n *inode) Open(_ context.Context, mode uint32) (fs.FileHandle, uint32, syscall.Errno) { + n.mu.Lock() + defer n.mu.Unlock() + ctx := n.ctx() + if n.IsRoot() { + // The entries under the root must be buckets, so we can't open it directly. + log.Error.Printf("open %s: cannot open a file under root", n.path) + return nil, 0, syscall.EINVAL + } + _, dirInode := n.Parent() + if dirInode == nil { + log.Panicf("open %s: parent dir does't exist", n.path) + } + if (mode & fuse.O_ANYWRITE) == 0 { + fp, err := file.Open(n.ctx(), n.path) + if err != nil { + log.Error.Printf("open %s (mode %x): %v", n.path, mode, err) + return nil, 0, errToErrno(err) + } + fh := newHandle(n, mode) + fh.dr = &directRead{fp: fp, r: fp.Reader(ctx)} + log.Debug.Printf("open %s: mode %x, fh %p", n.path, mode, fh) + return fh, 0, 0 + } + + fh := newHandle(n, mode) + return fh, 0, 0 +} + +// FsDirStream implements readdir. +type fsDirStream struct { + ctx context.Context + dir *inode + lister file.Lister + err error + + seenParent bool // Whether Next has already returned '..'. + seenSelf bool // Whether Next has already returned '.'. + peekedChild bool // Whether HasNext has Scan()-ed a child that Next hasn't returned yet. + + // previousInode is the inode of the previous entry, i.e. the most recent + // entry returned by Next. We hold a reference to service LOOKUP + // operations that go-fuse issues when servicing READDIRPLUS. See + // dirStreamUsage. + previousInode *fs.Inode +} + +// HasNext implements fs.DirStream +func (s *fsDirStream) HasNext() bool { + s.dir.mu.Lock() // TODO: Remove? + defer s.dir.mu.Unlock() + + if s.err != nil || s.lister == nil { + return false + } + if !s.seenParent || !s.seenSelf || s.peekedChild { + return true + } + for s.lister.Scan() { + if getFileName(s.dir, s.lister.Path()) != "" { + s.peekedChild = true + return true + } + // Assume this is a directory marker: + // https://web.archive.org/web/20190424231712/https://docs.aws.amazon.com/AmazonS3/latest/user-guide/using-folders.html + // s3file's List returns these, but empty filenames seem to cause problems for FUSE. + // TODO: Filtering these in s3file, if it's ok for other users. + } + return false +} + +// Next implements fs.DirStream +func (s *fsDirStream) Next() (fuse.DirEntry, syscall.Errno) { + s.dir.mu.Lock() + defer s.dir.mu.Unlock() + + if s.err != nil { + return fuse.DirEntry{}, errToErrno(s.err) + } + if err := s.lister.Err(); err != nil { + if _, canceled := <-s.ctx.Done(); canceled { + s.err = errors.E(errors.Canceled, "list canceled", err) + } else { + s.err = err + } + return fuse.DirEntry{}, errToErrno(s.err) + } + + ent := fuse.DirEntry{} + stat := cachedStat{expiration: time.Now().Add(cacheExpiration)} + + if !s.seenParent { + s.seenParent = true + _, parent := s.dir.Parent() + if parent != nil { + // Not root. + parentDir := downCast(parent) + ent = parentDir.ent + ent.Name = ".." + stat = parentDir.stat + return ent, 0 + } + } + if !s.seenSelf { + s.seenSelf = true + ent = s.dir.ent + ent.Name = "." + stat = s.dir.stat + return ent, 0 + } + s.peekedChild = false + + ent = fuse.DirEntry{ + Name: getFileName(s.dir, s.lister.Path()), + Mode: getModeBits(s.lister.IsDir()), + Ino: getIno(s.lister.Path()), + } + if info := s.lister.Info(); info != nil { + stat.size, stat.modTime = info.Size(), info.ModTime() + } + inode := s.dir.NewInode( + s.ctx, + &inode{path: file.Join(s.dir.path, ent.Name), ent: ent, stat: stat}, + fs.StableAttr{Mode: ent.Mode, Ino: ent.Ino}, + ) + _ = s.dir.AddChild(ent.Name, inode, true) + s.lockedSetPreviousInode(inode) + return ent, 0 +} + +// Close implements fs.DirStream +func (s *fsDirStream) Close() { + s.dir.mu.Lock() + s.lockedClearPreviousInode() + s.dir.mu.Unlock() +} + +func (s *fsDirStream) lockedSetPreviousInode(n *fs.Inode) { + s.lockedClearPreviousInode() + s.previousInode = n + s.previousInode.Operations().(*inode).addDirStreamRef() +} + +func (s *fsDirStream) lockedClearPreviousInode() { + if s.previousInode == nil { + return + } + s.previousInode.Operations().(*inode).dropDirStreamRef() + s.previousInode = nil +} + +func (n *inode) Lookup(ctx context.Context, name string, out *fuse.EntryOut) (*fs.Inode, syscall.Errno) { + log.Debug.Printf("lookup %s: name=%s start", n.path, name) + + childInode := n.GetChild(name) + if childInode != nil && childInode.Operations().(*inode).previousOfAnyDirStream() { + log.Debug.Printf("lookup %s: name=%s using existing child inode", n.path, name) + } else { + var ( + childPath = file.Join(n.path, name) + foundDir bool + foundFile cachedStat + lister = file.List(ctx, childPath, true /* recursive */) + ) + // Look for either a file or a directory at this path. + // If both exist, assume file is a directory marker. + for lister.Scan() { + if lister.IsDir() || // We've found an exact match, and it's a directory. + lister.Path() != childPath { // We're seeing children, so childPath must be a directory. + foundDir = true + break + } + info := lister.Info() + foundFile = cachedStat{time.Now().Add(cacheExpiration), info.Size(), info.ModTime()} + } + if err := lister.Err(); err != nil { + if errors.Is(errors.NotExist, err) || errors.Is(errors.NotAllowed, err) { + // Ignore. + } else { + return nil, errToErrno(err) + } + } + + if !foundDir && foundFile == (cachedStat{}) { + log.Debug.Printf("lookup: %s name='%s' not found", n.path, name) + return nil, syscall.ENOENT + } + + ent := fuse.DirEntry{ + Name: childPath, + Mode: getModeBits(foundDir), + Ino: getIno(childPath), + } + childInode = n.NewInode( + ctx, + &inode{path: childPath, ent: ent, stat: foundFile}, + fs.StableAttr{ + Mode: ent.Mode, + Ino: ent.Ino, + }) + } + ops := childInode.Operations().(*inode) + out.Attr = newAttr(ops.ent.Ino, ops.ent.Mode, uint64(ops.stat.size), ops.stat.modTime) + out.SetEntryTimeout(cacheExpiration) + out.SetAttrTimeout(cacheExpiration) + log.Debug.Printf("lookup %s name='%s' done: mode=%o ino=%d stat=%+v", n.path, name, ops.ent.Mode, ops.ent.Ino, ops.stat) + return childInode, 0 +} + +func (n *inode) Readdir(ctx context.Context) (fs.DirStream, syscall.Errno) { + log.Debug.Printf("readdir %s: start", n.path) + // TODO(josh): Newer Linux kernels (4.20+) can cache the entries from readdir. Make sure this works + // and invalidates reasonably. + // References: + // Linux patch series: https://github.com/torvalds/linux/commit/69e345511 + // go-fuse support: https://github.com/hanwen/go-fuse/commit/fa1304749db6eafd8fe64338f10c9750cf693274 + // libfuse's documentation (describing some kernel behavior): http://web.archive.org/web/20210118113434/https://libfuse.github.io/doxygen/structfuse__lowlevel__ops.html#afa15612c68f7971cadfe3d3ec0a8b70e + return &fsDirStream{ + ctx: ctx, + dir: n, + lister: file.List(ctx, n.path, false /*nonrecursive*/), + }, 0 +} + +func (n *inode) Unlink(_ context.Context, name string) syscall.Errno { + childPath := file.Join(n.path, name) + err := file.Remove(n.ctx(), childPath) + log.Debug.Printf("unlink %s: err %v", childPath, err) + return errToErrno(err) +} + +func (n *inode) Rmdir(_ context.Context, name string) syscall.Errno { + // Nothing to do. + return 0 +} + +func (n *inode) Mkdir(ctx context.Context, name string, _ uint32, out *fuse.EntryOut) (*fs.Inode, syscall.Errno) { + n.mu.Lock() + defer n.mu.Unlock() + // TODO: Consider creating an S3 "directory" object so this new directory persists for new listings. + // https://docs.aws.amazon.com/AmazonS3/latest/userguide/using-folders.html + newPath := file.Join(n.path, name) + childNode := &inode{ + path: newPath, + ent: fuse.DirEntry{ + Name: name, + Ino: getIno(newPath), + Mode: getModeBits(true)}} + childInode := n.NewInode(ctx, childNode, fs.StableAttr{ + Mode: childNode.ent.Mode, + Ino: childNode.ent.Ino, + }) + out.Attr = newAttr(n.ent.Ino, n.ent.Mode, 0, time.Time{}) + out.SetEntryTimeout(cacheExpiration) + out.SetAttrTimeout(cacheExpiration) + return childInode, 0 +} diff --git a/cmd/grail-fuse/gfs/gfs_test.go b/cmd/grail-fuse/gfs/gfs_test.go new file mode 100644 index 00000000..b6abd0d9 --- /dev/null +++ b/cmd/grail-fuse/gfs/gfs_test.go @@ -0,0 +1,246 @@ +//+build !unit + +package gfs_test + +import ( + "context" + "fmt" + "io" + "io/ioutil" + golog "log" + "os" + "os/exec" + "sort" + "syscall" + "testing" + + "github.com/grailbio/base/cmd/grail-fuse/gfs" + "github.com/grailbio/base/log" + "github.com/grailbio/testutil/assert" + "github.com/grailbio/testutil/expect" + "github.com/grailbio/testutil/h" + "github.com/hanwen/go-fuse/v2/fs" + "github.com/hanwen/go-fuse/v2/fuse" + "github.com/hanwen/go-fuse/v2/posixtest" +) + +type tester struct { + t *testing.T + remoteDir string + mountDir string + tempDir string + server *fuse.Server +} + +type logOutputter struct{} + +func (logOutputter) Level() log.Level { return log.Debug } + +func (logOutputter) Output(calldepth int, level log.Level, s string) error { + return golog.Output(calldepth+1, s) +} + +func newTester(t *testing.T, remoteDir string) *tester { + log.SetFlags(log.Lmicroseconds | log.Lshortfile) + log.SetOutputter(logOutputter{}) + if remoteDir == "" { + var err error + remoteDir, err = ioutil.TempDir("", "remote") + assert.NoError(t, err) + } + tempDir, err := ioutil.TempDir("", "temp") + assert.NoError(t, err) + mountDir, err := ioutil.TempDir("", "mount") + assert.NoError(t, err) + root := gfs.NewRoot(context.Background(), remoteDir, tempDir) + server, err := fs.Mount(mountDir, root, &fs.Options{ + MountOptions: fuse.MountOptions{ + FsName: "test", + DisableXAttrs: true, + Debug: true}}) + assert.NoError(t, err) + log.Printf("mount remote dir %s on %s, tmp %s", remoteDir, mountDir, tempDir) + return &tester{ + t: t, + remoteDir: remoteDir, + mountDir: mountDir, + tempDir: tempDir, + server: server, + } +} + +func (t *tester) MountDir() string { return t.mountDir } +func (t *tester) RemoteDir() string { return t.remoteDir } + +func (t *tester) Cleanup() { + log.Printf("unmount %s", t.mountDir) + assert.NoError(t.t, t.server.Unmount()) + assert.NoError(t.t, os.RemoveAll(t.mountDir)) + log.Printf("unmount %s done", t.mountDir) +} + +func writeFile(t *testing.T, path string, data string) { + assert.NoError(t, ioutil.WriteFile(path, []byte(data), 0600)) +} + +func readFile(path string) string { + data, err := ioutil.ReadFile(path) + if err != nil { + return fmt.Sprintf("read %s: error %v", path, err) + } + return string(data) +} + +func readdir(t *testing.T, dir string) []string { + fp, err := os.Open(dir) + assert.NoError(t, err) + names, err := fp.Readdirnames(0) + assert.NoError(t, err) + sort.Strings(names) + assert.NoError(t, fp.Close()) + return names +} + +func TestSimple(t *testing.T) { + var ( + err error + remoteDir string + ) + if remoteDir, err = ioutil.TempDir("", "remote"); err != nil { + log.Panic(err) + } + defer func() { _ = os.RemoveAll(remoteDir) }() + + writeFile(t, remoteDir+"/fox.txt", "pink fox") + tc := newTester(t, remoteDir) + defer tc.Cleanup() + + expect.EQ(t, readFile(tc.MountDir()+"/fox.txt"), "pink fox") + expect.That(t, readdir(t, tc.MountDir()), h.ElementsAre("fox.txt")) + assert.NoError(t, os.Remove(tc.MountDir()+"/fox.txt")) + expect.HasSubstr(t, readFile(tc.MountDir()+"/fox.txt"), "no such file") + expect.HasSubstr(t, readFile(remoteDir+"/fox.txt"), "no such file") + expect.That(t, readdir(t, tc.MountDir()), h.ElementsAre()) +} + +func TestOverwrite(t *testing.T) { + tc := newTester(t, "") + defer tc.Cleanup() + + path := tc.MountDir() + "/bar.txt" + fp, err := os.Create(path) + assert.NoError(t, err) + _, err = fp.Write([]byte("purple dog")) + assert.NoError(t, err) + assert.NoError(t, fp.Close()) + expect.EQ(t, readFile(tc.RemoteDir()+"/bar.txt"), "purple dog") + expect.EQ(t, readFile(path), "purple dog") + + fp, err = os.Create(path) + assert.NoError(t, err) + _, err = fp.Write([]byte("white giraffe")) + assert.NoError(t, err) + assert.NoError(t, fp.Close()) + expect.EQ(t, readFile(tc.RemoteDir()+"/bar.txt"), "white giraffe") + expect.EQ(t, readFile(path), "white giraffe") +} + +func TestReadWrite(t *testing.T) { + tc := newTester(t, "") + defer tc.Cleanup() + + path := tc.MountDir() + "/baz.txt" + writeFile(t, path, "purple cat") + fp, err := os.OpenFile(path, os.O_RDWR, 0600) + assert.NoError(t, err) + _, err = fp.Write([]byte("yellow")) + assert.NoError(t, err) + assert.NoError(t, fp.Close()) + expect.EQ(t, readFile(path), "yellow cat") + + fp, err = os.OpenFile(path, os.O_RDWR, 0600) + assert.NoError(t, err) + _, err = fp.Seek(7, io.SeekStart) + assert.NoError(t, err) + _, err = fp.Write([]byte("bat")) + assert.NoError(t, fp.Close()) + expect.EQ(t, readFile(path), "yellow bat") +} + +func TestAppend(t *testing.T) { + tc := newTester(t, "") + defer tc.Cleanup() + + path := tc.MountDir() + "/append.txt" + writeFile(t, path, "orange ape") + log.Printf("reopening %s with O_APPEND", path) + fp, err := os.OpenFile(path, os.O_WRONLY|os.O_APPEND, 0600) + assert.NoError(t, err) + log.Printf("writing to %s", path) + _, err = fp.Write([]byte("red donkey")) + assert.NoError(t, err) + assert.NoError(t, fp.Close()) + expect.EQ(t, readFile(path), "orange apered donkey") +} + +func TestMkdir(t *testing.T) { + tc := newTester(t, "") + defer tc.Cleanup() + + path := tc.MountDir() + "/dir0" + assert.NoError(t, os.Mkdir(path, 0755)) + assert.EQ(t, readdir(t, path), []string{}) +} + +func TestDup(t *testing.T) { + tc := newTester(t, "") + defer tc.Cleanup() + + path := tc.MountDir() + "/f0.txt" + fp0, err := os.Create(path) + assert.NoError(t, err) + + fd1, err := syscall.Dup(int(fp0.Fd())) + assert.NoError(t, err) + fp1 := os.NewFile(uintptr(fd1), path) + + _, err = fp0.Write([]byte("yellow bug")) + assert.NoError(t, err) + assert.NoError(t, fp0.Close()) + _, err = fp1.Write([]byte("yellow hug")) + assert.HasSubstr(t, err, "bad file descriptor") + assert.NoError(t, fp1.Close()) + expect.EQ(t, readFile(path), "yellow bug") +} + +func TestShell(t *testing.T) { + tc := newTester(t, "") + defer tc.Cleanup() + + path := tc.MountDir() + "/cat.txt" + cmd := exec.Command("sh", "-c", fmt.Sprintf("echo foo >%s", path)) + assert.NoError(t, cmd.Run()) + expect.EQ(t, readFile(path), "foo\n") + + cmd = exec.Command("sh", "-c", fmt.Sprintf("echo bar >>%s", path)) + assert.NoError(t, cmd.Run()) + expect.EQ(t, readFile(path), "foo\nbar\n") + + path2 := tc.MountDir() + "/cat2.txt" + log.Printf("Start cat") + cmd = exec.Command("sh", "-c", fmt.Sprintf("cat <%s >%s", path, path2)) + assert.NoError(t, cmd.Run()) + expect.EQ(t, readFile(path2), "foo\nbar\n") +} + +func TestPosix(t *testing.T) { + tc := newTester(t, "") + defer tc.Cleanup() + + // Regression test for a directory listing bug (erroneously skipping some entries). + t.Run("ReadDir", func(t *testing.T) { + posixtest.ReadDir(t, tc.MountDir()) + }) + + // TODO(josh): Consider running more tests from posixtest. This may require new features. +} diff --git a/cmd/grail-fuse/gfs/main.go b/cmd/grail-fuse/gfs/main.go new file mode 100644 index 00000000..03cc081a --- /dev/null +++ b/cmd/grail-fuse/gfs/main.go @@ -0,0 +1,118 @@ +package gfs + +import ( + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "syscall" + "time" + + "github.com/grailbio/base/log" + "github.com/hanwen/go-fuse/v2/fs" + "github.com/hanwen/go-fuse/v2/fuse" +) + +const daemonEnv = "_GFS_DAEMON" + +func logSuffix() string { + return time.Now().Format(time.RFC3339) + ".log" +} + +// Main starts the FUSE server. It arranges so that contents of remoteRootDir +// can be accessed through mountDir. Arg remoteRootDir is typically +// "s3://". mountDir must be a directory in the local file system. +// +// If daemon==true, this function will fork itself to become a background +// process, and this process will exit. Otherwise this function blocks until the +// filesystem is unmounted by a superuser running "umount ". +// +// Arg tmpDir specifies the directory to store temporary files. If tmpDir=="", +// it is set to /tmp/gfs-cache-. +// +// logDir specifies the directory for storing log files. If "", log messages are +// sent to stderr. +func Main(ctx context.Context, remoteRootDir, mountDir string, daemon bool, tmpDir, logDir string) { + if daemon { + daemonize(logDir) + // daemonize will exit the parent process if daemon==true + } else if logDir != "" { + path := filepath.Join(logDir, "gfs."+logSuffix()) + fd, err := os.OpenFile(path, syscall.O_CREAT|syscall.O_WRONLY|syscall.O_APPEND, 0600) + if err != nil { + log.Panicf("create %s: %v", path, err) + } + log.Printf("Storing log files in %s", path) + log.SetOutput(fd) + } + if err := os.MkdirAll(mountDir, 0700); err != nil { + log.Panicf("mkdir %s: %v", mountDir, err) + } + root := NewRoot(ctx, remoteRootDir, tmpDir) + server, err := fs.Mount(mountDir, root, &fs.Options{ + MountOptions: fuse.MountOptions{ + FsName: "grail", + DisableXAttrs: true, + Debug: log.At(log.Debug), + }, + }) + if err != nil { + log.Panicf("mount %s: %v", mountDir, err) + } + server.Wait() +} + +func daemonize(logDir string) { + if os.Getenv(daemonEnv) == "" { + suffix := logSuffix() + if logDir == "" { + logDir = "/tmp" + log.Printf("Storing log files in %s/gfs-*%s", logDir, suffix) + } + if logDir == "" { + log.Panic("-log-dir must set with -daemon") + } + stdinFd, err := os.Open("/dev/null") + if err != nil { + log.Panic(err) + } + stdoutFd, err := os.Create(filepath.Join(logDir, "gfs-stdout."+suffix)) + if err != nil { + log.Panic(err) + } + stderrFd, err := os.Create(filepath.Join(logDir, "gfs-stderr."+suffix)) + if err != nil { + log.Panic(err) + } + os.Stdout.Sync() + os.Stderr.Sync() + cmd := exec.Command(os.Args[0], os.Args[1:]...) + cmd.Stdout = stdoutFd + cmd.Stderr = stderrFd + cmd.Stdin = stdinFd + cmd.Env = append([]string(nil), os.Environ()...) + cmd.Env = append(cmd.Env, daemonEnv+"=1") + cmd.Start() + os.Exit(0) + } +} + +func NewRoot(ctx context.Context, remoteRootDir, tmpDir string) fs.InodeEmbedder { + if tmpDir == "" { + tmpDir = fmt.Sprintf("/tmp/gfscache-%d", os.Geteuid()) + } + if !strings.HasSuffix(remoteRootDir, "/") { + // getFileName misbehaves otherwise. + remoteRootDir += "/" + } + if err := os.MkdirAll(tmpDir, 0700); err != nil { + log.Panic(err) + } + ent := fuse.DirEntry{ + Name: "/", + Ino: getIno(""), + Mode: getModeBits(true)} + return &rootInode{inode: inode{path: remoteRootDir, ent: ent}, ctx: ctx, tmpDir: tmpDir} +} diff --git a/cmd/grail-fuse/main.go b/cmd/grail-fuse/main.go new file mode 100644 index 00000000..924fa8c6 --- /dev/null +++ b/cmd/grail-fuse/main.go @@ -0,0 +1,49 @@ +package main + +import ( + "context" + "flag" + "fmt" + "net/http" + _ "net/http/pprof" + "os" + + "github.com/grailbio/base/cmd/grail-fuse/gfs" + "github.com/grailbio/base/file" + "github.com/grailbio/base/file/s3file" + "github.com/grailbio/base/log" +) + +func main() { + flag.Usage = func() { + fmt.Fprintf(flag.CommandLine.Output(), `Usage: +%s [flags...] MOUNTDIR + +To unmount the file system, run "fusermount -u MOUNTDIR". +`, os.Args[0]) + flag.PrintDefaults() + } + remoteRootDirFlag := flag.String("remote-root-dir", "s3://", `Remote root directory`) + logDirFlag := flag.String("log-dir", "", `Directory to store log files. +If empty, log messages are sent to stderr`) + tmpDirFlag := flag.String("tmp-dir", "", `Tmp directory location. If empty, /tmp/gfscache- is used`) + daemonFlag := flag.Bool("daemon", false, "Run in background") + httpFlag := flag.String("http", "localhost:54321", "Run an HTTP status server") + log.AddFlags() + log.SetFlags(log.Lmicroseconds | log.Lshortfile) + flag.Parse() + args := flag.Args() + if len(args) != 1 { + log.Panic("fuse: missing mount point") + } + if len(*httpFlag) > 0 { + log.Printf("starting status server at %s", *httpFlag) + go func() { + log.Print(http.ListenAndServe(*httpFlag, nil)) + }() + } + file.RegisterImplementation("s3", func() file.Implementation { + return s3file.NewImplementation(s3file.NewDefaultProvider(), s3file.Options{}) + }) + gfs.Main(context.Background(), *remoteRootDirFlag, args[0], *daemonFlag, *tmpDirFlag, *logDirFlag) +} diff --git a/cmd/grail-role-group/create.go b/cmd/grail-role-group/create.go index c31c050e..69c546c3 100644 --- a/cmd/grail-role-group/create.go +++ b/cmd/grail-role-group/create.go @@ -4,11 +4,9 @@ import ( "fmt" "strings" - "v.io/x/lib/vlog" - - "google.golang.org/api/admin/directory/v1" - + admin "google.golang.org/api/admin/directory/v1" "v.io/x/lib/cmdline" + "v.io/x/lib/vlog" ) func runCreate(_ *cmdline.Env, args []string) error { @@ -16,7 +14,9 @@ func runCreate(_ *cmdline.Env, args []string) error { return fmt.Errorf("bad number of arguments, expected 1, got %q", args) } groupName := args[0] - if !strings.HasSuffix(groupName, groupSuffix) { + if !Any(groupSuffix, func(v string) bool { + return strings.HasSuffix(groupName, v) + }) { return fmt.Errorf("bad suffix: the group name %q doesn't end in %q", groupName, groupSuffix) } diff --git a/cmd/grail-role-group/description.go b/cmd/grail-role-group/description.go index 4df0cb46..e0bbb75a 100644 --- a/cmd/grail-role-group/description.go +++ b/cmd/grail-role-group/description.go @@ -6,11 +6,18 @@ import ( ) // description generates a standard description for the group. It assumes the -// group name for / is --aws-role@grailbio.com. +// group name for / is --@grailbio.com. func description(group string) string { - v := strings.SplitN(strings.TrimSuffix(group, groupSuffix), "-", 2) - if len(v) != 2 { + if strings.HasSuffix(group, "-aws-role@grailbio.com") { + v := strings.SplitN(strings.TrimSuffix(group, "-aws-role@grailbio.com"), "-", 2) + if len(v) != 2 { + return "" + } + return fmt.Sprintf("Please request access to this group if you need access to the %s/%s role account.", v[0], v[1]) + } else if strings.HasSuffix(group, "-ticket@grailbio.com") { + v := strings.TrimSuffix(group, "-ticket@grailbio.com") + return fmt.Sprintf("Please request access to this group if you need access to the ticket %s.", v) + } else { return "" } - return fmt.Sprintf("Please request access to this group if you need access to the %s/%s role account.", v[0], v[1]) } diff --git a/cmd/grail-role-group/description_test.go b/cmd/grail-role-group/description_test.go index f3100d63..b2869465 100644 --- a/cmd/grail-role-group/description_test.go +++ b/cmd/grail-role-group/description_test.go @@ -11,6 +11,10 @@ func TestDescription(t *testing.T) { "eng-dev-aws-role@grailbio.com", "Please request access to this group if you need access to the eng/dev role account.", }, + { + "vendor-procured-samples-ticket@grailbio.com", + "Please request access to this group if you need access to the ticket vendor-procured-samples.", + }, {"eng", ""}, {"", ""}, } diff --git a/cmd/grail-role-group/doc.go b/cmd/grail-role-group/doc.go index 25123814..4c8b943f 100644 --- a/cmd/grail-role-group/doc.go +++ b/cmd/grail-role-group/doc.go @@ -16,12 +16,10 @@ The global flags are: log to standard error as well as files -block-profile= filename prefix for block profiles - -block-profile-rate=1 - rate for runtime. SetBlockProfileRate + -block-profile-rate=200 + rate for runtime.SetBlockProfileRate -cpu-profile= filename for cpu profile - -gops=false - enable the gops listener -heap-profile= filename prefix for heap profiles -log_backtrace_at=:0 @@ -36,7 +34,7 @@ The global flags are: Displays metadata for the program and exits. -mutex-profile= filename prefix for mutex profiles - -mutex-profile-rate=1 + -mutex-profile-rate=200 rate for runtime.SetMutexProfileFraction -pprof= address for pprof server @@ -51,35 +49,6 @@ The global flags are: Dump timing information to stderr before exiting the program. -v=0 log level for V logs - -v23.credentials= - directory to use for storing security credentials - -v23.i18n-catalogue= - 18n catalogue files to load, comma separated - -v23.namespace.root=[/(v23.grail.com:internal:mounttabled)@ns.v23.grail.com:8101] - local namespace root; can be repeated to provided multiple roots - -v23.permissions.file=map[] - specify a perms file as : - -v23.permissions.literal= - explicitly specify the runtime perms as a JSON-encoded access.Permissions. - Overrides all --v23.permissions.file flags. - -v23.proxy= - object name of proxy service to use to export services across network - boundaries - -v23.tcp.address= - address to listen on - -v23.tcp.protocol=wsh - protocol to listen with - -v23.vtrace.cache-size=1024 - The number of vtrace traces to store in memory. - -v23.vtrace.collect-regexp= - Spans and annotations that match this regular expression will trigger trace - collection. - -v23.vtrace.dump-on-shutdown=true - If true, dump all stored traces on runtime shutdown. - -v23.vtrace.sample-rate=0 - Rate (from 0.0 to 1.0) to sample vtrace traces. - -v23.vtrace.v=0 - The verbosity level of the log messages to be captured in traces -vmodule= comma-separated list of globpattern=N settings for filename-filtered logging (without the .go suffix). E.g. foo/bar/baz.go is matched by patterns baz or diff --git a/cmd/grail-role-group/list.go b/cmd/grail-role-group/list.go index 0de15dcf..d7661a08 100644 --- a/cmd/grail-role-group/list.go +++ b/cmd/grail-role-group/list.go @@ -5,7 +5,7 @@ import ( "fmt" "strings" - "google.golang.org/api/admin/directory/v1" + admin "google.golang.org/api/admin/directory/v1" "google.golang.org/api/groupssettings/v1" goauth2 "google.golang.org/api/oauth2/v1" "v.io/x/lib/cmdline" @@ -27,7 +27,9 @@ func runList(_ *cmdline.Env, args []string) error { ctx := context.Background() return service.Groups.List().Domain(domain).Pages(ctx, func(groups *admin.Groups) error { for _, g := range groups.Groups { - if strings.HasSuffix(g.Email, groupSuffix) { + if Any(groupSuffix, func(v string) bool { + return strings.HasSuffix(g.Email, v) + }) { fmt.Printf("%v\n", g.Email) } } diff --git a/cmd/grail-role-group/main.go b/cmd/grail-role-group/main.go index 07db47ce..e0d78c72 100644 --- a/cmd/grail-role-group/main.go +++ b/cmd/grail-role-group/main.go @@ -3,20 +3,21 @@ // license that can be found in the LICENSE file. // The following enables go generate to generate the doc.go file. -//go:generate go run $GRAIL/go/src/vendor/v.io/x/lib/cmdline/testdata/gendoc.go "--build-cmd=go install" --copyright-notice= . -help +//go:generate go run v.io/x/lib/cmdline/gendoc "--build-cmd=go install" --copyright-notice= . -help + package main import ( "net/http" "os" - "google.golang.org/api/groupssettings/v1" - "github.com/grailbio/base/cmd/grail-role-group/googleclient" + "github.com/grailbio/base/cmdutil" _ "github.com/grailbio/base/cmdutil/interactive" "golang.org/x/oauth2" - "google.golang.org/api/admin/directory/v1" + admin "google.golang.org/api/admin/directory/v1" + "google.golang.org/api/groupssettings/v1" "v.io/x/lib/cmdline" ) @@ -31,7 +32,8 @@ const ( ) const domain = "grailbio.com" -const groupSuffix = "-aws-role@grailbio.com" + +var groupSuffix = []string{"-aws-role@grailbio.com", "-ticket@grailbio.com"} var ( browserFlag bool @@ -117,6 +119,16 @@ func newCmdUpdate() *cmdline.Command { return cmd } +// Any return true if any string in list returns true based on the passed comparison method +func Any(vs []string, f func(string) bool) bool { + for _, v := range vs { + if f(v) { + return true + } + } + return false +} + func main() { cmdline.HideGlobalFlagsExcept() cmdline.Main(newCmdRoot()) diff --git a/cmd/grail-role-group/update.go b/cmd/grail-role-group/update.go index 7213a9de..90a85e5b 100644 --- a/cmd/grail-role-group/update.go +++ b/cmd/grail-role-group/update.go @@ -49,7 +49,9 @@ func runUpdate(_ *cmdline.Env, args []string) error { return fmt.Errorf("bad number of arguments, expected 1, got %q", args) } groupName := args[0] - if !strings.HasSuffix(groupName, groupSuffix) { + if !Any(groupSuffix, func(v string) bool { + return strings.HasSuffix(groupName, v) + }) { return fmt.Errorf("bad suffix: the group name %q doesn't end in %q", groupName, groupSuffix) } diff --git a/cmd/grail-role/doc.go b/cmd/grail-role/doc.go index bd5b7177..3ddb9883 100644 --- a/cmd/grail-role/doc.go +++ b/cmd/grail-role/doc.go @@ -42,31 +42,57 @@ The global flags are: log level for V logs -v23.credentials= directory to use for storing security credentials - -v23.i18n-catalogue= - 18n catalogue files to load, comma separated - -v23.namespace.root=[/(v23.grail.com:internal:mounttabled)@ns.v23.grail.com:8101] + -v23.namespace.root=[/(v23.grail.com:internal:mounttabled)@ns-0.v23.grail.com:8101,/(v23.grail.com:internal:mounttabled)@ns-1.v23.grail.com:8101,/(v23.grail.com:internal:mounttabled)@ns-2.v23.grail.com:8101] local namespace root; can be repeated to provided multiple roots - -v23.permissions.file=map[] + -v23.permissions.file= specify a perms file as : -v23.permissions.literal= explicitly specify the runtime perms as a JSON-encoded access.Permissions. - Overrides all --v23.permissions.file flags. + Overrides all --v23.permissions.file flags -v23.proxy= object name of proxy service to use to export services across network boundaries + -v23.proxy.limit=0 + max number of proxies to connect to when the policy is to connect to all + proxies; 0 implies all proxies + -v23.proxy.policy= + policy for choosing from a set of available proxy instances -v23.tcp.address= address to listen on - -v23.tcp.protocol=wsh + -v23.tcp.protocol= protocol to listen with + -v23.virtualized.advertise-private-addresses= + if set the process will also advertise its private addresses + -v23.virtualized.disallow-native-fallback=false + if set, a failure to detect the requested virtualization provider will result + in an error, otherwise, native mode is used + -v23.virtualized.dns.public-name= + if set the process will use the supplied dns name (and port) without + resolution for its entry in the mounttable + -v23.virtualized.docker= + set if the process is running in a docker container and needs to configure + itself differently therein + -v23.virtualized.provider= + the name of the virtualization/cloud provider hosting this process if the + process needs to configure itself differently therein + -v23.virtualized.tcp.public-address= + if set the process will use this address (resolving via dns if appropriate) + for its entry in the mounttable + -v23.virtualized.tcp.public-protocol= + if set the process will use this protocol for its entry in the mounttable -v23.vtrace.cache-size=1024 - The number of vtrace traces to store in memory. + The number of vtrace traces to store in memory -v23.vtrace.collect-regexp= Spans and annotations that match this regular expression will trigger trace - collection. + collection -v23.vtrace.dump-on-shutdown=true - If true, dump all stored traces on runtime shutdown. + If true, dump all stored traces on runtime shutdown + -v23.vtrace.enable-aws-xray=false + Enable the use of AWS x-ray integration with vtrace + -v23.vtrace.root-span-name= + Set the name of the root vtrace span created by the runtime at startup -v23.vtrace.sample-rate=0 - Rate (from 0.0 to 1.0) to sample vtrace traces. + Rate (from 0.0 to 1.0) to sample vtrace traces -v23.vtrace.v=0 The verbosity level of the log messages to be captured in traces -vmodule= diff --git a/cmd/grail-role/main.go b/cmd/grail-role/main.go index 9b1def7e..1c74e13e 100644 --- a/cmd/grail-role/main.go +++ b/cmd/grail-role/main.go @@ -1,5 +1,6 @@ // The following enables go generate to generate the doc.go file. -//go:generate go run $GRAIL/go/src/vendor/v.io/x/lib/cmdline/testdata/gendoc.go "--build-cmd=go install" --copyright-notice= . -help +//go:generate go run v.io/x/lib/cmdline/gendoc "--build-cmd=go install" --copyright-notice= . -help + package main import ( @@ -9,8 +10,8 @@ import ( "time" _ "github.com/grailbio/base/cmdutil/interactive" - "github.com/grailbio/base/grail/data/v23data" "github.com/grailbio/base/security/ticket" + _ "github.com/grailbio/v23/factories/grail" // Needed to initialize v23 "v.io/v23" "v.io/v23/context" "v.io/v23/security" @@ -19,7 +20,6 @@ import ( "v.io/x/lib/vlog" libsecurity "v.io/x/ref/lib/security" "v.io/x/ref/lib/v23cmd" - _ "v.io/x/ref/runtime/factories/grail" ) const blessingSuffix = "_role" @@ -53,17 +53,6 @@ Example: return cmd } -func decodeBlessings(s string) (security.Blessings, error) { - b, err := base64.URLEncoding.DecodeString(s) - if err != nil { - return security.Blessings{}, err - } - - dec := vom.NewDecoder(bytes.NewBuffer(b)) - var blessings security.Blessings - return blessings, dec.Decode(&blessings) -} - func run(ctx *context.T, env *cmdline.Env, args []string) error { if len(args) != 2 { return fmt.Errorf("Exactly two arguments are required: ") @@ -98,7 +87,7 @@ func run(ctx *context.T, env *cmdline.Env, args []string) error { } client := ticket.TicketServiceClient(ticketPath) - ctx, cancel := context.WithTimeout(roleCtx, timeoutFlag) + _, cancel := context.WithTimeout(roleCtx, timeoutFlag) defer cancel() t, err := client.Get(roleCtx) @@ -110,7 +99,7 @@ func run(ctx *context.T, env *cmdline.Env, args []string) error { vanadiumTicket, ok := t.(ticket.TicketVanadiumTicket) if !ok { - return fmt.Errorf("Not a VanadiumTicket: %#s", t) + return fmt.Errorf("Not a VanadiumTicket: %#v", t) } var blessings security.Blessings @@ -130,16 +119,11 @@ func run(ctx *context.T, env *cmdline.Env, args []string) error { return fmt.Errorf("failed to add blessings to recognized roots: %v", err) } - if err := v23data.InjectPipelineBlessings(ctx); err != nil { - vlog.Error(err) - return fmt.Errorf("failed to add the pipeline roots") - } - fmt.Printf("Public key: %s\n", principal.PublicKey()) fmt.Println("---------------- BlessingStore ----------------") - fmt.Printf(principal.BlessingStore().DebugString()) + fmt.Print(principal.BlessingStore().DebugString()) fmt.Println("---------------- BlessingRoots ----------------") - fmt.Printf(principal.Roots().DebugString()) + fmt.Print(principal.Roots().DebugString()) return nil } diff --git a/cmd/grail-ssh/doc.go b/cmd/grail-ssh/doc.go new file mode 100644 index 00000000..3617eb17 --- /dev/null +++ b/cmd/grail-ssh/doc.go @@ -0,0 +1,122 @@ +// This file was auto-generated via go generate. +// DO NOT UPDATE MANUALLY + +/* +Command that simplifies connecting to GRAIL systems using ssh with ssh +certificates for authentication. + +Usage: + ssh [flags] + +The ssh flags are: + -i=/mnt/home/jjc/.ssh/id_rsa + Path to the SSH private key that will be used for the connection + -l= + Username to provide to the remote host. If not provided selects the first + principal defined as part of the ticket definition. + -ssh=ssh + What ssh client to use + +The global flags are: + -alsologtostderr=false + log to standard error as well as files + -block-profile= + filename prefix for block profiles + -block-profile-rate=200 + rate for runtime.SetBlockProfileRate + -cpu-profile= + filename for cpu profile + -heap-profile= + filename prefix for heap profiles + -log_backtrace_at=:0 + when logging hits line file:N, emit a stack trace + -log_dir= + if non-empty, write log files to this directory + -logtostderr=false + log to standard error instead of files + -max_stack_buf_size=4292608 + max size in bytes of the buffer to use for logging stack traces + -metadata= + Displays metadata for the program and exits. + -mutex-profile= + filename prefix for mutex profiles + -mutex-profile-rate=200 + rate for runtime.SetMutexProfileFraction + -pprof= + address for pprof server + -profile-interval-s=0 + If >0, output new profiles at this interval (seconds). If <=0, profiles are + written only when Write() is called + -stderrthreshold=2 + logs at or above this threshold go to stderr + -thread-create-profile= + filename prefix for thread create profiles + -time=false + Dump timing information to stderr before exiting the program. + -v=0 + log level for V logs + -v23.credentials= + directory to use for storing security credentials + -v23.namespace.root=[/(v23.grail.com:internal:mounttabled)@ns-0.v23.grail.com:8101,/(v23.grail.com:internal:mounttabled)@ns-1.v23.grail.com:8101,/(v23.grail.com:internal:mounttabled)@ns-2.v23.grail.com:8101] + local namespace root; can be repeated to provided multiple roots + -v23.permissions.file= + specify a perms file as : + -v23.permissions.literal= + explicitly specify the runtime perms as a JSON-encoded access.Permissions. + Overrides all --v23.permissions.file flags + -v23.proxy= + object name of proxy service to use to export services across network + boundaries + -v23.proxy.limit=0 + max number of proxies to connect to when the policy is to connect to all + proxies; 0 implies all proxies + -v23.proxy.policy= + policy for choosing from a set of available proxy instances + -v23.tcp.address= + address to listen on + -v23.tcp.protocol= + protocol to listen with + -v23.virtualized.advertise-private-addresses= + if set the process will also advertise its private addresses + -v23.virtualized.disallow-native-fallback=false + if set, a failure to detect the requested virtualization provider will result + in an error, otherwise, native mode is used + -v23.virtualized.dns.public-name= + if set the process will use the supplied dns name (and port) without + resolution for its entry in the mounttable + -v23.virtualized.docker= + set if the process is running in a docker container and needs to configure + itself differently therein + -v23.virtualized.provider= + the name of the virtualization/cloud provider hosting this process if the + process needs to configure itself differently therein + -v23.virtualized.tcp.public-address= + if set the process will use this address (resolving via dns if appropriate) + for its entry in the mounttable + -v23.virtualized.tcp.public-protocol= + if set the process will use this protocol for its entry in the mounttable + -v23.vtrace.cache-size=1024 + The number of vtrace traces to store in memory + -v23.vtrace.collect-regexp= + Spans and annotations that match this regular expression will trigger trace + collection + -v23.vtrace.dump-on-shutdown=true + If true, dump all stored traces on runtime shutdown + -v23.vtrace.enable-aws-xray=false + Enable the use of AWS x-ray integration with vtrace + -v23.vtrace.root-span-name= + Set the name of the root vtrace span created by the runtime at startup + -v23.vtrace.sample-rate=0 + Rate (from 0.0 to 1.0) to sample vtrace traces + -v23.vtrace.v=0 + The verbosity level of the log messages to be captured in traces + -vmodule= + comma-separated list of globpattern=N settings for filename-filtered logging + (without the .go suffix). E.g. foo/bar/baz.go is matched by patterns baz or + *az or b* but not by bar/baz or baz.go or az or b.* + -vpath= + comma-separated list of regexppattern=N settings for file pathname-filtered + logging (without the .go suffix). E.g. foo/bar/baz.go is matched by patterns + foo/bar/baz or fo.*az or oo/ba or b.z but not by foo/bar/baz.go or fo*az +*/ +package main diff --git a/cmd/grail-ssh/main.go b/cmd/grail-ssh/main.go new file mode 100644 index 00000000..1096463f --- /dev/null +++ b/cmd/grail-ssh/main.go @@ -0,0 +1,60 @@ +// The following enables go generate to generate the doc.go file. +//go:generate go run v.io/x/lib/cmdline/gendoc "--build-cmd=go install" --copyright-notice= . -help + +package main + +import ( + "flag" + "io" + "os" + + "github.com/grailbio/base/cmdutil" + _ "github.com/grailbio/base/cmdutil/interactive" // print output to console + "github.com/grailbio/base/vcontext" + _ "github.com/grailbio/v23/factories/grail" + "v.io/v23/context" + "v.io/x/lib/cmdline" +) + +var ( + sshFlag string + idRsaFlag string + userFlag string +) + +func newCmdRoot() *cmdline.Command { + root := &cmdline.Command{ + Runner: cmdutil.RunnerFuncWithAccessCheck(vcontext.Background, runner(runSsh)), + Name: "ssh", + Short: "ssh to a VM", + ArgsName: "", + Long: ` +Command that simplifies connecting to GRAIL systems using ssh with ssh certificates for authentication. +`, + LookPath: false, + } + root.Flags.StringVar(&sshFlag, "ssh", "ssh", "What ssh client to use") + root.Flags.StringVar(&idRsaFlag, "i", os.ExpandEnv("${HOME}/.ssh/id_rsa"), "Path to the SSH private key that will be used for the connection") + root.Flags.StringVar(&userFlag, "l", "", "Username to provide to the remote host. If not provided selects the first principal defined as part of the ticket definition.") + + return root +} + +type runnerFunc func(*context.T, io.Writer, *cmdline.Env, []string) error +type v23RunnerFunc func(*context.T, *cmdline.Env, []string) error + +// runner wraps a runnerFunc to produce a cmdline.RunnerFunc. +func runner(f runnerFunc) v23RunnerFunc { + return func(ctx *context.T, env *cmdline.Env, args []string) error { + // No special actions needed that applies to all runners. + return f(ctx, os.Stdout, env, args) + } +} + +func main() { + // We suppress 'alsologtosterr' because this is a user tool. + _ = flag.Set("alsologtostderr", "false") + cmdRoot := newCmdRoot() + cmdline.HideGlobalFlagsExcept() + cmdline.Main(cmdRoot) +} diff --git a/cmd/grail-ssh/ssh.go b/cmd/grail-ssh/ssh.go new file mode 100644 index 00000000..650db9f3 --- /dev/null +++ b/cmd/grail-ssh/ssh.go @@ -0,0 +1,254 @@ +package main + +import ( + "fmt" + "io" + "io/ioutil" + "os" + "os/exec" + "regexp" + "strings" + "syscall" + "time" + + "github.com/grailbio/base/security/ticket" + sshLib "golang.org/x/crypto/ssh" + terminal "golang.org/x/crypto/ssh/terminal" + "v.io/v23/context" + "v.io/x/lib/cmdline" + "v.io/x/lib/vlog" +) + +const ( + timeout = 10 * time.Second +) + +func runSsh(ctx *context.T, out io.Writer, env *cmdline.Env, args []string) error { + if len(args) == 0 { + return env.UsageErrorf("At least one argument () is required.") + } + + ticketPath := args[0] + args = args[1:] // remove the ticket from the arguments + + client := ticket.TicketServiceClient(ticketPath) + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + // Read in the private key + privateKey, err := ioutil.ReadFile(idRsaFlag) + if err != nil { + return fmt.Errorf("Failed to read private key - %v", err) + } + + // Load the private key + privateSigner, err := sshLib.ParsePrivateKey(privateKey) + if err != nil { + switch err.(type) { + case *sshLib.PassphraseMissingError: + // try to load the key with a passphrase + fmt.Print("Enter SSH Key Passphrase: ") + bytePassword, _ := terminal.ReadPassword(int(syscall.Stdin)) + privateSigner, err = sshLib.ParsePrivateKeyWithPassphrase(privateKey, bytePassword) + if err != nil { + return fmt.Errorf("Failed to read private key - %v", err) + } + fmt.Println("\nSSH Key decoded") + default: + return fmt.Errorf("Failed to parse private key - %v", err) + } + } + + if err != nil { + return fmt.Errorf("Failed to parse private key - %v", err) + } + + var parameters = []ticket.Parameter{ + ticket.Parameter{ + Key: "PublicKey", + Value: string(sshLib.MarshalAuthorizedKey(privateSigner.PublicKey())), + }, + } + + t, err := client.GetWithParameters(ctx, parameters) + if err != nil { + return fmt.Errorf("Failed to communicate with the ticket-server - %v", err) + } + + switch t.Index() { + case (ticket.TicketSshCertificateTicket{}).Index(): + { + creds := t.(ticket.TicketSshCertificateTicket).Value.Credentials + // pull the public certificate out and write to the id_rsa cert path location + if err = ioutil.WriteFile(idRsaFlag+"-cert.pub", []byte(creds.Cert), 0644); err != nil { + return fmt.Errorf("Failed to write ssh public key "+idRsaFlag+"-cert.pub"+" - %v", err) + } + } + default: + { + return fmt.Errorf("Provided ticket is not a SSHCertificateTicket") + } + } + + var computeInstances []ticket.ComputeInstance = t.(ticket.TicketSshCertificateTicket).Value.ComputeInstances + var username = t.(ticket.TicketSshCertificateTicket).Value.Username + // Use the environment provided username if specified + if userFlag != "" { + username = userFlag + } + // Throw an error if no username is set + if username == "" { + vlog.Errorf("Username was not provided in ticket or via command line") + // TODO: return the exit code from the cmd. + os.Exit(1) + } + + var host string + instanceMatch := regexp.MustCompile("^i-[a-zA-Z0-9]+$") + // Not the best regex (e.g. doesn't match IPV6) to use here ... better regexs are available at + // https://stackoverflow.com/questions/106179/regular-expression-to-match-dns-hostname-or-ip-address + hostIpMatch := regexp.MustCompile("^([a-zA-Z0-9]+\\.)+[a-zA-Z0-9]+$") + dnsMatch := regexp.MustCompile("^(([a-zA-Z]|[a-zA-Z][a-zA-Z0-9\\-]*[a-zA-Z0-9])\\.)*([A-Za-z]|[A-Za-z][A-Za-z0-9\\-]*[A-Za-z0-9])$") + stopMatch := regexp.MustCompile("^--$") + + // Loop through the arguments provided to the CLI tool - and try to match to a hostname or an instanceID. + // Stop processing if -- is found. + // host is the last match found. + for i, arg := range args { + match := instanceMatch.MatchString(arg) + if err != nil { + return fmt.Errorf("Failed to check if input %s matched an instanceId - %v", arg, err) + } + + // Find matching instanceId in list + if match { + // Remove the matched element from the list + args = append(args[:i], args[i+1:]...) + for _, instance := range computeInstances { + if instance.InstanceId == arg { + vlog.Errorf("Matched InstanceID %s - %s", instance.InstanceId, instance.PublicIp) + fmt.Printf("Matched InstanceID %s - %s \n", instance.InstanceId, instance.PublicIp) + host = instance.PublicIp + break + } + } + if host == "" { + return fmt.Errorf("Failed to find a match for InstanceId provided %s", arg) + } + break + } + + // // check for a dns name to stop processing + match = dnsMatch.MatchString(arg) + if err != nil { + return fmt.Errorf("Failed to check if input %s matched a DNS name' - %v", arg, err) + } + if match { + host = arg + args = append(args[:i], args[i+1:]...) + fmt.Printf("Matched DNS %s \n", host) + break + } + + // check for a dns/ip host name to stop processing + match = hostIpMatch.MatchString(arg) + if err != nil { + return fmt.Errorf("Failed to check if input %s matched an '^[a-zA-Z0-9]+\\.[a-zA-Z0-9]+' - %v", arg, err) + } + if match { + host = arg + args = append(args[:i], args[i+1:]...) + fmt.Printf("Matched Host IP %s \n", host) + break + } + + // check for a -- to stop processing + match = stopMatch.MatchString(arg) + if err != nil { + return fmt.Errorf("Failed to check if input %s matched an '--' - %v", arg, err) + } + if match { + break + } + } + + // If no host has been found present a list + if host == "" { + fmt.Printf("No host or InstanceId provided - please select from list provided by the ticket") + // prompt for which instance to connect too + for index, instance := range computeInstances { + fmt.Printf("[%d] %s:%s - %s\n", index, instance.InstanceId, getTagValueFromKey(instance, "Name"), instance.PublicIp) + } + var instanceSelection int = -1 // initialize to negative value + fmt.Printf("Enter number for corresponding system to connect to?") + if _, err := fmt.Scanf("%d", &instanceSelection); err != nil { + return err + } + + if instanceSelection < 0 || instanceSelection > len(computeInstances) { + return fmt.Errorf("Selected index (%d) was not in the list", instanceSelection) + } + if computeInstances[instanceSelection].PublicIp != "" { + host = computeInstances[instanceSelection].PublicIp + } else { + host = computeInstances[instanceSelection].PrivateIp + } + } + + if host == "" { + return fmt.Errorf("Host selection failed - please provide an ip, DNS name, or select host from list with no input") + } + + var sshArgs = []string{ + // Forward the ssh agent. + "-A", + // Forward the X11 connections. + "-X", + // Don't check the identity of the remote host. + "-o", "StrictHostKeyChecking no", + // Don't store the identity of the remote host. + "-o", "UserKnownHostsFile /dev/null", + // Pass the private key to the ssh command + "-i", idRsaFlag, + } + + // When using MOSH, SSH connection commands need to be passed like + // $ mosh --ssh="ssh -i ./identity" username@host + if sshFlag == "mosh" { + var moshSshArg = strings.Join(sshArgs, " ") + sshArgs = []string{ + "--ssh", moshSshArg, + } + } + + sshArgs = append(sshArgs, + username+"@"+host, + ) + + sshArgs = append(sshArgs, args...) + + vlog.Infof("exec: %q %q", sshFlag, sshArgs) + cmd := exec.Command(sshFlag, sshArgs...) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + vlog.Errorf("ssh error: %s", err) + // TODO: return the exit code from the cmd. + os.Exit(1) + } + + return nil +} + +// Return the key value from the list of Tag Parameters +func getTagValueFromKey(instance ticket.ComputeInstance, key string) string { + for _, param := range instance.Tags { + if param.Key == key { + return param.Value + } + } + + // Throwing a NoSuchKey value is overkill for cases where tag is not added + return "" +} diff --git a/cmd/grail-ticket/doc.go b/cmd/grail-ticket/doc.go index ea8f505e..153a3c95 100644 --- a/cmd/grail-ticket/doc.go +++ b/cmd/grail-ticket/doc.go @@ -7,13 +7,15 @@ identified using a Vanadium name. Examples: - grail ticket ticket/reflow/gdc/aws - grail ticket /127.0.0.1:8000/reflow/gdc/aws + grail-ticket tickets/eng/dev/aws + grail-ticket /127.0.0.1:8000/eng/dev/aws Note that tickets can be enumerated using the 'namespace' Vanadium tool: + namespace glob tickets/... + namespace glob tickets/eng/... namespace glob /127.0.0.1:8000/... - namespace glob /127.0.0.1:8000/reflow/... + namespace glob /127.0.0.1:8000/eng/... Usage: grail-ticket [flags] @@ -27,7 +29,11 @@ The grail-ticket flags are: Force a JSON output even for the tickets that have special handling -key= PEM file to store the private key for a TLS-based ticket - -timeout=10s + -list=false + List accessible tickets + -rationale= + Rationale for accessing ticket + -timeout=1m30s Timeout for the requests to the ticket-server The global flags are: @@ -51,31 +57,57 @@ The global flags are: log level for V logs -v23.credentials= directory to use for storing security credentials - -v23.i18n-catalogue= - 18n catalogue files to load, comma separated - -v23.namespace.root=[/(v23.grail.com:internal:mounttabled)@ns.v23.grail.com:8101] + -v23.namespace.root=[/(v23.grail.com:internal:mounttabled)@ns-0.v23.grail.com:8101,/(v23.grail.com:internal:mounttabled)@ns-1.v23.grail.com:8101,/(v23.grail.com:internal:mounttabled)@ns-2.v23.grail.com:8101] local namespace root; can be repeated to provided multiple roots - -v23.permissions.file=map[] + -v23.permissions.file= specify a perms file as : -v23.permissions.literal= explicitly specify the runtime perms as a JSON-encoded access.Permissions. - Overrides all --v23.permissions.file flags. + Overrides all --v23.permissions.file flags -v23.proxy= object name of proxy service to use to export services across network boundaries + -v23.proxy.limit=0 + max number of proxies to connect to when the policy is to connect to all + proxies; 0 implies all proxies + -v23.proxy.policy= + policy for choosing from a set of available proxy instances -v23.tcp.address= address to listen on - -v23.tcp.protocol=wsh + -v23.tcp.protocol= protocol to listen with + -v23.virtualized.advertise-private-addresses= + if set the process will also advertise its private addresses + -v23.virtualized.disallow-native-fallback=false + if set, a failure to detect the requested virtualization provider will result + in an error, otherwise, native mode is used + -v23.virtualized.dns.public-name= + if set the process will use the supplied dns name (and port) without + resolution for its entry in the mounttable + -v23.virtualized.docker= + set if the process is running in a docker container and needs to configure + itself differently therein + -v23.virtualized.provider= + the name of the virtualization/cloud provider hosting this process if the + process needs to configure itself differently therein + -v23.virtualized.tcp.public-address= + if set the process will use this address (resolving via dns if appropriate) + for its entry in the mounttable + -v23.virtualized.tcp.public-protocol= + if set the process will use this protocol for its entry in the mounttable -v23.vtrace.cache-size=1024 - The number of vtrace traces to store in memory. + The number of vtrace traces to store in memory -v23.vtrace.collect-regexp= Spans and annotations that match this regular expression will trigger trace - collection. + collection -v23.vtrace.dump-on-shutdown=true - If true, dump all stored traces on runtime shutdown. + If true, dump all stored traces on runtime shutdown + -v23.vtrace.enable-aws-xray=false + Enable the use of AWS x-ray integration with vtrace + -v23.vtrace.root-span-name= + Set the name of the root vtrace span created by the runtime at startup -v23.vtrace.sample-rate=0 - Rate (from 0.0 to 1.0) to sample vtrace traces. + Rate (from 0.0 to 1.0) to sample vtrace traces -v23.vtrace.v=0 The verbosity level of the log messages to be captured in traces -vmodule= diff --git a/cmd/grail-ticket/main.go b/cmd/grail-ticket/main.go index 14affbbd..51ffcba4 100644 --- a/cmd/grail-ticket/main.go +++ b/cmd/grail-ticket/main.go @@ -3,22 +3,27 @@ // license that can be found in the LICENSE file. // The following enables go generate to generate the doc.go file. -//go:generate go run $GRAIL/go/src/vendor/v.io/x/lib/cmdline/testdata/gendoc.go "--build-cmd=go install" --copyright-notice= . -help +//go:generate go run v.io/x/lib/cmdline/gendoc "--build-cmd=go install" --copyright-notice= . -help + package main import ( "fmt" "io/ioutil" + "log" + "os" + "os/exec" + "syscall" "time" _ "github.com/grailbio/base/cmdutil/interactive" "github.com/grailbio/base/security/ticket" + _ "github.com/grailbio/v23/factories/grail" "v.io/v23/context" "v.io/v23/vdl" "v.io/x/lib/cmdline" "v.io/x/ref/lib/v23cmd" "v.io/x/ref/lib/vdl/codegen/json" - _ "v.io/x/ref/runtime/factories/grail" ) var ( @@ -26,7 +31,9 @@ var ( authorityCertFlag string certFlag string keyFlag string + rationaleFlag string jsonOnlyFlag bool + listFlag bool ) func newCmdRoot() *cmdline.Command { @@ -40,22 +47,26 @@ identified using a Vanadium name. Examples: - grail ticket ticket/reflow/gdc/aws - grail ticket /127.0.0.1:8000/reflow/gdc/aws + grail-ticket tickets/eng/dev/aws + grail-ticket /127.0.0.1:8000/eng/dev/aws Note that tickets can be enumerated using the 'namespace' Vanadium tool: + namespace glob tickets/... + namespace glob tickets/eng/... namespace glob /127.0.0.1:8000/... - namespace glob /127.0.0.1:8000/reflow/... + namespace glob /127.0.0.1:8000/eng/... `, ArgsName: "", LookPath: false, } - root.Flags.DurationVar(&timeoutFlag, "timeout", 10*time.Second, "Timeout for the requests to the ticket-server") + root.Flags.DurationVar(&timeoutFlag, "timeout", 90*time.Second, "Timeout for the requests to the ticket-server") root.Flags.BoolVar(&jsonOnlyFlag, "json-only", false, "Force a JSON output even for the tickets that have special handling") + root.Flags.BoolVar(&listFlag, "list", false, "List accessible tickets") root.Flags.StringVar(&authorityCertFlag, "authority-cert", "", "PEM file to store the CA cert for a TLS-based ticket") root.Flags.StringVar(&certFlag, "cert", "", "PEM file to store the cert for a TLS-based ticket") root.Flags.StringVar(&keyFlag, "key", "", "PEM file to store the private key for a TLS-based ticket") + root.Flags.StringVar(&rationaleFlag, "rationale", "", "Rationale for accessing ticket") return root } @@ -70,23 +81,43 @@ func saveCredentials(creds ticket.TlsCredentials) error { } func run(ctx *context.T, env *cmdline.Env, args []string) error { - if len(args) != 1 { - return env.UsageErrorf("Exactly one arguments () is required.") + if len(args) == 0 { + return env.UsageErrorf("At least one arguments () is required.") } ticketPath := args[0] + if listFlag { + fmt.Println("Listing all accessible tickets (this may take up to 90 seconds)...") + client := ticket.ListServiceClient(ticketPath + "/list") + tickets, err := client.List(ctx) + if err != nil { + return err + } + for _, t := range tickets { + fmt.Println(t) + } + return nil + } + client := ticket.TicketServiceClient(ticketPath) ctx, cancel := context.WithTimeout(ctx, timeoutFlag) defer cancel() - t, err := client.Get(ctx) + var t ticket.Ticket + var err error + if rationaleFlag != "" { + t, err = client.GetWithArgs(ctx, map[string]string{ + ticket.ControlRationale.String(): rationaleFlag, + }) + } else { + t, err = client.Get(ctx) + } if err != nil { return err } - jsonOutput := json.Const(vdl.ValueOf(t.Interface()), "", nil) - if jsonOnlyFlag { + jsonOutput := json.Const(vdl.ValueOf(t.Interface()), "", nil) fmt.Println(jsonOutput) return nil } @@ -115,6 +146,34 @@ func run(ctx *context.T, env *cmdline.Env, args []string) error { } } + if t.Index() == (ticket.TicketAwsTicket{}).Index() && len(args) > 1 { + creds := t.(ticket.TicketAwsTicket).Value.AwsCredentials + awsEnv := map[string]string{ + "AWS_ACCESS_KEY_ID": creds.AccessKeyId, + "AWS_SECRET_ACCESS_KEY": creds.SecretAccessKey, + "AWS_SESSION_TOKEN": creds.SessionToken, + } + + args = args[1:] + path, err := exec.LookPath(args[0]) + if err != nil { + log.Fatal(err) + } + for k := range awsEnv { + os.Unsetenv(k) + } + env := os.Environ() + for k, v := range awsEnv { + env = append(env, fmt.Sprintf("%s=%s", k, v)) + } + + // run runs a program with certain arguments and certain environment + // variables. This function never returns. The arguments list contains + // the name of the program. + return syscall.Exec(path, args, env) + } + + jsonOutput := json.Const(vdl.ValueOf(t.Interface()), "", nil) fmt.Println(jsonOutput) return nil } diff --git a/cmd/grail-ticket/terraform.tfvars b/cmd/grail-ticket/terraform.tfvars new file mode 100644 index 00000000..9f772ab5 --- /dev/null +++ b/cmd/grail-ticket/terraform.tfvars @@ -0,0 +1 @@ +image_tag="639575724980.dkr.ecr.us-west-2.amazonaws.com/ror-dev-grail-ticket:2020-06-11.bbentson-103649.ccf89dd4c6b-dirty-3" diff --git a/cmd/oom/main.go b/cmd/oom/main.go new file mode 100644 index 00000000..b3448c6c --- /dev/null +++ b/cmd/oom/main.go @@ -0,0 +1,35 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package main + +import ( + "flag" + "fmt" + "log" + "os" + + "github.com/grailbio/base/stress/oom" +) + +func main() { + log.SetFlags(0) + log.SetPrefix("oom: ") + size := flag.Int("size", 0, "amount of memory to allocate; automatically determined if zero") + + flag.Usage = func() { + fmt.Fprintf(os.Stderr, `usage: oom [-size N] + +OOM attempts to OOM the system by allocating up +N bytes of memory. If size is not specified, oom +automatically determines how much memory to allocate. +`) + os.Exit(2) + } + flag.Parse() + if *size != 0 { + oom.Do(*size) + } + oom.Try() +} diff --git a/cmd/ticket-server/config/load.go b/cmd/ticket-server/config/load.go index 60699cdf..ec2e0b4c 100644 --- a/cmd/ticket-server/config/load.go +++ b/cmd/ticket-server/config/load.go @@ -22,9 +22,10 @@ import ( const configSuffix = ".vdlconfig" type ticketConfig struct { - Kind string - Ticket ticket.Ticket - Perms access.Permissions + Kind string + Ticket ticket.Ticket + Perms access.Permissions + Controls map[ticket.Control]bool } type Config map[string]ticketConfig @@ -53,7 +54,8 @@ func Load(dir string) (map[string]ticketConfig, error) { } errors := vdlutil.Errors{} - packages := build.TransitivePackagesForConfig(path, f, build.Opts{}, &errors) + warnings := vdlutil.Errors{} + packages := build.TransitivePackagesForConfig(path, f, build.Opts{}, &errors, &warnings) env := compile.NewEnv(100) for _, p := range packages { vlog.VI(1).Infof("building package: %+v", p.Path) @@ -76,9 +78,14 @@ func Load(dir string) (map[string]ticketConfig, error) { if t.Permissions != nil { perms = t.Permissions } + // TODO(noah): Remove this check after PagerDutyId and TicketId controls are supported. + if _, ok := t.Controls[ticket.ControlRationale]; len(t.Controls) != 0 && (len(t.Controls) != 1 || !ok) { + return fmt.Errorf("only rationale control is supported: %+v", t.Controls) + } all[naming.Join(prefix, name)] = ticketConfig{ - Ticket: t.Ticket, - Perms: perms, + Ticket: t.Ticket, + Perms: perms, + Controls: t.Controls, } } diff --git a/cmd/ticket-server/doc.go b/cmd/ticket-server/doc.go index 246d2830..eab3cddb 100644 --- a/cmd/ticket-server/doc.go +++ b/cmd/ticket-server/doc.go @@ -1,7 +1,3 @@ -// Copyright 2018 GRAIL, Inc. All rights reserved. -// Use of this source code is governed by the Apache-2.0 -// license that can be found in the LICENSE file. - // This file was auto-generated via go generate. // DO NOT UPDATE MANUALLY @@ -17,6 +13,12 @@ Usage: ticket-server [flags] The ticket-server flags are: + -aws-account-ids= + Commma-separated list of AWS account IDs used to populate allow-list of k8s + clusters. + -aws-regions=us-west-2 + Commma-separated list of AWS regions used to populate allow-list of k8s + clusters. -config-dir= Directory with tickets in VDL format. Must be provided. -danger-danger-danger-ec2-disable-address-check=false @@ -28,6 +30,8 @@ The ticket-server flags are: -danger-danger-danger-ec2-disable-uniqueness-check=false Disable the uniqueness check for the EC2-based blessings requests. Only useful for local tests. + -dry-run=false + Don't run, just check the configs. -ec2-blesser-role= What role to use for the blesser/ec2 endpoint. The role needs to exist in all the accounts. @@ -36,10 +40,22 @@ The ticket-server flags are: requests. -ec2-expiration=8760h0m0s Expiration caveat for the EC2-based blessings. + -google-admin=admin@grailbio.com + Google Admin that can read all group memberships - NOTE: all groups will need + to match the admin user's domain. -google-expiration=168h0m0s Expiration caveat for the Google-based blessings. + -google-user-domain=grailbio.com + Comma-separated list of email domains used for validating users. + -k8s-blesser-role=ticket-server + What role to use to lookup EKS cluster information on all authorized + accounts. The role needs to exist in all the accounts. + -k8s-expiration=8760h0m0s + Expiration caveat for the K8s-based blessings. -name= Name to mount the server under. If empty, don't mount. + -region=us-west-2 + AWS region to use for cached AWS session. -service-account= JSON file with a Google service account credentials. @@ -64,31 +80,57 @@ The global flags are: log level for V logs -v23.credentials= directory to use for storing security credentials - -v23.i18n-catalogue= - 18n catalogue files to load, comma separated - -v23.namespace.root=[/(v23.grail.com:internal:mounttabled)@ns.v23.grail.com:8101] + -v23.namespace.root=[/(v23.grail.com:internal:mounttabled)@ns-0.v23.grail.com:8101,/(v23.grail.com:internal:mounttabled)@ns-1.v23.grail.com:8101,/(v23.grail.com:internal:mounttabled)@ns-2.v23.grail.com:8101] local namespace root; can be repeated to provided multiple roots - -v23.permissions.file=map[] + -v23.permissions.file= specify a perms file as : -v23.permissions.literal= explicitly specify the runtime perms as a JSON-encoded access.Permissions. - Overrides all --v23.permissions.file flags. + Overrides all --v23.permissions.file flags -v23.proxy= object name of proxy service to use to export services across network boundaries + -v23.proxy.limit=0 + max number of proxies to connect to when the policy is to connect to all + proxies; 0 implies all proxies + -v23.proxy.policy= + policy for choosing from a set of available proxy instances -v23.tcp.address= address to listen on - -v23.tcp.protocol=wsh + -v23.tcp.protocol= protocol to listen with + -v23.virtualized.advertise-private-addresses= + if set the process will also advertise its private addresses + -v23.virtualized.disallow-native-fallback=false + if set, a failure to detect the requested virtualization provider will result + in an error, otherwise, native mode is used + -v23.virtualized.dns.public-name= + if set the process will use the supplied dns name (and port) without + resolution for its entry in the mounttable + -v23.virtualized.docker= + set if the process is running in a docker container and needs to configure + itself differently therein + -v23.virtualized.provider= + the name of the virtualization/cloud provider hosting this process if the + process needs to configure itself differently therein + -v23.virtualized.tcp.public-address= + if set the process will use this address (resolving via dns if appropriate) + for its entry in the mounttable + -v23.virtualized.tcp.public-protocol= + if set the process will use this protocol for its entry in the mounttable -v23.vtrace.cache-size=1024 - The number of vtrace traces to store in memory. + The number of vtrace traces to store in memory -v23.vtrace.collect-regexp= Spans and annotations that match this regular expression will trigger trace - collection. + collection -v23.vtrace.dump-on-shutdown=true - If true, dump all stored traces on runtime shutdown. + If true, dump all stored traces on runtime shutdown + -v23.vtrace.enable-aws-xray=false + Enable the use of AWS x-ray integration with vtrace + -v23.vtrace.root-span-name= + Set the name of the root vtrace span created by the runtime at startup -v23.vtrace.sample-rate=0 - Rate (from 0.0 to 1.0) to sample vtrace traces. + Rate (from 0.0 to 1.0) to sample vtrace traces -v23.vtrace.v=0 The verbosity level of the log messages to be captured in traces -vmodule= diff --git a/cmd/ticket-server/ec2blesser.go b/cmd/ticket-server/ec2blesser.go index d1af2cfd..96a1cce7 100644 --- a/cmd/ticket-server/ec2blesser.go +++ b/cmd/ticket-server/ec2blesser.go @@ -7,6 +7,7 @@ package main import ( "fmt" "net" + "os" "strings" "time" @@ -18,10 +19,10 @@ import ( "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/ec2" "github.com/grailbio/base/cloud/ec2util" + "github.com/grailbio/base/common/log" "v.io/v23/context" "v.io/v23/rpc" "v.io/v23/security" - "v.io/x/lib/vlog" ) const pendingTimeWindow = time.Hour @@ -37,7 +38,7 @@ const pendingTimeWindow = time.Hour // IdentityDocument: (string) JSON of the IdentityDocument from the request // DescribeInstance: (string) JSON response for the DescribeInstance call // Timestamp: (string) Timestamp in RFC3339Nano when the record was created -func setupEc2Blesser(s *session.Session, table string) { +func setupEc2Blesser(ctx *context.T, s *session.Session, table string) { if table == "" { return } @@ -48,13 +49,14 @@ func setupEc2Blesser(s *session.Session, table string) { }) if err == nil { - vlog.Infof("DynamoDB table already exists:\n%+v", out) + log.Error(ctx, "DynamoDB table already exists", "table", out) return } want := dynamodb.ErrCodeResourceNotFoundException if aerr, ok := err.(awserr.Error); !ok || aerr.Code() != want { - vlog.Fatal("unexpected error: got %+v, want %+v", err, want) + log.Error(ctx, "unexpected DynamoDB error", "got", err, "want", want) + os.Exit(255) } _, err = client.CreateTable(&dynamodb.CreateTableInput{ @@ -77,14 +79,16 @@ func setupEc2Blesser(s *session.Session, table string) { }, }) if err != nil { - vlog.Fatal(err) + log.Error(ctx, err.Error()) + os.Exit(255) } - vlog.Infof("%q DynamoDB table was created", table) + log.Debug(ctx, "created DynamoDB table", "table", table) // TODO(razvanm): wait for the table to reach ACTIVE state? // TODO(razvanm): enable the auto scaling? } type ec2Blesser struct { + ctx context.T expirationInterval time.Duration role string table string @@ -92,8 +96,9 @@ type ec2Blesser struct { } func newEc2Blesser(ctx *context.T, s *session.Session, expiration time.Duration, role string, table string) *ec2Blesser { - setupEc2Blesser(s, ec2DynamoDBTableFlag) + setupEc2Blesser(ctx, s, ec2DynamoDBTableFlag) return &ec2Blesser{ + ctx: *ctx, expirationInterval: expiration, role: role, table: table, @@ -101,7 +106,7 @@ func newEc2Blesser(ctx *context.T, s *session.Session, expiration time.Duration, } } -func (blesser *ec2Blesser) checkUniqueness(doc *ec2util.IdentityDocument, remoteAddr string, jsonDoc string, jsonInstance string) error { +func (blesser *ec2Blesser) checkUniqueness(ctx *context.T, doc *ec2util.IdentityDocument, remoteAddr string, jsonDoc string, jsonInstance string) error { if blesser.table == "" { return nil } @@ -110,7 +115,7 @@ func (blesser *ec2Blesser) checkUniqueness(doc *ec2util.IdentityDocument, remote return err } key := strings.Join([]string{doc.AccountID, doc.Region, doc.InstanceID, ipAddr}, "/") - vlog.Infof("DynamoDB key(%s): %q", remoteAddr, key) + log.Debug(ctx, "DynamoDB info", "key", key, "remoteAddr", remoteAddr) cond := aws.String("attribute_not_exists(ID)") if ec2DisableUniquenessCheckFlag { cond = nil @@ -140,18 +145,17 @@ func (blesser *ec2Blesser) BlessEc2(ctx *context.T, call rpc.ServerCall, pkcs7b6 var empty security.Blessings remoteAddress := call.RemoteAddr().String() - vlog.Infof("remote endpoint: %+v", call.RemoteEndpoint().Addr()) - vlog.Infof("pkcs7(%s): %d bytes", remoteAddress, len(pkcs7b64)) doc, jsonDoc, err := ec2util.ParseAndVerifyIdentityDocument(pkcs7b64) - vlog.Infof("doc(%s): %+v", remoteAddress, doc) + log.Info(ctx, "bless EC2 request", "remoteAddr", remoteAddress, "remoteEndpoint", call.RemoteEndpoint().Addr(), + "pkcs7b64Bytes", len(pkcs7b64), "doc", doc) if err != nil { - vlog.Infof("error(%s): %+v", remoteAddress, err) + log.Error(ctx, "Error parsing and verifying identity document.", "err", err) return empty, err } if !ec2DisablePendingTimeCheckFlag { if err := checkPendingTime(doc); err != nil { - vlog.Infof("error(%s): %+v", remoteAddress, err) + log.Error(ctx, err.Error()) return empty, err } } @@ -161,6 +165,7 @@ func (blesser *ec2Blesser) BlessEc2(ctx *context.T, call rpc.ServerCall, pkcs7b6 Retryer: client.DefaultRetryer{ NumMaxRetries: 100, }, + Region: aws.String(doc.Region), } validateRemoteAddr := remoteAddress if ec2DisableAddrCheckFlag { @@ -172,18 +177,18 @@ func (blesser *ec2Blesser) BlessEc2(ctx *context.T, call rpc.ServerCall, pkcs7b6 }) if err != nil { - vlog.Infof("error(%s): %+v", remoteAddress, err) + log.Error(ctx, err.Error()) return empty, err } role, err := ec2util.ValidateInstance(output, *doc, validateRemoteAddr) if err != nil { - vlog.Infof("error(%s): %+v", remoteAddress, err) + log.Error(ctx, err.Error()) return empty, err } - if err := blesser.checkUniqueness(doc, remoteAddress, jsonDoc, output.String()); err != nil { - vlog.Infof("error(%s): %+v", remoteAddress, err) + if err = blesser.checkUniqueness(ctx, doc, remoteAddress, jsonDoc, output.String()); err != nil { + log.Error(ctx, err.Error()) return empty, err } diff --git a/cmd/ticket-server/googleblesser.go b/cmd/ticket-server/googleblesser.go index 5fbc0ee0..25fd0969 100644 --- a/cmd/ticket-server/googleblesser.go +++ b/cmd/ticket-server/googleblesser.go @@ -7,43 +7,43 @@ package main import ( "context" "fmt" - "regexp" "strings" "time" - oidc "github.com/coreos/go-oidc" + "github.com/coreos/go-oidc" + "github.com/grailbio/base/common/log" v23context "v.io/v23/context" "v.io/v23/rpc" "v.io/v23/security" - "v.io/x/lib/vlog" ) const ( issuer = "https://accounts.google.com" audience = "27162366543-edih9cqc3t8p5hn9ord1k1n7h4oajfhm.apps.googleusercontent.com" - // TODO(razvanm): add support for 'grail.com'. - hostedDomain = "grailbio.com" - emailSuffix = "@grailbio.com" - extensionPrefix = "google" ) -var extensionRE = regexp.MustCompile(strings.Join([]string{extensionPrefix, fmt.Sprintf("([a-z0-9]+%s)", emailSuffix)}, security.ChainSeparator)) +var ( + hostedDomains []string +) + +func googleBlesserInit(googleUserDomainList []string) { + hostedDomains = googleUserDomainList +} func (c *claims) checkClaims() error { - if c.EmailVerified != true { + if !c.EmailVerified { return fmt.Errorf("ID token doesn't have a verified email") } - if got, want := c.HostedDomain, hostedDomain; got != want { - return fmt.Errorf("ID token has a wrong hosted domain: got %q, want %q", got, want) + if !stringInSlice(hostedDomains, c.HostedDomain) { + return fmt.Errorf("ID token has a wrong hosted domain: got %q, want %q", c.HostedDomain, strings.Join(hostedDomains, ",")) } - if !strings.HasSuffix(c.Email, emailSuffix) { - return fmt.Errorf("ID token does not have the right email suffix (%q): %q", emailSuffix, c.Email) + if !stringInSlice(hostedDomains, emailDomain(c.Email)) { + return fmt.Errorf("ID token does not have a sufix with a authorized email domain (%q): %q", strings.Join(hostedDomains, ","), c.Email) } - return nil } @@ -58,10 +58,12 @@ type googleBlesser struct { expirationInterval time.Duration } -func newGoogleBlesser(expiration time.Duration) *googleBlesser { +func newGoogleBlesser(ctx *v23context.T, expiration time.Duration, domains []string) *googleBlesser { + googleBlesserInit(domains) + provider, err := oidc.NewProvider(context.Background(), issuer) if err != nil { - vlog.Fatal(err) + log.Error(ctx, err.Error()) } return &googleBlesser{ verifier: provider.Verifier(&oidc.Config{ClientID: audience}), @@ -71,20 +73,18 @@ func newGoogleBlesser(expiration time.Duration) *googleBlesser { func (blesser *googleBlesser) BlessGoogle(ctx *v23context.T, call rpc.ServerCall, idToken string) (security.Blessings, error) { remoteAddress := call.RemoteEndpoint().Address - vlog.Infof("idtoken(%s): %d bytes", remoteAddress, len(idToken)) - vlog.VI(1).Infof("idtoken(%s): %v", remoteAddress, idToken) + log.Info(ctx, "bless Google request", "remoteAddr", remoteAddress, "idToken", idToken, "idTokenLen", len(idToken)) var empty security.Blessings oidcIDToken, err := blesser.verifier.Verify(ctx, idToken) if err != nil { return empty, err } - vlog.VI(1).Infof("oidcIDToken: %+v", oidcIDToken) var claims claims if err := oidcIDToken.Claims(&claims); err != nil { return empty, nil } - vlog.VI(1).Infof("claims: %+v", claims) + log.Debug(ctx, "", "oidcIDToken", oidcIDToken, "claims", claims) if err := claims.checkClaims(); err != nil { return empty, err diff --git a/cmd/ticket-server/googleblesser_test.go b/cmd/ticket-server/googleblesser_test.go index ba327eaa..5c7b601c 100644 --- a/cmd/ticket-server/googleblesser_test.go +++ b/cmd/ticket-server/googleblesser_test.go @@ -10,15 +10,19 @@ import ( ) func TestCheckClaims(t *testing.T) { + googleBlesserInit([]string{"grailbio.com", "contractors.grail.com"}) + cases := []struct { claims claims errPrefix string }{ {claims{HostedDomain: "grailbio.com", EmailVerified: true, Email: "user@grailbio.com"}, ""}, + {claims{HostedDomain: "grailbio.com", EmailVerified: true, Email: "user@contractors.grail.com"}, ""}, {claims{}, "ID token doesn't have a verified email"}, {claims{EmailVerified: false}, "ID token doesn't have a verified email"}, {claims{EmailVerified: true, Email: "user@grailbio.com"}, "ID token has a wrong hosted domain:"}, - {claims{HostedDomain: "grailbio.com", EmailVerified: true, Email: "user@gmail.com"}, "ID token does not have the right email suffix"}, + {claims{HostedDomain: "grailbio.com", EmailVerified: true, Email: "user@gmail.com"}, "ID token does not have a sufix with a authorized email domain"}, + {claims{HostedDomain: "grailbio.com", EmailVerified: true, Email: "user@gmail@.com"}, "ID token does not have a sufix with a authorized email domain"}, } for _, c := range cases { diff --git a/cmd/ticket-server/googlegroups.go b/cmd/ticket-server/googlegroups.go index 51e4aa08..f0ae6171 100644 --- a/cmd/ticket-server/googlegroups.go +++ b/cmd/ticket-server/googlegroups.go @@ -10,19 +10,17 @@ import ( "strings" "time" + "github.com/grailbio/base/common/log" + "github.com/grailbio/base/ttlcache" "golang.org/x/net/context" "golang.org/x/oauth2/jwt" - "google.golang.org/api/admin/directory/v1" - "github.com/grailbio/base/ttlcache" + admin "google.golang.org/api/admin/directory/v1" v23context "v.io/v23/context" "v.io/v23/security" "v.io/v23/security/access" "v.io/v23/vdl" - "v.io/x/lib/vlog" ) -var groupRE = regexp.MustCompile(strings.Join([]string{"googlegroups", fmt.Sprintf("([a-z0-9-]+%s)", emailSuffix)}, security.ChainSeparator)) - type cacheKey struct { user string group string @@ -35,27 +33,49 @@ var cache = ttlcache.New(cacheTTL) // email returns the user email from a Vanadium blessing that was produced via // a BlessGoogle call. +var ( + groupRE *regexp.Regexp + userRE *regexp.Regexp + adminLookupDomain string +) + +func googleGroupsInit(ctx *v23context.T, groupLookupName string) { + if hostedDomains == nil || len(hostedDomains) == 0 { + log.Error(ctx, "hostedDomains not initialized.") + panic("hostedDomains not initialized") + } + + // Extract the domain of the admin account to filter users in the same Google Domain + adminLookupDomain = emailDomain(groupLookupName) + groupRE = regexp.MustCompile(strings.Join([]string{"^" + "googlegroups", fmt.Sprintf("([a-z0-9-_+.]+@[a-z0-9-_+.]+)$")}, security.ChainSeparator)) + // NOTE This is a non terminating string, because the user validation can be terminated by the ChainSeparator (`:`) + userRE = regexp.MustCompile(strings.Join([]string{"^" + extensionPrefix, fmt.Sprintf("([a-z0-9-_+.]+@[a-z0-9-_+.]+)")}, security.ChainSeparator)) +} + +//verifyAndExtractEmailFromBlessing returns the email address defined in a v23 principal/blessing // // For example, for 'v23.grail.com:google:razvanm@grailbio.com' the return // string should be 'razvanm@grailbio.com'. -func email(blessing string, prefix string) string { +func verifyAndExtractEmailFromBlessing(blessing string, prefix string) string { if strings.HasPrefix(blessing, prefix) && blessing != prefix { - m := extensionRE.FindStringSubmatch(blessing[len(prefix)+1:]) - if m != nil { + m := userRE.FindStringSubmatch(blessing[len(prefix)+1:]) + if m != nil && stringInSlice(hostedDomains, emailDomain(m[1])) { return m[1] } } return "" } -// group returns the Google Groups name from a Vanadium blessing. +// extractGroupEmailFromBlessing returns the Google Groups name from a Vanadium blessing. // // For example, for 'v23.grail.com:googlegroups:eng@grailbio.com' the return // string should be 'eng@grailbio.com'. -func group(blessing string, prefix string) string { +func extractGroupEmailFromBlessing(ctx *v23context.T, blessing string, prefix string) string { + log.Debug(ctx, "extracting group email from blessing", "blessing", blessing, "prefix", prefix) if strings.HasPrefix(blessing, prefix) { m := groupRE.FindStringSubmatch(blessing[len(prefix)+1:]) - if m != nil { + + if m != nil && stringInSlice(hostedDomains, emailDomain(m[1])) { return m[1] } } @@ -69,35 +89,55 @@ type authorizer struct { isMember func(user, group string) bool } -func googleGroupsAuthorizer(perms access.Permissions, jwtConfig *jwt.Config) security.Authorizer { +func (a *authorizer) String() string { + return fmt.Sprintf("%+v", a.perms) +} + +func googleGroupsAuthorizer(ctx *v23context.T, perms access.Permissions, jwtConfig *jwt.Config, + groupLookupName string) security.Authorizer { + googleGroupsInit(ctx, groupLookupName) return &authorizer{ perms: perms, tagType: access.TypicalTagType(), isMember: func(user, group string) bool { key := cacheKey{user, group} if v, ok := cache.Get(key); ok { - vlog.VI(1).Infof("cache hit for %+v", key) + log.Debug(ctx, "Google groups lookup cache hit", "key", key) return v.(bool) } - vlog.VI(1).Infof("cache miss for %+v", key) + log.Debug(ctx, "Google groups lookup cache miss", "key", key) config := *jwtConfig // This needs to be a Super Admin of the domain. - config.Subject = "admin@grailbio.com" + config.Subject = groupLookupName + service, err := admin.New(config.Client(context.Background())) if err != nil { - vlog.Info(err) + log.Error(ctx, err.Error()) return false } + // If the group is in a different domain, perform a user based group membership check + // This loses the ability to check for nested groups - see https://phabricator.grailbio.com/D13275 + // and https://github.com/googleapis/google-api-java-client/issues/1082 + if adminLookupDomain != emailDomain(user) { + member, member_err := admin.NewMembersService(service).Get(group, user).Do() + if member_err != nil { + log.Error(ctx, member_err.Error()) + return false + } + log.Debug(ctx, "adding member to cache", "member", member, "key", key) + isMember := member.Status == "ACTIVE" + cache.Set(key, isMember) + return isMember + } + result, err := admin.NewMembersService(service).HasMember(group, user).Do() if err != nil { - vlog.Info(err) + log.Error(ctx, err.Error()) return false } - vlog.Infof("hasMember: %+v", result) - - vlog.VI(1).Infof("add to cache %+v", key) + log.Debug(ctx, "adding member to cache", "hasMember", result, "key", key) cache.Set(key, result.IsMember) return result.IsMember @@ -105,47 +145,52 @@ func googleGroupsAuthorizer(perms access.Permissions, jwtConfig *jwt.Config) sec } } -func (a *authorizer) pruneBlacklisted(acl access.AccessList, blessings []string, localBlessings string) []string { +func (a *authorizer) pruneBlessingslist(ctx *v23context.T, acl access.AccessList, blessings []string, localBlessings string) []string { if len(acl.NotIn) == 0 { return blessings } var filtered []string for _, b := range blessings { - blacklisted := false + inDenyList := false for _, bp := range acl.NotIn { if security.BlessingPattern(bp).MatchedBy(b) { - blacklisted = true + inDenyList = true break } - userEmail := email(b, localBlessings) - groupEmail := group(bp, localBlessings) + userEmail := verifyAndExtractEmailFromBlessing(b, localBlessings) + groupEmail := extractGroupEmailFromBlessing(ctx, bp, localBlessings) + log.Debug(ctx, "pruning blessings list", "userEmail", userEmail, "groupEmail", groupEmail) if userEmail != "" && groupEmail != "" { if a.isMember(userEmail, groupEmail) { - vlog.Infof("%q is a member of %q (NotIn blessing pattern %q)", userEmail, groupEmail, bp) - blacklisted = true + log.Debug(ctx, "user is a member of group", "userEmail", userEmail, "groupEmail", groupEmail, + "blessingPattern", bp) + inDenyList = true break } } } - if !blacklisted { + if !inDenyList { filtered = append(filtered, b) } } return filtered } -func (a *authorizer) aclIncludes(acl access.AccessList, blessings []string, localBlessings string) bool { - blessings = a.pruneBlacklisted(acl, blessings, localBlessings) - for _, pattern := range acl.In { - if pattern.MatchedBy(blessings...) { +func (a *authorizer) aclIncludes(ctx *v23context.T, acl access.AccessList, blessings []string, + localBlessings string) bool { + blessings = a.pruneBlessingslist(ctx, acl, blessings, localBlessings) + for _, bp := range acl.In { + if bp.MatchedBy(blessings...) { return true } for _, b := range blessings { - userEmail := email(b, localBlessings) - groupEmail := group(string(pattern), localBlessings) + userEmail := verifyAndExtractEmailFromBlessing(b, localBlessings) + groupEmail := extractGroupEmailFromBlessing(ctx, string(bp), localBlessings) + log.Debug(ctx, "checking access list", "userEmail", userEmail, "groupEmail", groupEmail) if userEmail != "" && groupEmail != "" { if a.isMember(userEmail, groupEmail) { - vlog.Infof("%q is a member of %q (In blessing pattern %q)", userEmail, groupEmail, pattern) + log.Debug(ctx, "user is a member of group", "userEmail", userEmail, "groupEmail", groupEmail, + "blessingPattern", bp) return true } } @@ -156,12 +201,13 @@ func (a *authorizer) aclIncludes(acl access.AccessList, blessings []string, loca func (a *authorizer) Authorize(ctx *v23context.T, call security.Call) error { blessings, invalid := security.RemoteBlessingNames(ctx, call) - vlog.Infof("RemoteBlessingNames: %q Tags: %q", blessings, call.MethodTags()) + log.Debug(ctx, "authorizing via Google flow", "blessings", blessings, "tags", call.MethodTags()) for _, tag := range call.MethodTags() { if tag.Type() == a.tagType { - if acl, exists := a.perms[tag.RawString()]; !exists || !a.aclIncludes(acl, blessings, call.LocalBlessings().String()) { - return access.NewErrNoPermissions(ctx, blessings, invalid, tag.RawString()) + if acl, exists := a.perms[tag.RawString()]; !exists || !a.aclIncludes(ctx, acl, blessings, + call.LocalBlessings().String()) { + return access.ErrorfNoPermissions(ctx, "%v %v %v", blessings, invalid, tag.RawString()) } } } diff --git a/cmd/ticket-server/googlegroups_test.go b/cmd/ticket-server/googlegroups_test.go index 5131a035..e92e1a37 100644 --- a/cmd/ticket-server/googlegroups_test.go +++ b/cmd/ticket-server/googlegroups_test.go @@ -7,27 +7,57 @@ package main import ( "testing" + "github.com/grailbio/base/vcontext" + "github.com/stretchr/testify/assert" "v.io/v23/security" "v.io/v23/security/access" ) +var ( + testDomainList = []string{"grailbio.com", "contractors.grail.com"} +) + +func TestInit(t *testing.T) { + ctx := vcontext.Background() + f := func() { + hostedDomains = nil + googleGroupsInit(ctx, "admin@grailbio.com") + } + assert.PanicsWithValue(t, "hostedDomains not initialized", f) + + f = func() { + googleBlesserInit([]string{}) + googleGroupsInit(ctx, "admin@grailbio.com") + } + assert.PanicsWithValue(t, "hostedDomains not initialized", f) +} + func TestEmail(t *testing.T) { + ctx := vcontext.Background() + googleBlesserInit(testDomainList) + googleGroupsInit(ctx, "admin@grailbio.com") + cases := []struct { blessing string email string }{ {"v23.grail.com:google:razvanm@grailbio.com", "razvanm@grailbio.com"}, + {"v23.grail.com:google:razvanm@grailbio.com:_role", "razvanm@grailbio.com"}, + {"v23.grail.com:google:complex_+.email@grailbio.com:_role", "complex_+.email@grailbio.com"}, + {"v23.grail.com:google:razvanm@grailbioacom", ""}, {"v23.grail.com:google:razvanm@gmail.com", ""}, {"v23.grail.com:google:razvanm@", ""}, {"v23.grail.com:google:razvanm", ""}, {"v23.grail.com:google", ""}, {"v23.grail.com:xxx:razvanm@grailbio.com", ""}, {"v23.grail.com:googlegroups:eng@grailbio.com", ""}, + {"v23.grail.com:googlegroups:golang-nuts@googlegroups.com:google:razvanm@grailbio.com", ""}, + {"v23.grail.com:googlegroups:eng@grailbio.com:google:razvanm@grailbio.com", ""}, } prefix := "v23.grail.com" for _, c := range cases { - got, want := email(c.blessing, prefix), c.email + got, want := verifyAndExtractEmailFromBlessing(c.blessing, prefix), c.email if got != want { t.Errorf("email(%q, %q): got %q, want %q", c.blessing, prefix, got, want) } @@ -35,22 +65,29 @@ func TestEmail(t *testing.T) { } func TestGroup(t *testing.T) { + ctx := vcontext.Background() + googleBlesserInit(testDomainList) + googleGroupsInit(ctx, "admin@grailbio.com") + cases := []struct { blessing string email string }{ {"v23.grail.com:googlegroups:eng-dev-access@grailbio.com", "eng-dev-access@grailbio.com"}, {"v23.grail.com:googlegroups:golang-nuts@googlegroups.com", ""}, + {"v23.grail.com:googlegroups:golang-_+.nuts@grailbio.com", "golang-_+.nuts@grailbio.com"}, {"v23.grail.com:googlegroups:eng@", ""}, {"v23.grail.com:googlegroups:eng", ""}, {"v23.grail.com:googlegroups", ""}, {"v23.grail.com:xxx:eng@grailbio.com", ""}, {"v23.grail.com:google:razvanm@grailbio.com", ""}, + {"v23.grail.com:google:razvanm@grailbio.com:googlegroups:golang-nuts@googlegroups.com", ""}, + {"v23.grail.com:google:razvanm@grailbio.com:googlegroups:eng@grailbio.com", ""}, } prefix := "v23.grail.com" for _, c := range cases { - got, want := group(c.blessing, prefix), c.email + got, want := extractGroupEmailFromBlessing(ctx, c.blessing, prefix), c.email if got != want { t.Errorf("email(%q, %q): got %q, want %q", c.blessing, prefix, got, want) } @@ -58,6 +95,10 @@ func TestGroup(t *testing.T) { } func TestAclIncludes(t *testing.T) { + ctx := vcontext.Background() + googleBlesserInit(testDomainList) + googleGroupsInit(ctx, "admin@grailbio.com") + cases := []struct { acl access.AccessList want bool @@ -128,7 +169,7 @@ func TestAclIncludes(t *testing.T) { }, } for _, c := range cases { - got := a.aclIncludes(c.acl, blessings, prefix) + got := a.aclIncludes(ctx, c.acl, blessings, prefix) if got != c.want { t.Errorf("aclIncludes(%+v, %v): got %v, want %v", c.acl, blessings, got, c.want) } diff --git a/cmd/ticket-server/k8sblesser.go b/cmd/ticket-server/k8sblesser.go new file mode 100644 index 00000000..388e596b --- /dev/null +++ b/cmd/ticket-server/k8sblesser.go @@ -0,0 +1,379 @@ +package main + +import ( + "context" + "encoding/base64" + "fmt" + "net/http" + "strings" + "time" + + auth "k8s.io/api/authentication/v1" + client "k8s.io/client-go/kubernetes/typed/authentication/v1" + rest "k8s.io/client-go/rest" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/arn" + awsclient "github.com/aws/aws-sdk-go/aws/client" + "github.com/aws/aws-sdk-go/aws/credentials/stscreds" + "github.com/aws/aws-sdk-go/aws/endpoints" + awssession "github.com/aws/aws-sdk-go/aws/session" + awssigner "github.com/aws/aws-sdk-go/aws/signer/v4" + "github.com/aws/aws-sdk-go/service/eks" + "github.com/aws/aws-sdk-go/service/sts" + + "github.com/grailbio/base/common/log" + "github.com/grailbio/base/errors" + "github.com/grailbio/base/security/identity" + + v23context "v.io/v23/context" + "v.io/v23/rpc" + "v.io/v23/security" +) + +// AWSSessionWrapper is a composition struture that wraps the +// aws session element and returns values that can be more +// easily mocked for testing purposes. SessionI interface +// is passed into the wrapper so that functionality within the +// session function can +type AWSSessionWrapper struct { + session SessionI +} + +// AWSSessionWrapperI provides a means to mock an aws client session. +type AWSSessionWrapperI interface { + GetAuthV1Client(ctx context.Context, headers map[string]string, caCrt string, region string, endpoint string) (client.AuthenticationV1Interface, error) + ListEKSClusters(input *eks.ListClustersInput, roleARN string, region string) (*eks.ListClustersOutput, error) + DescribeEKSCluster(input *eks.DescribeClusterInput, roleARN string, region string) (*eks.DescribeClusterOutput, error) +} + +// newSessionWrapper generates an AWSSessionWrapper that contains +// an awsSession.Session struct and provides multiple mockable interfaces +// for interacting with aws and its remote data. +func newSessionWrapper(session SessionI) *AWSSessionWrapper { + // in order to update the sessionI config we must cast it as an awssession.Session struct + newSession := session.(*awssession.Session) + newSession.Config.STSRegionalEndpoint = endpoints.RegionalSTSEndpoint + + return &AWSSessionWrapper{session: newSession} +} + +// ListEKSClusters provides a mockable interface for AWS sessions to +// obtain and iterate over a list of available EKS clusters with the +// provided input configuration +func (w *AWSSessionWrapper) ListEKSClusters(input *eks.ListClustersInput, roleARN string, region string) (*eks.ListClustersOutput, error) { + config := aws.Config{ + Credentials: stscreds.NewCredentials(w.session, roleARN), // w.session.GetStsCreds(roleARN), + Region: ®ion, + } + + svc := eks.New(w.session, &config) + return svc.ListClusters(input) +} + +// DescribeEKSCluster provides a mockable interface for AWS sessions to +// obtain information regarding a specific EKS cluster with the +// provided input configuration +func (w *AWSSessionWrapper) DescribeEKSCluster(input *eks.DescribeClusterInput, roleARN string, region string) (*eks.DescribeClusterOutput, error) { + config := aws.Config{ + Credentials: stscreds.NewCredentials(w.session, roleARN), // w.session.GetStsCreds(roleARN), + Region: ®ion, + } + svc := eks.New(w.session, &config) + return svc.DescribeCluster(input) +} + +// GetAuthV1Client provides a mockable interface for returning an AWS auth client +func (w *AWSSessionWrapper) GetAuthV1Client(ctx context.Context, headers map[string]string, caCrt string, region string, endpoint string) (client.AuthenticationV1Interface, error) { + var ( + err error + authV1Client *client.AuthenticationV1Client + ) + svc := sts.New(w.session, aws.NewConfig().WithRegion(region)) + req, _ := http.NewRequest("GET", fmt.Sprintf("%s/?Action=GetCallerIdentity&Version=2011-06-15", svc.Client.Endpoint), nil) + + for key, header := range headers { + req.Header.Add(key, header) + } + + var sessionInterface = w.session + var credentials = sessionInterface.(*awssession.Session).Config.Credentials + + signer := awssigner.NewSigner(credentials) + emptyBody := strings.NewReader("") + _, err = signer.Presign(req, emptyBody, "sts", region, 60*time.Second, time.Now()) + + log.Debug(ctx, "Request was built and presigned", "req", req) + + if err != nil { + return authV1Client, errors.E(err, "unable to presign request for STS credentials") + } + + bearerToken := fmt.Sprintf("k8s-aws-v1.%s", strings.TrimRight(base64.StdEncoding.EncodeToString([]byte(req.URL.String())), "=")) + + log.Debug(ctx, "Bearer token generated", "bearerToken", bearerToken, "url", req.URL.String()) + + tlsConfig := rest.TLSClientConfig{CAData: []byte(caCrt)} + config := rest.Config{ + Host: endpoint, + BearerToken: bearerToken, + TLSClientConfig: tlsConfig, + } + + return client.NewForConfigOrDie(&config), err +} + +// SessionI interface provides a mockable interface for session data +type SessionI interface { + awsclient.ConfigProvider +} + +// V23 Blesser utility for generating blessings for k8s cluster principals. Implements +// interface K8sBlesserServerStubMethods, which requires a BlessK8s method. +// Stores awsConn information in addition to the v23 session and blessing expiration intervals. +// Mock this by creating a separate implementation of K8sBlesserServerStubMethods interface. +type k8sBlesser struct { + identity.K8sBlesserServerMethods + sessionWrapper AWSSessionWrapperI + expirationInterval time.Duration + awsConn *awsConn +} + +func newK8sBlesser(sessionWrapper AWSSessionWrapperI, expiration time.Duration, role string, awsAccountIDs []string, awsRegions []string) *k8sBlesser { + return &k8sBlesser{ + sessionWrapper: sessionWrapper, + expirationInterval: expiration, + awsConn: newAwsConn(sessionWrapper, role, awsRegions, awsAccountIDs), + } +} + +// BlessK8s uses the awsConn and k8sConn structs as well as the CreateK8sExtension func +// in order to create a blessing for a k8s principle. It acts as an entrypoint that does not +// perform any important logic on its own. +func (blesser *k8sBlesser) BlessK8s(ctx *v23context.T, call rpc.ServerCall, caCrt string, namespace string, k8sSvcAcctToken string, region string) (security.Blessings, error) { + log.Info(ctx, "bless K8s request", "namespace", namespace, "region", region, "remoteAddr", call.RemoteEndpoint().Address) + var ( + nullBlessings security.Blessings = security.Blessings{} + cluster *eks.Cluster + err error + ) + + // establish security call + securityCall := call.Security() + if securityCall.LocalPrincipal() == nil { + return nullBlessings, errors.New("server misconfiguration: no authentication happened") + } + + // establish caveat + caveat, err := security.NewExpiryCaveat(time.Now().Add(blesser.expirationInterval)) + if err != nil { + return nullBlessings, errors.E(err, "unable to presign request for STS credentials") + } + + // next, we are ready to isolate a desired cluster by enumerating existing eks clusters in a region and matching the caCrt + cluster, err = blesser.awsConn.GetEKSCluster(ctx, region, caCrt) + if err != nil { + return nullBlessings, err + } + + // now we can establish the k8s cluster obj because we know the cluster and can connect to it. + k8sConn := newK8sConn(blesser.sessionWrapper, cluster, region, caCrt, k8sSvcAcctToken) + + // obtain username from cluster connection + username, err := k8sConn.GetK8sUsername(ctx) + if err != nil { + return nullBlessings, err + } + + // create an extension based on the namespace and username + extension, err := CreateK8sExtension(ctx, cluster, username, namespace) + if err != nil { + return nullBlessings, err + } + + // lastly we perform the blessing using the generated k8s extension + return call.Security().LocalPrincipal().Bless(securityCall.RemoteBlessings().PublicKey(), securityCall.LocalBlessings(), extension, caveat) +} + +// Provides an interface for gather aws data using the context, region, caCrt which can be mocked for testing. +type awsConn struct { + role string + regions []string + accountIDs []string + sessionWrapper AWSSessionWrapperI +} + +// Creates a new AWS Connect object that can be used to obtain data about AWS, EKS, etc. +func newAwsConn(sessionWrapper AWSSessionWrapperI, role string, regions []string, accountIDs []string) *awsConn { + return &awsConn{ + sessionWrapper: sessionWrapper, + role: role, + regions: regions, + accountIDs: accountIDs, + } +} + +// Interface for mocking awsConn. +type awsConnI interface { + GetEKSCluster(caCrt string) (*eks.Cluster, error) + GetClusters(ctx *v23context.T, region string) []*eks.Cluster +} + +// Gets an EKS Cluster with Matching AWS region and caCrt. +func (conn *awsConn) GetEKSCluster(ctx *v23context.T, region string, caCrt string) (*eks.Cluster, error) { + var ( + cluster *eks.Cluster + err error + ) + + caCrtData := base64.StdEncoding.EncodeToString([]byte(caCrt)) + // TODO(noah): If performance becomes an issue, populate allow-list of clusters on ticket-server startup. + for _, c := range conn.GetClusters(ctx, region) { + if caCrtData == *c.CertificateAuthority.Data { + cluster = c + break + } + } + if cluster == nil { + err = errors.New("CA certificate does not match any cluster") + } + return cluster, err +} + +// Gets all EKS clusters in a given AWS region. +func (conn *awsConn) GetClusters(ctx *v23context.T, region string) []*eks.Cluster { + var clusters []*eks.Cluster + for _, r := range conn.regions { + if r == region { + for _, id := range conn.accountIDs { + roleARN := fmt.Sprintf("arn:aws:iam::%s:role/%s", id, conn.role) + listClusterOutput, err := conn.sessionWrapper.ListEKSClusters(&eks.ListClustersInput{}, roleARN, region) + if err != nil { + log.Error(ctx, "Unable to fetch list of clusters.", "roleARN", roleARN, "region", region) + } + for _, name := range listClusterOutput.Clusters { + describeClusterOutput, err := conn.sessionWrapper.DescribeEKSCluster(&eks.DescribeClusterInput{Name: name}, roleARN, region) + if err != nil { + log.Error(ctx, "Unable to describe cluster.", "clusterName", *name) + } + clusters = append(clusters, describeClusterOutput.Cluster) + } + } + } + } + return clusters +} + +// Defines connection parameters to a k8s cluster and can connect and return data. Isolated as an interface so complex http calls can be mocked for testing. +type k8sConn struct { + cluster *eks.Cluster + namespace string + region string + caCrt string + svcAcctToken string + sessionWrapper AWSSessionWrapperI +} + +// Creates a new k8s connection object that can be used to connect to the k8s cluster and obtain relevant data. +func newK8sConn(sessionWrapper AWSSessionWrapperI, cluster *eks.Cluster, region string, caCrt string, svcAcctToken string) *k8sConn { + return &k8sConn{ + sessionWrapper: sessionWrapper, + cluster: cluster, + region: region, + caCrt: caCrt, + svcAcctToken: svcAcctToken, + } +} + +// An interface for mocking k8sConn struct. +type k8sConnI interface { + GetK8sUsername(ctx context.Context) (string, error) +} + +func (conn *k8sConn) GetK8sUsername(ctx context.Context) (string, error) { + var ( + username string + err error + ) + + //svc := sts.New(conn.session) + + var headers = make(map[string]string) + headers["x-k8s-aws-id"] = *conn.cluster.Name + + authV1Client, err := conn.sessionWrapper.GetAuthV1Client(ctx, headers, conn.caCrt, conn.region, *conn.cluster.Endpoint) + if err != nil { + return username, err + } + + log.Debug(ctx, "AuthV1Client retrieved", "caCrt", conn.caCrt, "region", conn.region, "endpoint", *conn.cluster.Endpoint) + + tr := auth.TokenReview{ + Spec: auth.TokenReviewSpec{ + Token: conn.svcAcctToken, + }, + } + + log.Debug(ctx, "K8s Service account token configured for tokenReview request", "token", conn.svcAcctToken) + + trResp, err := authV1Client.TokenReviews().Create(&tr) + username = trResp.Status.User.Username + + if err != nil { + err = errors.E(err, "unable to create tokenreview") + } else if !trResp.Status.Authenticated { + err = errors.New("requestToken authentication failed") + } + + return username, err +} + +// CreateK8sExtension evaluates EKS Cluster configuration and tagging to produce a v23 Blessing extension. +func CreateK8sExtension(ctx context.Context, cluster *eks.Cluster, username, namespace string) (string, error) { + var ( + extension string + err error + clusterNameFromTag string + ) + + arn, err := arn.Parse(*cluster.Arn) + if err != nil { + return extension, err + } + + // Username is of format: system:serviceaccount:(NAMESPACE):(SERVICEACCOUNT) + usernameSet := strings.Split(username, ":") + + if len(usernameSet) != 4 { + return extension, errors.New("username does not match format system:serviceaccount:(NAMESPACE):(SERVICEACCOUNT)") + } else if namespace != usernameSet[2] { + return extension, errors.New("namespace does not match") + } + + if val, ok := cluster.Tags["ClusterName"]; ok { + clusterNameFromTag = *val + } + + /* + // leaving this implementation here, commented, as a relic of understanding; the intent was to use this to generate unique blessings + // for each cluster but allow the clusters blessing for vRPC calls to be authorized by parent blessing extensions, such as k8s:dev:svc:a + // authorizing when k8s:dev:svc is allowed. However, we cannot do this because Grailbook, and perhaps other services as well + // expect exact matches for blessings in some circumstances. In order to prevent a complicated rewrite of the code, we are removing + // the cluster mode specific blessings in order to keep implementation simple and plan to move out of vanadium sooner to something + // that better matches the authorization method desired. + + if val, ok := cluster.Tags["ClusterMode"]; ok { + clusterModeFromTag = strings.ToLower(*val) + } + */ + + if clusterNameFromTag != "" { + extension = fmt.Sprintf("k8s:%s:%s:%s", arn.AccountID, clusterNameFromTag, usernameSet[3]) + log.Debug(ctx, "Using k8s cluster a/b extension generation.", "extension", extension) + } else { + extension = fmt.Sprintf("k8s:%s:%s:%s", arn.AccountID, *cluster.Name, usernameSet[3]) + log.Debug(ctx, "Using standard k8s extension generation.", "extension", extension) + } + + return extension, err +} diff --git a/cmd/ticket-server/k8sblesser_test.go b/cmd/ticket-server/k8sblesser_test.go new file mode 100644 index 00000000..c93584a3 --- /dev/null +++ b/cmd/ticket-server/k8sblesser_test.go @@ -0,0 +1,399 @@ +package main + +import ( + "context" + "encoding/base64" + "fmt" + "io/ioutil" + "os" + "os/exec" + "path" + "testing" + "time" + + "github.com/aws/aws-sdk-go/service/eks" + ticketServerUtil "github.com/grailbio/base/cmd/ticket-server/testutil" + "github.com/grailbio/base/errors" + "github.com/grailbio/base/security/identity" + "github.com/grailbio/base/vcontext" + "github.com/grailbio/testutil" + + assert "github.com/stretchr/testify/assert" + + auth "k8s.io/api/authentication/v1" + client "k8s.io/client-go/kubernetes/typed/authentication/v1" + rest "k8s.io/client-go/rest" + + "v.io/v23/naming" + "v.io/x/ref" +) + +type FakeAWSSession struct { +} + +type FakeAuthV1Client struct { + client.AuthenticationV1Interface + RESTClientReturn rest.Interface + TokenReviewsReturn client.TokenReviewInterface +} + +func (w *FakeAuthV1Client) RESTClient() rest.Interface { + return w.RESTClientReturn +} + +func (w *FakeAuthV1Client) TokenReviews() client.TokenReviewInterface { + return w.TokenReviewsReturn +} + +type FakeTokenReviews struct { + client.TokenReviewInterface + TokenReviewReturn *auth.TokenReview +} + +func (t *FakeTokenReviews) Create(*auth.TokenReview) (*auth.TokenReview, error) { + var err error + return t.TokenReviewReturn, err +} + +// FakeAWSSessionWrapper mocks the session wrapper used to isolate +type FakeAWSSessionWrapper struct { + session *FakeAWSSession + GetAuthV1ClientReturn client.AuthenticationV1Interface + ListEKSClustersReturn *eks.ListClustersOutput + AllEKSClusters map[string]*eks.DescribeClusterOutput +} + +func (w *FakeAWSSessionWrapper) DescribeEKSCluster(input *eks.DescribeClusterInput, roleARN string, region string) (*eks.DescribeClusterOutput, error) { + var err error + return w.AllEKSClusters[*input.Name], err +} + +func (w *FakeAWSSessionWrapper) GetAuthV1Client(ctx context.Context, headers map[string]string, caCrt string, region string, endpoint string) (client.AuthenticationV1Interface, error) { + var err error + return w.GetAuthV1ClientReturn, err +} + +func (w *FakeAWSSessionWrapper) ListEKSClusters(input *eks.ListClustersInput, roleARN string, region string) (*eks.ListClustersOutput, error) { + var err error + return w.ListEKSClustersReturn, err +} + +// FakeContext mocks contexts so that we can pass them in to simulate logging, etc +type FakeContext struct { + context.Context +} + +// required to simulate logging. +func (c *FakeContext) Value(key interface{}) interface{} { + return nil +} + +// ClusterHelper generates all the cluster attributes used in a test +type ClusterHelper struct { + Name string + Arn string + Crt string + CrtEnc string + RoleARN string + Endpoint string + Cluster *eks.Cluster + ClusterOutput *eks.DescribeClusterOutput +} + +func newClusterHelper(name, acctNum, crt, roleARN, region string, tags map[string]*string) *ClusterHelper { + fakeAccountName := "ACCTNAMEFOR" + name + + ch := ClusterHelper{ + Name: name, + Arn: "arn:aws:iam::" + acctNum + ":role/" + name, + Crt: crt, + CrtEnc: base64.StdEncoding.EncodeToString([]byte(crt)), + RoleARN: roleARN, + Endpoint: "https://" + fakeAccountName + ".sk1." + region + ".eks.amazonaws.com", + } + + ch.Cluster = &eks.Cluster{ + Name: &ch.Name, + RoleArn: &ch.RoleARN, + Endpoint: &ch.Endpoint, + Tags: tags, + Arn: &ch.Arn, + CertificateAuthority: &eks.Certificate{ + Data: &ch.CrtEnc, + }, + } + + ch.ClusterOutput = &eks.DescribeClusterOutput{ + Cluster: ch.Cluster, + } + + return &ch +} + +// Note: we cannot test +func TestK8sBlesser(t *testing.T) { + emptyTags := make(map[string]*string) + randomTag := "test" + emptyTags["RandomTag"] = &randomTag + acctNum := "111111111111" + + ctx := vcontext.Background() + assert.NoError(t, ref.EnvClearCredentials()) + + t.Run("init", func(t *testing.T) { + fakeSessionWrapper := &FakeAWSSessionWrapper{session: &FakeAWSSession{}} + accountIDs := []string{"abc123456"} + awsRegions := []string{"us-west-2"} + testRole := "test-role" + compareAWSConn := newAwsConn(fakeSessionWrapper, testRole, awsRegions, accountIDs) + blesser := newK8sBlesser(fakeSessionWrapper, time.Hour, testRole, accountIDs, awsRegions) + + // test that awsConn was configured + assert.Equal(t, blesser.awsConn, compareAWSConn) + }) + + t.Run("awsConn", func(t *testing.T) { + fakeSessionWrapper := &FakeAWSSessionWrapper{session: &FakeAWSSession{}} + accountIDs := []string{acctNum} + awsRegions := []string{"us-west-2"} + testRole := "test-role" + testRegion := "us-west-2" + wantCluster := newClusterHelper("test-cluster", acctNum, "fake-crt", testRole, testRegion, emptyTags) + otherCluster1 := newClusterHelper("other-cluster1", acctNum, "other-crt1", testRole, testRegion, emptyTags) + otherCluster2 := newClusterHelper("other-cluster2", acctNum, "other-crt2", "another-role", testRegion, emptyTags) + + clusters := []string{wantCluster.Name, otherCluster1.Name, otherCluster2.Name} + var clusterOutputs = make(map[string]*eks.DescribeClusterOutput) + clusterOutputs[wantCluster.Name] = wantCluster.ClusterOutput + clusterOutputs[otherCluster1.Name] = otherCluster1.ClusterOutput + clusterOutputs[otherCluster2.Name] = otherCluster2.ClusterOutput + + clusterPtrs := []*string{} + for i := range clusters { + clusterPtrs = append(clusterPtrs, &clusters[i]) + } + + fakeSessionWrapper.ListEKSClustersReturn = &eks.ListClustersOutput{ + Clusters: clusterPtrs, + } + + fakeSessionWrapper.AllEKSClusters = clusterOutputs + + assert.NoError(t, ref.EnvClearCredentials()) + + blesser := newK8sBlesser(fakeSessionWrapper, time.Hour, testRole, accountIDs, awsRegions) + + clustersOutput := blesser.awsConn.GetClusters(ctx, testRegion) + assert.Equal(t, clustersOutput, []*eks.Cluster{wantCluster.Cluster, otherCluster1.Cluster, otherCluster2.Cluster}) + + foundEksCluster, _ := blesser.awsConn.GetEKSCluster(ctx, testRegion, wantCluster.Crt) + assert.NotNil(t, foundEksCluster) + }) + + t.Run("k8sConn", func(t *testing.T) { + var ( + foundUsername string + k8sConn *k8sConn + err error + ) + fakeSessionWrapper := &FakeAWSSessionWrapper{session: &FakeAWSSession{}} + testRole := "test-role" + testToken := "test-token" + testRegion := "us-west-2" + testUsername := "system:serviceaccount:default:someService" + cluster := newClusterHelper("test-cluster", acctNum, "fake-crt", testRole, testRegion, emptyTags) + + fakeTokenReviews := &FakeTokenReviews{} + fakeTokenReviews.TokenReviewReturn = &auth.TokenReview{ + Status: auth.TokenReviewStatus{ + User: auth.UserInfo{ + Username: testUsername, + }, + Authenticated: true, + }, + } + fakeContext := &FakeContext{} + fakeAuthV1Client := &FakeAuthV1Client{} + fakeAuthV1Client.TokenReviewsReturn = fakeTokenReviews + + fakeSessionWrapper.GetAuthV1ClientReturn = fakeAuthV1Client + k8sConn = newK8sConn(fakeSessionWrapper, cluster.Cluster, testRegion, cluster.Crt, testToken) + + foundUsername, err = k8sConn.GetK8sUsername(fakeContext) + assert.NoError(t, err) + assert.NotNil(t, foundUsername) + assert.Equal(t, testUsername, foundUsername) + + // test failure outputs + fakeTokenReviews.TokenReviewReturn = &auth.TokenReview{ + Status: auth.TokenReviewStatus{ + User: auth.UserInfo{ + Username: "", + }, + Authenticated: false, + }, + } + k8sConn = newK8sConn(fakeSessionWrapper, cluster.Cluster, testRegion, cluster.Crt, testToken) + foundUsername, err = k8sConn.GetK8sUsername(fakeContext) + assert.NotNil(t, err) + assert.Empty(t, foundUsername) + assert.Equal(t, err, errors.New("requestToken authentication failed")) + }) + + t.Run("CreateK8sExtension", func(t *testing.T) { + var ( + err error + cluster *ClusterHelper + extension string + ) + testRole := "test-role" + testRegion := "us-west-2" + testNamespace := "default" + clusterName := "test-cluster" + serviceAccountName := "someService" + testUsername := "system:serviceaccount:" + testNamespace + ":" + serviceAccountName + fakeContext := &FakeContext{} + + // test default cluster naming + cluster = newClusterHelper(clusterName, acctNum, "fake-crt", testRole, testRegion, emptyTags) + extension, err = CreateK8sExtension(fakeContext, cluster.Cluster, testUsername, testNamespace) + assert.NoError(t, err) + assert.Equal(t, "k8s:"+acctNum+":test-cluster:someService", extension) + + // test cluster a/b + tags := make(map[string]*string) + clusterMode := "A" + tags["ClusterName"] = &clusterName + tags["ClusterMode"] = &clusterMode + cluster = newClusterHelper(clusterName+"-a", acctNum, "fake-crt", testRole, testRegion, tags) + extension, err = CreateK8sExtension(fakeContext, cluster.Cluster, testUsername, testNamespace) + assert.Nil(t, err) + assert.Equal(t, "k8s:"+acctNum+":"+clusterName+":"+serviceAccountName, extension) + }) + + t.Run("BlessK8s", func(t *testing.T) { + testRole := "test-role" + testToken := "test-token" + testRegion := "us-west-2" + testNamespace := "default" + clusterName := "test-cluster" + serviceAccountName := "someService" + testUsername := "system:serviceaccount:" + testNamespace + ":" + serviceAccountName + accountIDs := []string{acctNum} + awsRegions := []string{testRegion} + + // tags for the ab Cluster + tags := make(map[string]*string) + clusterMode := "A" + tags["ClusterName"] = &clusterName + tags["ClusterMode"] = &clusterMode + + // setup fake clusters, lg = legacy, ab = with cluster a/b + lgCluster := newClusterHelper(clusterName, acctNum, "lg-crt", testRole, testRegion, emptyTags) + abCluster := newClusterHelper(clusterName+"-a", acctNum, "ab-crt", testRole, testRegion, tags) + + // creating clusters list + clusters := []string{lgCluster.Name, abCluster.Name} + + // outputs list for the desired client output + var clusterOutputs = make(map[string]*eks.DescribeClusterOutput) + clusterOutputs[lgCluster.Name] = lgCluster.ClusterOutput + clusterOutputs[abCluster.Name] = abCluster.ClusterOutput + + // assigning pointer to cluster names to cluster ptrs list + clusterPtrs := []*string{} + for i := range clusters { + clusterPtrs = append(clusterPtrs, &clusters[i]) + } + // setup fake token reviews + fakeTokenReviews := &FakeTokenReviews{} + fakeTokenReviews.TokenReviewReturn = &auth.TokenReview{ + Status: auth.TokenReviewStatus{ + User: auth.UserInfo{ + Username: testUsername, + }, + Authenticated: true, + }, + } + + // setup fake authv1 client + fakeAuthV1Client := &FakeAuthV1Client{} + fakeAuthV1Client.TokenReviewsReturn = fakeTokenReviews + + // setup fake session wrapper + fakeSessionWrapper := &FakeAWSSessionWrapper{session: &FakeAWSSession{}} + fakeSessionWrapper.ListEKSClustersReturn = &eks.ListClustersOutput{ + Clusters: clusterPtrs, + } + fakeSessionWrapper.AllEKSClusters = clusterOutputs + fakeSessionWrapper.GetAuthV1ClientReturn = fakeAuthV1Client + + assert.NoError(t, ref.EnvClearCredentials()) + + // setup fake blessings server + pathEnv := "PATH=" + os.Getenv("PATH") + exe := testutil.GoExecutable(t, "//go/src/github.com/grailbio/base/cmd/grail-access/grail-access") + + var blesserEndpoint naming.Endpoint + ctx, blesserEndpoint = ticketServerUtil.RunBlesserServer( + ctx, + t, + identity.K8sBlesserServer(newK8sBlesser(fakeSessionWrapper, time.Hour, testRole, accountIDs, awsRegions)), + ) + + var ( + tmpDir string + cleanUp func() + stdout string + principalDir string + principalCleanUp func() + cmd *exec.Cmd + ) + + // create local crt, namespace, and tokens for the legacy cluster + tmpDir, cleanUp = testutil.TempDir(t, "", "") + defer cleanUp() + + assert.NoError(t, ioutil.WriteFile(path.Join(tmpDir, "caCrt"), []byte(lgCluster.Crt), 0644)) + assert.NoError(t, ioutil.WriteFile(path.Join(tmpDir, "namespace"), []byte(testNamespace), 0644)) + assert.NoError(t, ioutil.WriteFile(path.Join(tmpDir, "token"), []byte(testToken), 0644)) + + // Run grail-access to create a principal and bless it with the k8s flow. + principalDir, principalCleanUp = testutil.TempDir(t, "", "") + defer principalCleanUp() + cmd = exec.Command(exe, + "-dir", principalDir, + "-blesser", fmt.Sprintf("/%s", blesserEndpoint.Address), + "-k8s", + "-ca-crt", path.Join(tmpDir, "caCrt"), + "-namespace", path.Join(tmpDir, "namespace"), + "-token", path.Join(tmpDir, "token"), + ) + cmd.Env = []string{pathEnv} + stdout, _ = ticketServerUtil.RunAndCapture(t, cmd) + assert.Contains(t, stdout, "k8s:111111111111:test-cluster:someService") + + // create local crt, namespace, and tokens for the a/b cluster + tmpDir, cleanUp = testutil.TempDir(t, "", "") + defer cleanUp() + + assert.NoError(t, ioutil.WriteFile(path.Join(tmpDir, "caCrt"), []byte(abCluster.Crt), 0644)) + assert.NoError(t, ioutil.WriteFile(path.Join(tmpDir, "namespace"), []byte(testNamespace), 0644)) + assert.NoError(t, ioutil.WriteFile(path.Join(tmpDir, "token"), []byte(testToken), 0644)) + + // Run grail-access to create a principal and bless it with the k8s flow. + principalDir, principalCleanUp = testutil.TempDir(t, "", "") + defer principalCleanUp() + cmd = exec.Command(exe, + "-dir", principalDir, + "-blesser", fmt.Sprintf("/%s", blesserEndpoint.Address), + "-k8s", + "-ca-crt", path.Join(tmpDir, "caCrt"), + "-namespace", path.Join(tmpDir, "namespace"), + "-token", path.Join(tmpDir, "token"), + ) + cmd.Env = []string{pathEnv} + stdout, _ = ticketServerUtil.RunAndCapture(t, cmd) + assert.Contains(t, stdout, "k8s:111111111111:test-cluster:someService") + }) +} diff --git a/cmd/ticket-server/list.go b/cmd/ticket-server/list.go new file mode 100644 index 00000000..5092bf19 --- /dev/null +++ b/cmd/ticket-server/list.go @@ -0,0 +1,31 @@ +package main + +import ( + "regexp" + "sort" + + "github.com/grailbio/base/common/log" + "v.io/v23/context" + "v.io/v23/rpc" +) + +type list struct{} + +func newList(ctx *context.T) *list { + return &list{} +} + +func (l *list) List(ctx *context.T, call rpc.ServerCall) ([]string, error) { + log.Info(ctx, "list request", "endpoint", "get", "blessing", call.Security().RemoteBlessings(), "ticket", call.Suffix()) + var result []string + ignored := regexp.MustCompile("blesser/*|list") + for t, e := range d.registry { + if ignore := ignored.MatchString(t); !ignore { + if err := e.auth.Authorize(ctx, call.Security()); err == nil { + result = append(result, t) + } + } + } + sort.Strings(result) + return result, nil +} diff --git a/cmd/ticket-server/main.go b/cmd/ticket-server/main.go index 45982bef..0308febb 100644 --- a/cmd/ticket-server/main.go +++ b/cmd/ticket-server/main.go @@ -3,7 +3,8 @@ // license that can be found in the LICENSE file. // The following enables go generate to generate the doc.go file. -//go:generate go run $GRAIL/go/src/vendor/v.io/x/lib/cmdline/testdata/gendoc.go "--build-cmd=go install" --copyright-notice= . -help +//go:generate go run v.io/x/lib/cmdline/gendoc "--build-cmd=go install" --copyright-notice= . -help + package main import ( @@ -15,32 +16,34 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" "github.com/grailbio/base/cmd/ticket-server/config" + "github.com/grailbio/base/common/log" "github.com/grailbio/base/security/identity" _ "github.com/grailbio/base/security/keycrypt/file" _ "github.com/grailbio/base/security/keycrypt/keychain" _ "github.com/grailbio/base/security/keycrypt/kms" "github.com/grailbio/base/security/ticket" + _ "github.com/grailbio/v23/factories/grail" "golang.org/x/oauth2/google" "golang.org/x/oauth2/jwt" admin "google.golang.org/api/admin/directory/v1" - "v.io/v23" + v23 "v.io/v23" "v.io/v23/context" "v.io/v23/glob" "v.io/v23/naming" "v.io/v23/rpc" "v.io/v23/security" "v.io/x/lib/cmdline" - "v.io/x/lib/vlog" "v.io/x/ref/lib/security/securityflag" "v.io/x/ref/lib/signals" "v.io/x/ref/lib/v23cmd" - _ "v.io/x/ref/runtime/factories/grail" ) var ( - nameFlag string - configDirFlag string - regionFlag string + nameFlag string + configDirFlag string + regionFlag string + googleUserSufixFlag string + googleAdminNameFlag string dryrunFlag bool @@ -53,6 +56,11 @@ var ( ec2DisableAddrCheckFlag bool ec2DisableUniquenessCheckFlag bool ec2DisablePendingTimeCheckFlag bool + + k8sBlesserRoleFlag string + k8sExpirationIntervalFlag time.Duration + awsAccountsFlag string + awsRegionsFlag string ) func newCmdRoot() *cmdline.Command { @@ -81,6 +89,15 @@ certificate + the private key and the URL to reach the Docker daemon. root.Flags.BoolVar(&ec2DisableAddrCheckFlag, "danger-danger-danger-ec2-disable-address-check", false, "Disable the IP address check for the EC2-based blessings requests. Only useful for local tests.") root.Flags.BoolVar(&ec2DisableUniquenessCheckFlag, "danger-danger-danger-ec2-disable-uniqueness-check", false, "Disable the uniqueness check for the EC2-based blessings requests. Only useful for local tests.") root.Flags.BoolVar(&ec2DisablePendingTimeCheckFlag, "danger-danger-danger-ec2-disable-pending-time-check", false, "Disable the pendint time check for the EC2-based blessings requests. Only useful for local tests.") + + root.Flags.StringVar(&googleUserSufixFlag, "google-user-domain", "grailbio.com", "Comma-separated list of email domains used for validating users.") + root.Flags.StringVar(&googleAdminNameFlag, "google-admin", "admin@grailbio.com", "Google Admin that can read all group memberships - NOTE: all groups will need to match the admin user's domain.") + + root.Flags.DurationVar(&k8sExpirationIntervalFlag, "k8s-expiration", 365*24*time.Hour, "Expiration caveat for the K8s-based blessings.") + root.Flags.StringVar(&k8sBlesserRoleFlag, "k8s-blesser-role", "ticket-server", "What role to use to lookup EKS cluster information on all authorized accounts. The role needs to exist in all the accounts.") + root.Flags.StringVar(&awsAccountsFlag, "aws-account-ids", "", "Commma-separated list of AWS account IDs used to populate allow-list of k8s clusters.") + root.Flags.StringVar(&awsRegionsFlag, "aws-regions", "us-west-2", "Commma-separated list of AWS regions used to populate allow-list of k8s clusters.") + return root } @@ -93,30 +110,26 @@ type node struct { var _ rpc.AllGlobber = (*node)(nil) -func (n *node) Glob__(ctx *context.T, call rpc.GlobServerCall, g *glob.Glob) error { - vlog.Infof("Glob: %+v len: %d tail: %+v recursive: %+v restricted: %+v", g, g.Len(), g.Tail(), g.Recursive(), g.Restricted()) +func (n *node) Glob__(ctx *context.T, call rpc.GlobServerCall, g *glob.Glob) error { // nolint: golint + log.Info(ctx, "glob request", "glob", g, "blessing", call.Security().RemoteBlessings(), "ticket", call.Suffix()) sender := call.SendStream() element := g.Head() // The key is the path to a node. children := map[string]interface{}{"": n} - vlog.VI(1).Infof("children: %+v", children) for g.Len() != 0 { children = descent(children) matches := map[string]interface{}{} for k, v := range children { v := v.(*node) - vlog.VI(1).Infof("k: %+v name: %+v", k, v.name) if element.Match(v.name) { matches[k] = v } } - vlog.VI(1).Infof("matches: %+v", matches) children = matches g = g.Tail() element = g.Head() - vlog.VI(1).Infof("glob: %+v len: %d tail: %+v recursive: %+v restricted: %+v", g, g.Len(), g.Tail(), g.Recursive(), g.Restricted()) } if g.String() == "..." { @@ -138,7 +151,6 @@ func (n *node) Glob__(ctx *context.T, call rpc.GlobServerCall, g *glob.Glob) err case *entry: isLeaf = true } - vlog.VI(1).Infof("send: %q isLeaf: %+v", k, isLeaf) sender.Send(naming.GlobReplyEntry{ Value: naming.MountEntry{ Name: strings.TrimLeft(k, "/"), @@ -189,10 +201,10 @@ type dispatcher struct { root *node } -var _ rpc.Dispatcher = (*dispatcher)(nil) +var d *dispatcher func newDispatcher(ctx *context.T, awsSession *session.Session, cfg config.Config, jwtConfig *jwt.Config) rpc.Dispatcher { - d := &dispatcher{ + d = &dispatcher{ registry: make(map[string]entry), root: &node{}, } @@ -200,19 +212,28 @@ func newDispatcher(ctx *context.T, awsSession *session.Session, cfg config.Confi // Note that the blesser/ endpoints are not exposed via Glob__ and the // permissions are governed by the -v23.permissions.{file,literal} flags. d.registry["blesser/google"] = entry{ - service: identity.GoogleBlesserServer(newGoogleBlesser(googleExpirationIntervalFlag)), - auth: securityflag.NewAuthorizerOrDie(), + service: identity.GoogleBlesserServer(newGoogleBlesser(ctx, googleExpirationIntervalFlag, + strings.Split(googleUserSufixFlag, ","))), + auth: securityflag.NewAuthorizerOrDie(ctx), + } + d.registry["blesser/k8s"] = entry{ + service: identity.K8sBlesserServer(newK8sBlesser(newSessionWrapper(awsSession), k8sExpirationIntervalFlag, k8sBlesserRoleFlag, strings.Split(awsAccountsFlag, ","), strings.Split(awsRegionsFlag, ","))), + auth: securityflag.NewAuthorizerOrDie(ctx), } if ec2BlesserRoleFlag != "" { d.registry["blesser/ec2"] = entry{ service: identity.Ec2BlesserServer(newEc2Blesser(ctx, awsSession, ec2ExpirationIntervalFlag, ec2BlesserRoleFlag, ec2DynamoDBTableFlag)), - auth: securityflag.NewAuthorizerOrDie(), + auth: securityflag.NewAuthorizerOrDie(ctx), } } + d.registry["list"] = entry{ + service: ticket.ListServiceServer(newList(ctx)), + auth: securityflag.NewAuthorizerOrDie(ctx), + } for k, v := range cfg { - auth := googleGroupsAuthorizer(v.Perms, jwtConfig) - vlog.Infof("registry add: %q perms: %+v", k, auth) + auth := googleGroupsAuthorizer(ctx, v.Perms, jwtConfig, googleAdminNameFlag) + log.Debug(ctx, "adding service to dispatcher registry", "name", k, "perms", v.Perms) parts := strings.Split(k, "/") n := d.root for _, p := range parts { @@ -226,6 +247,7 @@ func newDispatcher(ctx *context.T, awsSession *session.Session, cfg config.Confi n = n.children[p].(*node) } } + d.registry[k] = entry{ service: ticket.TicketServiceServer(&service{ name: parts[len(parts)-1], @@ -233,6 +255,7 @@ func newDispatcher(ctx *context.T, awsSession *session.Session, cfg config.Confi ticket: v.Ticket, perms: v.Perms, awsSession: awsSession, + controls: v.Controls, }), auth: auth, } @@ -242,7 +265,7 @@ func newDispatcher(ctx *context.T, awsSession *session.Session, cfg config.Confi // Lookup implements the Dispatcher interface from v.io/v23/rpc. func (d *dispatcher) Lookup(ctx *context.T, suffix string) (interface{}, security.Authorizer, error) { - vlog.Infof("suffix: %q ctx: %+v", suffix, ctx) + log.Debug(ctx, "performing service looking", "name", suffix) if s, ok := d.registry[suffix]; ok { return s.service, s.auth, nil } @@ -258,7 +281,6 @@ func run(ctx *context.T, env *cmdline.Env, args []string) error { if err != nil { return err } - vlog.Infof("ticketConfig:\n%#v", ticketConfig) if dryrunFlag { return nil @@ -269,7 +291,7 @@ func run(ctx *context.T, env *cmdline.Env, args []string) error { } blessings, _ := v23.GetPrincipal(ctx).BlessingStore().Default() - vlog.Infof("default blessings: %+v", blessings) + log.Debug(ctx, "using default blessing", "blessing", blessings) awsSession, err := session.NewSession(aws.NewConfig().WithRegion(regionFlag)) if err != nil { @@ -280,7 +302,7 @@ func run(ctx *context.T, env *cmdline.Env, args []string) error { if err != nil { return err } - jwtConfig, err := google.JWTConfigFromJSON(serviceAccountJSON, admin.AdminDirectoryGroupMemberReadonlyScope) + jwtConfig, err := google.JWTConfigFromJSON(serviceAccountJSON, admin.AdminDirectoryGroupMemberReadonlyScope+" "+admin.AdminDirectoryGroupReadonlyScope) if err != nil { return err } @@ -292,7 +314,7 @@ func run(ctx *context.T, env *cmdline.Env, args []string) error { } for _, endpoint := range s.Status().Endpoints { - vlog.Infof("ENDPOINT=%s\n", endpoint.Name()) + log.Info(ctx, "server endpoint", "addr", endpoint.Name()) } <-signals.ShutdownOnSignals(ctx) // Wait forever. return nil diff --git a/cmd/ticket-server/service.go b/cmd/ticket-server/service.go index 7d0a779c..c5b9f78a 100644 --- a/cmd/ticket-server/service.go +++ b/cmd/ticket-server/service.go @@ -5,14 +5,15 @@ package main import ( + "errors" "fmt" "github.com/aws/aws-sdk-go/aws/session" + "github.com/grailbio/base/common/log" "github.com/grailbio/base/security/ticket" "v.io/v23/context" "v.io/v23/rpc" "v.io/v23/security/access" - "v.io/x/lib/vlog" ) type service struct { @@ -21,39 +22,74 @@ type service struct { ticket ticket.Ticket perms access.Permissions awsSession *session.Session + controls map[ticket.Control]bool } -func (s *service) Get(ctx *context.T, call rpc.ServerCall) (ticket.Ticket, error) { - vlog.Infof("Get: ctx: %+v call: %+v", ctx, call) +func (s *service) log(ctx *context.T, call rpc.ServerCall, parameters []ticket.Parameter, args map[string]string) { + logArgs := make([]interface{}, len(parameters)+len(args)*2+2) + i := 0 + for _, p := range parameters { + logArgs[i] = p.Key + logArgs[i+1] = p.Value + i += 2 + } + for k, v := range args { + logArgs[i] = k + logArgs[i+1] = v + i += 2 + } + log.Info(ctx, "Fetching ticket.", logArgs...) +} + +func (s *service) get(ctx *context.T, call rpc.ServerCall, parameters []ticket.Parameter, args map[string]string) (ticket.Ticket, error) { + s.log(ctx, call, parameters, args) + if ok, err := s.checkControls(ctx, call, args); !ok { + return nil, err + } remoteBlessings := call.Security().RemoteBlessings() ticketCtx := ticket.NewTicketContext(ctx, s.awsSession, remoteBlessings) switch t := s.ticket.(type) { case ticket.TicketAwsTicket: - return t.Build(ticketCtx) + return t.Build(ticketCtx, parameters) case ticket.TicketS3Ticket: - return t.Build(ticketCtx) + return t.Build(ticketCtx, parameters) + case ticket.TicketSshCertificateTicket: + return t.Build(ticketCtx, parameters) case ticket.TicketEcrTicket: - return t.Build(ticketCtx) + return t.Build(ticketCtx, parameters) case ticket.TicketTlsServerTicket: - return t.Build(ticketCtx) + return t.Build(ticketCtx, parameters) case ticket.TicketTlsClientTicket: - return t.Build(ticketCtx) + return t.Build(ticketCtx, parameters) case ticket.TicketDockerTicket: - return t.Build(ticketCtx) + return t.Build(ticketCtx, parameters) case ticket.TicketDockerServerTicket: - return t.Build(ticketCtx) + return t.Build(ticketCtx, parameters) case ticket.TicketDockerClientTicket: - return t.Build(ticketCtx) + return t.Build(ticketCtx, parameters) case ticket.TicketB2Ticket: - return t.Build(ticketCtx) + return t.Build(ticketCtx, parameters) case ticket.TicketVanadiumTicket: - return t.Build(ticketCtx) + return t.Build(ticketCtx, parameters) case ticket.TicketGenericTicket: - return t.Build(ticketCtx) + return t.Build(ticketCtx, parameters) } return nil, fmt.Errorf("not implemented") } +func (s *service) Get(ctx *context.T, call rpc.ServerCall) (ticket.Ticket, error) { + log.Info(ctx, "get request", "blessing", call.Security().RemoteBlessings(), "ticket", call.Suffix()) + return s.get(ctx, call, nil, nil) +} + +func (s *service) GetWithArgs(ctx *context.T, call rpc.ServerCall, args map[string]string) (ticket.Ticket, error) { + return s.get(ctx, call, nil, args) +} + +func (s *service) GetWithParameters(ctx *context.T, call rpc.ServerCall, parameters []ticket.Parameter) (ticket.Ticket, error) { + return s.get(ctx, call, parameters, nil) +} + // GetPermissions implements the Object interface from // v.io/v23/services/permissions. func (s *service) GetPermissions(ctx *context.T, call rpc.ServerCall) (perms access.Permissions, version string, _ error) { @@ -67,3 +103,12 @@ func (s *service) GetPermissions(ctx *context.T, call rpc.ServerCall) (perms acc func (s *service) SetPermissions(ctx *context.T, call rpc.ServerCall, perms access.Permissions, version string) error { return fmt.Errorf("not implemented") } + +func (s *service) checkControls(_ *context.T, _ rpc.ServerCall, args map[string]string) (bool, error) { + for control, required := range s.controls { + if required && args[control.String()] == "" { + return false, errors.New("missing required argument: " + control.String()) + } + } + return true, nil +} diff --git a/cmd/ticket-server/service_test.go b/cmd/ticket-server/service_test.go new file mode 100644 index 00000000..3a63571e --- /dev/null +++ b/cmd/ticket-server/service_test.go @@ -0,0 +1,55 @@ +package main + +import ( + "testing" + + "github.com/grailbio/base/security/ticket" +) + +func TestCheckControls(t *testing.T) { + cases := []struct { + s service + args map[string]string + wantOk bool + wantErrMsg string + }{ + { + service{}, + map[string]string{}, + true, + "", + }, + { + service{ + controls: map[ticket.Control]bool{ + ticket.ControlRationale: true, + }, + }, + map[string]string{ + "Rationale": "rationale", + }, + true, + "", + }, + { + service{ + controls: map[ticket.Control]bool{ + ticket.ControlRationale: true, + }, + }, + map[string]string{}, + false, + "missing required argument: Rationale", + }, + } + for _, c := range cases { + ok, err := c.s.checkControls(nil, nil, c.args) + if ok != c.wantOk { + t.Errorf("unexpected ok value: got: %t, want: %t", ok, c.wantOk) + } + if c.wantErrMsg != "" && err.Error() != c.wantErrMsg { + t.Errorf("unexpected err value: got: %s, want: %s", err.Error(), c.wantErrMsg) + } + } + +} diff --git a/cmd/ticket-server/support.go b/cmd/ticket-server/support.go new file mode 100644 index 00000000..7804554d --- /dev/null +++ b/cmd/ticket-server/support.go @@ -0,0 +1,39 @@ +// Copyright 2020 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +package main + +import ( + "strings" +) + +// Return true if a string matches a value in a list +func stringInSlice(haystack []string, needle string) bool { + for _, s := range haystack { + if needle == s { + return true + } + } + return false +} + +// Return a new list after applying a function to the provided list +func fmap(stringList []string, f func(string) string) []string { + resultList := make([]string, len(stringList)) + for i, v := range stringList { + resultList[i] = f(v) + } + return resultList +} + +// Returns the domain part of an email, or "" if it did not split correctly +func emailDomain(email string) string { + components := strings.Split(email, "@") + // Email should have 2 parts. + if len(components) != 2 { + return "" + } else { + return components[1] // domain part + } +} diff --git a/cmd/ticket-server/support_test.go b/cmd/ticket-server/support_test.go new file mode 100644 index 00000000..6ddc2a06 --- /dev/null +++ b/cmd/ticket-server/support_test.go @@ -0,0 +1,60 @@ +package main + +import ( + "reflect" + "strings" + "testing" +) + +func TestStringInSlice(t *testing.T) { + cases := []struct { + haystack []string + match string + result bool + }{ + {[]string{"a", "b", "c"}, "a", true}, + {[]string{"a", "b", "c"}, "d", false}, + } + + for _, c := range cases { + got, want := stringInSlice(c.haystack, c.match), c.result + if got != want { + t.Errorf("stringInSlice(%+v, %q): got %t, want %t", c.haystack, c.match, got, want) + } + } +} + +func TestFmap(t *testing.T) { + cases := []struct { + stringList []string + f func(string) string + result []string + }{ + {[]string{"a", "b", "c"}, strings.ToUpper, []string{"A", "B", "C"}}, + } + + for _, c := range cases { + got, want := fmap(c.stringList, c.f), c.result + if !reflect.DeepEqual(got, want) { + t.Errorf("Map(%+v, ...): got %+v, want %+v", c.stringList, got, want) + } + } +} + +func TestEmailDomain(t *testing.T) { + cases := []struct { + email string + domain string + }{ + {"aeiser@grailbio.com", "grailbio.com"}, + {"aeisergrailbio.com", ""}, + {"aeiser@grail@bio.com", ""}, + } + + for _, c := range cases { + got, want := emailDomain(c.email), c.domain + if got != want { + t.Errorf("emailDomain(%q): got %q, want %q", c.email, got, want) + } + } +} diff --git a/cmd/ticket-server/testutil/BUILD.bazel b/cmd/ticket-server/testutil/BUILD.bazel new file mode 100644 index 00000000..42bf4c6a --- /dev/null +++ b/cmd/ticket-server/testutil/BUILD.bazel @@ -0,0 +1,16 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "go_default_library", + srcs = ["util.go"], + importpath = "github.com/grailbio/base/cmd/ticket-server/testutil", + visibility = ["//visibility:public"], + deps = [ + "@com_github_stretchr_testify//assert:go_default_library", + "@io_v//v23:go_default_library", + "@io_v//v23/context:go_default_library", + "@io_v//v23/naming:go_default_library", + "@io_v//v23/rpc:go_default_library", + "@io_v//v23/security:go_default_library", + ], +) diff --git a/cmd/ticket-server/testutil/README.md b/cmd/ticket-server/testutil/README.md new file mode 100644 index 00000000..480e97d2 --- /dev/null +++ b/cmd/ticket-server/testutil/README.md @@ -0,0 +1,13 @@ +# Ticket-Server/TestUtil + +The following functions and future functionality appeared to be needed across multiple packages and in multiple tests and relevant to v23 functionality. I opted to create a smaller package for them to be imported in a slim, DRY fashion. + +The general use of these functions is to mock and test ticket-server functionality when verifying blessings and other v23 artifacts. + +## func RunBlesserServer + +Creates a generic ticket server to use in a test with all permissions. + +# func RunAndCapture + +Runs a command and captures the output, somewhat specifically designed for running and capturing grail-access output against a mock ticket server. \ No newline at end of file diff --git a/cmd/ticket-server/testutil/util.go b/cmd/ticket-server/testutil/util.go new file mode 100644 index 00000000..401570a3 --- /dev/null +++ b/cmd/ticket-server/testutil/util.go @@ -0,0 +1,39 @@ +package testutil + +import ( + "fmt" + "os/exec" + "strings" + "testing" + + assert "github.com/stretchr/testify/assert" + + v23 "v.io/v23" + v23context "v.io/v23/context" + "v.io/v23/naming" + "v.io/v23/rpc" + "v.io/v23/security" +) + +// RunBlesserServer runs a test v23 server, returns a context and an endpoint +func RunBlesserServer(ctx *v23context.T, t *testing.T, stub interface{}) (*v23context.T, naming.Endpoint) { + ctx = v23.WithListenSpec(ctx, rpc.ListenSpec{ + Addrs: rpc.ListenAddrs{{"tcp", "localhost:0"}}, + }) + ctx, blesserServer, err := v23.WithNewServer(ctx, "", stub, security.AllowEveryone()) + assert.NoError(t, err) + blesserEndpoints := blesserServer.Status().Endpoints + assert.Equal(t, 1, len(blesserEndpoints)) + return ctx, blesserEndpoints[0] +} + +// RunAndCapture runs a command and captures the stdout and stderr +func RunAndCapture(t *testing.T, cmd *exec.Cmd) (stdout, stderr string) { + var stdoutBuf, stderrBuf strings.Builder + cmd.Stdout = &stdoutBuf + cmd.Stderr = &stderrBuf + err := cmd.Run() + stdout, stderr = stdoutBuf.String(), stderrBuf.String() + assert.NoError(t, err, fmt.Sprintf("stdout: '%s', stderr: '%s'", stdout, stderr)) + return +} diff --git a/cmdutil/access.go b/cmdutil/access.go index df6fc3d4..98231497 100644 --- a/cmdutil/access.go +++ b/cmdutil/access.go @@ -10,8 +10,7 @@ import ( "os" "time" - "github.com/grailbio/base/vcontext" - "v.io/v23" + v23 "v.io/v23" "v.io/v23/context" "v.io/x/lib/cmdline" ) @@ -23,9 +22,9 @@ func WriteBlessings(ctx *context.T, out io.Writer) { principal := v23.GetPrincipal(ctx) fmt.Fprintf(out, "Public key: %s\n", principal.PublicKey()) fmt.Fprintf(out, "---------------- BlessingStore ----------------") - fmt.Fprintf(out, principal.BlessingStore().DebugString()) + fmt.Fprint(out, principal.BlessingStore().DebugString()) fmt.Fprintf(out, "---------------- BlessingRoots ----------------") - fmt.Fprintf(out, principal.Roots().DebugString()) + fmt.Fprint(out, principal.Roots().DebugString()) } // CheckAccess checkes that the current process has credentials that @@ -54,12 +53,13 @@ func CheckAccess(ctx *context.T) (time.Duration, error) { type runner struct { access bool + ctxfn func() *context.T run func(*context.T, *cmdline.Env, []string) error } // Run implements cmdline.Runner. func (r runner) Run(env *cmdline.Env, args []string) error { - ctx := vcontext.Background() + ctx := r.ctxfn() if os.Getenv("GRAIL_CMDUTIL_NO_ACCESS_CHECK") != "1" && r.access { if _, err := CheckAccess(ctx); err != nil { return err @@ -68,15 +68,14 @@ func (r runner) Run(env *cmdline.Env, args []string) error { return r.run(ctx, env, args) } -// RunnerFuncWithAccessCheck is like cmdutil.RunnerFunc, but allows for -// a context.T parameter and calls CheckAccess to test for credential -// existence/expiry. -func RunnerFuncWithAccessCheck(run func(*context.T, *cmdline.Env, []string) error) cmdline.Runner { - return RunnerFunc(runner{true, run}.Run) +// V23RunnerFunc is like cmdutil.RunnerFunc, but allows for a context.T +// parameter that is given the context as obtained from ctxfn. +func V23RunnerFunc(ctxfn func() *context.T, run func(*context.T, *cmdline.Env, []string) error) cmdline.Runner { + return RunnerFunc(runner{false, ctxfn, run}.Run) } -// V23RunnerFunc is like cmdutil.RunnerFunc, but allows for a context.T -// parameter. -func V23RunnerFunc(run func(*context.T, *cmdline.Env, []string) error) cmdline.Runner { - return RunnerFunc(runner{false, run}.Run) +// RunnerFuncWithAccessCheck is like V23RunnerFunc, but also calls CheckAccess +// to test for credential existence/expiry. +func RunnerFuncWithAccessCheck(ctxfn func() *context.T, run func(*context.T, *cmdline.Env, []string) error) cmdline.Runner { + return RunnerFunc(runner{true, ctxfn, run}.Run) } diff --git a/cmdutil/access_test.go b/cmdutil/access_test.go index 72d12297..6cd35fa5 100644 --- a/cmdutil/access_test.go +++ b/cmdutil/access_test.go @@ -2,6 +2,9 @@ // Use of this source code is governed by the Apache-2.0 // license that can be found in the LICENSE file. +//go:build !unit +// +build !unit + package cmdutil_test import ( @@ -13,11 +16,12 @@ import ( "github.com/grailbio/base/cmdutil" "github.com/grailbio/testutil" "v.io/x/ref/lib/security" + _ "v.io/x/ref/runtime/factories/library" "v.io/x/ref/test" ) func TestCheckAccess(t *testing.T) { - flag.Set("v23.credentials", "") + _ = flag.Set("v23.credentials", "") dir, cleanup := testutil.TempDir(t, "", "check-access") defer cleanup() cdir := filepath.Join(dir, "creds") diff --git a/cmdutil/cmdline-test/cmdline-test.go b/cmdutil/cmdline-test/cmdline-test.go index 65ce05c5..282c92fe 100644 --- a/cmdutil/cmdline-test/cmdline-test.go +++ b/cmdutil/cmdline-test/cmdline-test.go @@ -9,6 +9,7 @@ import ( "github.com/grailbio/base/cmdutil" _ "github.com/grailbio/base/cmdutil/interactive" + "github.com/grailbio/base/vcontext" "v.io/v23/context" "v.io/x/lib/cmdline" "v.io/x/lib/vlog" @@ -31,7 +32,7 @@ var logging = &cmdline.Command{ var access = &cmdline.Command{ Name: "access", ArgsName: "args", - Runner: cmdutil.RunnerFuncWithAccessCheck(runnerWithRPC), + Runner: cmdutil.RunnerFuncWithAccessCheck(vcontext.Background, runnerWithRPC), } func main() { @@ -43,7 +44,6 @@ func runner(_ *cmdline.Env, args []string) error { vlog.Infof("-----") for i, a := range args { vlog.Infof("T: %d: %v", i, a) - vlog.VI(2).Infof("V2: %d: %v", i, a) } return nil } @@ -53,7 +53,6 @@ func runnerWithRPC(ctx *context.T, _ *cmdline.Env, args []string) error { vlog.Infof("-----") for i, a := range args { vlog.Infof("T: %d: %v", i, a) - vlog.VI(2).Infof("V2: %d: %v", i, a) } return nil } diff --git a/cmdutil/logging_test.go b/cmdutil/logging_test.go index 1be933bf..5085e97d 100644 --- a/cmdutil/logging_test.go +++ b/cmdutil/logging_test.go @@ -2,6 +2,8 @@ // Use of this source code is governed by the Apache-2.0 // license that can be found in the LICENSE file. +//+build !unit + package cmdutil_test import ( @@ -36,8 +38,6 @@ func testLogging(t *testing.T, sh *gosh.Shell, tempdir, logger, naked string) { }{ {logger, []string{"logging"}, "T"}, {logger, []string{"access"}, "T"}, - {logger, []string{"-v=2", "logging"}, "V2"}, - {logger, []string{"-v=2", "access"}, "V2"}, {naked, nil, "T"}, } { args := append(tc.args, "--log_dir="+tempdir, "a", "b") diff --git a/cmdutil/naked-test/naked-test.go b/cmdutil/naked-test/naked-test.go index 188b027e..72b7d452 100644 --- a/cmdutil/naked-test/naked-test.go +++ b/cmdutil/naked-test/naked-test.go @@ -8,8 +8,8 @@ import ( "flag" "fmt" - "github.com/grailbio/base/grail" _ "github.com/grailbio/base/cmdutil/interactive" + "github.com/grailbio/base/grail" "v.io/x/lib/vlog" ) diff --git a/cmdutil/network_flags.go b/cmdutil/network_flags.go new file mode 100644 index 00000000..35d8437f --- /dev/null +++ b/cmdutil/network_flags.go @@ -0,0 +1,36 @@ +package cmdutil + +import "net" + +// NetworkAddressFlag represents a network address in host:port format. +type NetworkAddressFlag struct { + Address string + Host string + Port string + Specified bool +} + +// Set implements flag.Value.Set +func (na *NetworkAddressFlag) Set(v string) error { + host, port, err := net.SplitHostPort(v) + if err != nil { + na.Host = v + na.Port = "0" + } else { + na.Host = host + na.Port = port + } + na.Address = v + na.Specified = true + return nil +} + +// String implements flag.Value.String +func (na *NetworkAddressFlag) String() string { + return na.Address +} + +// Get implements flag.Value.Get +func (na *NetworkAddressFlag) Get() interface{} { + return na.String() +} diff --git a/cmdutil/runner.go b/cmdutil/runner.go index 13515cc9..f7a0d53b 100644 --- a/cmdutil/runner.go +++ b/cmdutil/runner.go @@ -7,8 +7,8 @@ package cmdutil import ( "sync" - "github.com/grailbio/base/grail" "github.com/grailbio/base/pprof" + "github.com/grailbio/base/shutdown" "v.io/x/lib/cmdline" "v.io/x/lib/vlog" ) @@ -23,12 +23,12 @@ type RunnerFunc func(*cmdline.Env, []string) error // at the end. func (f RunnerFunc) Run(env *cmdline.Env, args []string) error { runnerOnce.Do(func() { - vlog.ConfigureLibraryLoggerFromFlags() + _ = vlog.ConfigureLibraryLoggerFromFlags() pprof.Start() }) err := f(env, args) - grail.RunShutdownCallbacks() + shutdown.Run() vlog.FlushLog() pprof.Write(1) return err diff --git a/cmdutil/ticket_flags.go b/cmdutil/ticket_flags.go index 4daa2b46..2c41e3b2 100644 --- a/cmdutil/ticket_flags.go +++ b/cmdutil/ticket_flags.go @@ -22,14 +22,14 @@ import ( // and/or // --flag=t1,t2 --flag=t3 type TicketFlags struct { - set bool - dedup map[string]bool - fs *flag.FlagSet - rcFlag string - Tickets []string - TicketRCFile string - ticketRCFlag stringFlag - Timeout time.Duration + set bool + dedup map[string]bool + fs *flag.FlagSet + ticketFlag, rcFlag string + Tickets []string + TicketRCFile string + ticketRCFlag stringFlag + Timeout time.Duration } // wrapper to catch explicit setting of a flag. @@ -57,10 +57,15 @@ func (sf *stringFlag) String() string { // Set implements flag.Value. func (tf *TicketFlags) Set(v string) error { - if tf.dedup == nil { + if !tf.set { + // Clear any defaults if setting for the first time. + tf.Tickets = nil tf.dedup = map[string]bool{} } for _, ps := range strings.Split(v, ",") { + if ps == "" { + continue + } if !tf.dedup[ps] { tf.Tickets = append(tf.Tickets, ps) } @@ -70,6 +75,17 @@ func (tf *TicketFlags) Set(v string) error { return nil } +// setDefaults sets default ticket paths for the flag. These values are cleared +// the first time the flag is explicitly parsed in the flag set. +func (tf *TicketFlags) setDefaults(tickets []string) { + tf.Tickets = tickets + tf.dedup = map[string]bool{} + for _, t := range tickets { + tf.dedup[t] = true + } + tf.fs.Lookup(tf.ticketFlag).DefValue = strings.Join(tickets, ",") +} + // String implements flag.Value. func (tf *TicketFlags) String() string { return strings.Join(tf.Tickets, ",") @@ -99,7 +115,7 @@ func (tf *TicketFlags) ReadEnvOrFile() error { sc := bufio.NewScanner(f) for sc.Scan() { if l := strings.TrimSpace(sc.Text()); len(l) > 0 { - tf.Set(l) + _ = tf.Set(l) } } return sc.Err() @@ -109,8 +125,9 @@ func (tf *TicketFlags) ReadEnvOrFile() error { // supplied FlagSet. The flags are: // --ticket // --ticket-timeout -// --tickerc -func RegisterTicketFlags(fs *flag.FlagSet, prefix string, flags *TicketFlags) { +// --ticketrc +func RegisterTicketFlags(fs *flag.FlagSet, prefix string, defaultTickets []string, flags *TicketFlags) { + flags.fs = fs desc := "Comma separated list of GRAIL security tickets, and/or the flag may be repeated" fs.Var(flags, prefix+"ticket", desc) fs.DurationVar(&flags.Timeout, prefix+"ticket-timeout", time.Minute, "specifies the timeout duration for obtaining any single GRAIL security ticket") @@ -119,5 +136,7 @@ func RegisterTicketFlags(fs *flag.FlagSet, prefix string, flags *TicketFlags) { flags.TicketRCFile = filepath.Join(os.Getenv("HOME"), ".ticketrc") fs.Var(&flags.ticketRCFlag, flags.ticketRCFlag.name, "a file containing the tickets to use") fs.Lookup(prefix + "ticketrc").DefValue = "$HOME/.ticketrc" + flags.ticketFlag = prefix + "ticket" flags.rcFlag = prefix + "ticketrc" + flags.setDefaults(defaultTickets) } diff --git a/cmdutil/version.go b/cmdutil/version.go index 1a266e78..0ed6eb83 100644 --- a/cmdutil/version.go +++ b/cmdutil/version.go @@ -11,7 +11,6 @@ import ( "v.io/v23/context" "v.io/x/lib/cmdline" "v.io/x/ref/lib/v23cmd" - _ "v.io/x/ref/runtime/factories/grail" ) var ( diff --git a/common/log/context.go b/common/log/context.go new file mode 100644 index 00000000..cd146875 --- /dev/null +++ b/common/log/context.go @@ -0,0 +1,33 @@ +package log + +import ( + "context" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" +) + +const ( + // Best practices for avoiding key collisions in context say this shouldn't be a string type. + // Gin forces us to use a string for their context, so choose a name that is unlikely to collide with anything else. + RequestIDContextKey = "grail_logger_request_id" +) + +// WithRequestID sets the uuid value for the RequestIDContextKey key in the context. +func WithRequestID(ctx context.Context, requestID uuid.UUID) context.Context { + return context.WithValue(ctx, RequestIDContextKey, requestID) +} + +// WithGinRequestID creates a uuid that is set as a string on the gin Context and as +// a uuid on the regular-flavor Request context that it wraps. The context should +// be passed to the methods in this package to prefix logs with the identifier. +func WithGinRequestID(ctx *gin.Context) { + requuid := uuid.New() + uuidStr := requuid.String() + if _, ok := ctx.Get(RequestIDContextKey); ok { + return // Avoid overwriting the original value in case this middleware is invoked twice + } + ctx.Set(RequestIDContextKey, uuidStr) + // TODO: ideally we'd pass the X-Amzn-Trace-Id header from our ALB, but we're not using ALBs yet. + ctx.Request = ctx.Request.WithContext(WithRequestID(ctx.Request.Context(), requuid)) +} diff --git a/common/log/log.go b/common/log/log.go new file mode 100644 index 00000000..b4397d37 --- /dev/null +++ b/common/log/log.go @@ -0,0 +1,231 @@ +package log + +import ( + "context" + "fmt" + + "go.uber.org/zap/zapcore" + + "go.uber.org/zap" +) + +var logger = NewLogger(Config{Level: DebugLevel}) + +// SetLoggerConfig sets the logging config. +func SetLoggerConfig(config Config) { + logger = NewLogger(config) +} + +// SetLoggerLevel sets the logging level. +func SetLoggerLevel(level zapcore.Level) { + SetLoggerConfig(Config{Level: level}) +} + +// Instantiate a new logger and assign any key-value pair to addedInfo field in logger to log additional +// information specific to service +func GetNewLoggerWithDefaultFields(addedInfo ...interface{}) *Logger { + return NewLoggerWithDefaultFields(Config{Level: DebugLevel}, addedInfo) +} + +// Debug logs a message, the key-value pairs defined in contextFields from ctx, and variadic key-value pairs. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +func Debug(ctx context.Context, msg string, keysAndValues ...interface{}) { + Debugv(ctx, 1, msg, keysAndValues...) +} + +// Debugf uses fmt.Sprintf to log a templated message and the key-value pairs defined in contextFields from ctx. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +func Debugf(ctx context.Context, fs string, args ...interface{}) { + Debugv(ctx, 1, fmt.Sprintf(fs, args...)) +} + +// Debugv logs a message, the key-value pairs defined in contextFields from ctx, and variadic key-value pairs. +// Caller is skipped by skip. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +func Debugv(ctx context.Context, skip int, msg string, keysAndValues ...interface{}) { + logger.Debugv(ctx, skip+1, msg, keysAndValues...) +} + +// DebugNoCtx logs a message and variadic key-value pairs. +func DebugNoCtx(msg string, keysAndValues ...interface{}) { + // context.Background() is a singleton and gets initialized once + Debugv(context.Background(), 1, msg, keysAndValues...) +} + +// DebugfNoCtx uses fmt.Sprintf to log a templated message. +func DebugfNoCtx(fs string, args ...interface{}) { + Debugv(context.Background(), 1, fmt.Sprintf(fs, args...)) +} + +// Info logs a message, the key-value pairs defined in contextFields from ctx, and variadic key-value pairs. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +func Info(ctx context.Context, msg string, keysAndValues ...interface{}) { + Infov(ctx, 1, msg, keysAndValues...) +} + +// Infof uses fmt.Sprintf to log a templated message and the key-value pairs defined in contextFields from ctx. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +func Infof(ctx context.Context, fs string, args ...interface{}) { + Infov(ctx, 1, fmt.Sprintf(fs, args...)) +} + +// Infov logs a message, the key-value pairs defined in contextFields from ctx, and variadic key-value pairs. +// Caller is skipped by skip. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +func Infov(ctx context.Context, skip int, msg string, keysAndValues ...interface{}) { + logger.Infov(ctx, skip+1, msg, keysAndValues...) +} + +// InfoNoCtx logs a message and variadic key-value pairs. +func InfoNoCtx(msg string, keysAndValues ...interface{}) { + // context.Background() is a singleton and gets initialized once + Infov(context.Background(), 1, msg, keysAndValues...) +} + +// InfofNoCtx uses fmt.Sprintf to log a templated message. +func InfofNoCtx(fs string, args ...interface{}) { + Infov(context.Background(), 1, fmt.Sprintf(fs, args...)) +} + +// Warn logs a message, the key-value pairs defined in contextFields from ctx, and variadic key-value pairs. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +func Warn(ctx context.Context, msg string, keysAndValues ...interface{}) { + Warnv(ctx, 1, msg, keysAndValues...) +} + +// Warnf uses fmt.Sprintf to log a templated message and the key-value pairs defined in contextFields from ctx. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +func Warnf(ctx context.Context, fs string, args ...interface{}) { + Warnv(ctx, 1, fmt.Sprintf(fs, args...)) +} + +// Warnv logs a message, the key-value pairs defined in contextFields from ctx, and variadic key-value pairs. +// Caller is skipped by skip. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +func Warnv(ctx context.Context, skip int, msg string, keysAndValues ...interface{}) { + logger.Warnv(ctx, skip+1, msg, keysAndValues...) +} + +// WarnNoCtx logs a message and variadic key-value pairs. +func WarnNoCtx(msg string, keysAndValues ...interface{}) { + // context.Background() is a singleton and gets initialized once + Warnv(context.Background(), 1, msg, keysAndValues...) +} + +// WarnfNoCtx uses fmt.Sprintf to log a templated message. +func WarnfNoCtx(fs string, args ...interface{}) { + Warnv(context.Background(), 1, fmt.Sprintf(fs, args...)) +} + +// Fatal logs a message, the key-value pairs defined in contextFields from ctx, and variadic key-value pairs. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +func Fatal(ctx context.Context, msg string, keysAndValues ...interface{}) { + Fatalv(ctx, 1, msg, keysAndValues...) +} + +// Fatalf uses fmt.Sprintf to log a templated message and the key-value pairs defined in contextFields from ctx. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +func Fatalf(ctx context.Context, fs string, args ...interface{}) { + Fatalv(ctx, 1, fmt.Sprintf(fs, args...)) +} + +// Fatalv logs a message, the key-value pairs defined in contextFields from ctx, and variadic key-value pairs. +// Caller is skipped by skip. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +func Fatalv(ctx context.Context, skip int, msg string, keysAndValues ...interface{}) { + logger.Fatalv(ctx, skip+1, msg, keysAndValues...) +} + +// FatalNoCtx logs a message and variadic key-value pairs. +func FatalNoCtx(msg string, keysAndValues ...interface{}) { + // context.Background() is a singleton and gets initialized once + Fatalv(context.Background(), 1, msg, keysAndValues...) +} + +// FatalfNoCtx uses fmt.Sprintf to log a templated message. +func FatalfNoCtx(fs string, args ...interface{}) { + Fatalv(context.Background(), 1, fmt.Sprintf(fs, args...)) +} + +// Error logs a message, the key-value pairs defined in contextFields from ctx, and variadic key-value pairs. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +func Error(ctx context.Context, msg string, keysAndValues ...interface{}) { + Errorv(ctx, 1, msg, keysAndValues...) +} + +// Errorf uses fmt.Sprintf to log a templated message and the key-value pairs defined in contextFields from ctx. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +func Errorf(ctx context.Context, fs string, args ...interface{}) { + Errorv(ctx, 1, fmt.Sprintf(fs, args...)) +} + +// Errorv logs a message, the key-value pairs defined in contextFields from ctx, and variadic key-value pairs. +// Caller is skipped by skip. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +func Errorv(ctx context.Context, skip int, msg string, keysAndValues ...interface{}) { + logger.Errorv(ctx, skip+1, msg, keysAndValues...) +} + +// ErrorNoCtx logs a message and variadic key-value pairs. +func ErrorNoCtx(msg string, keysAndValues ...interface{}) { + // context.Background() is a singleton and gets initialized once + Errorv(context.Background(), 1, msg, keysAndValues...) +} + +// ErrorfNoCtx uses fmt.Sprintf to log a templated message. +func ErrorfNoCtx(fs string, args ...interface{}) { + Errorv(context.Background(), 1, fmt.Sprintf(fs, args...)) +} + +// Error logs a message, the key-value pairs defined in contextFields from ctx, and variadic key-value pairs. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +// Returns an error with the given message for convenience +func ErrorAndReturn(ctx context.Context, msg string, keysAndValues ...interface{}) error { + return ErrorvAndReturn(ctx, 1, msg, keysAndValues...) +} + +// Errorf uses fmt.Errorf to construct an error and log its message +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +// Returns the error for convenicence +func ErrorfAndReturn(ctx context.Context, fs string, args ...interface{}) error { + err := fmt.Errorf(fs, args...) + Errorv(ctx, 1, err.Error()) + return err +} + +// Errorv logs a message, the key-value pairs defined in contextFields from ctx, and variadic key-value pairs. +// Caller is skipped by skip. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +// Returns an error with the given message for convenience +func ErrorvAndReturn(ctx context.Context, skip int, msg string, keysAndValues ...interface{}) error { + return logger.ErrorvAndReturn(ctx, skip+1, msg, keysAndValues...) +} + +// ErrorNoCtx logs a message and variadic key-value pairs. +// Returns an error with the given message for convenience +func ErrorNoCtxAndReturn(msg string, keysAndValues ...interface{}) error { + // context.Background() is a singleton and gets initialized once + return ErrorvAndReturn(context.Background(), 1, msg, keysAndValues...) +} + +func InjectTestLogger(testLogger *zap.SugaredLogger) { + logger = NewLoggerFromCore(testLogger) +} diff --git a/common/log/log_test.go b/common/log/log_test.go new file mode 100644 index 00000000..2258deda --- /dev/null +++ b/common/log/log_test.go @@ -0,0 +1,421 @@ +package log + +import ( + "context" + "os" + "time" + + "github.com/google/uuid" +) + +var ( + requestID uuid.UUID + ctx context.Context + TestConfig = Config{ + OutputPaths: []string{"stdout"}, + Level: DebugLevel, + } +) + +func setup() { + requestID, _ = uuid.Parse("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee") + ctx = WithRequestID(context.Background(), requestID) + logger = NewLogger(TestConfig) + logger.now = func() time.Time { + return time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC) + } +} + +func ExampleDebug() { + setup() + Debug(ctx, "Hello, world!") + Debug( + ctx, + "Hello, world!", + "foo", "bar", + "abc", 123, + "time", time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), + ) + // Output: + // {"level":"debug","msg":"Hello, world!","caller":"log_test.go:31","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} + // {"level":"debug","msg":"Hello, world!","caller":"log_test.go:32","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee","foo":"bar","abc":123,"time":"2000-01-02T00:00:00.000000000Z"} +} + +func ExampleInfo() { + setup() + Info(ctx, "Hello, world!") + Info( + ctx, + "Hello, world!", + "foo", "bar", + "abc", 123, + "time", time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), + ) + // Output: + // {"level":"info","msg":"Hello, world!","caller":"log_test.go:46","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} + // {"level":"info","msg":"Hello, world!","caller":"log_test.go:47","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee","foo":"bar","abc":123,"time":"2000-01-02T00:00:00.000000000Z"} +} + +func ExampleWarn() { + setup() + Warn(ctx, "Hello, world!") + Warn( + ctx, + "Hello, world!", + "foo", "bar", + "abc", 123, + "time", time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), + ) + // Output: + // {"level":"warn","msg":"Hello, world!","caller":"log_test.go:61","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} + // {"level":"warn","msg":"Hello, world!","caller":"log_test.go:62","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee","foo":"bar","abc":123,"time":"2000-01-02T00:00:00.000000000Z"} +} + +func ExampleError() { + setup() + Error(ctx, "Hello, world!") + Error( + ctx, + "Hello, world!", + "foo", "bar", + "abc", 123, + "time", time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), + ) + // Output: + // {"level":"error","msg":"Hello, world!","caller":"log_test.go:76","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} + // {"level":"error","msg":"Hello, world!","caller":"log_test.go:77","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee","foo":"bar","abc":123,"time":"2000-01-02T00:00:00.000000000Z"} +} + +func ExampleDebugf() { + setup() + Debugf(ctx, "Hello, %s!", "world") + // Output: + // {"level":"debug","msg":"Hello, world!","caller":"log_test.go:91","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} +} + +func ExampleInfof() { + setup() + Infof(ctx, "Hello, %s!", "world") + // Output: + // {"level":"info","msg":"Hello, world!","caller":"log_test.go:98","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} +} + +func ExampleWarnf() { + setup() + Warnf(ctx, "Hello, %s!", "world") + // Output: + // {"level":"warn","msg":"Hello, world!","caller":"log_test.go:105","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} +} + +func ExampleErrorf() { + setup() + Errorf(ctx, "Hello, %s!", "world") + // Output: + // {"level":"error","msg":"Hello, world!","caller":"log_test.go:112","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} +} + +func ExampleDebugv() { + setup() + Debugv(ctx, 0, "Hello, world!") + // Output: + // {"level":"debug","msg":"Hello, world!","caller":"log_test.go:119","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} +} + +func ExampleInfov() { + setup() + Infov(ctx, 0, "Hello, world!") + // Output: + // {"level":"info","msg":"Hello, world!","caller":"log_test.go:126","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} +} + +func ExampleWarnv() { + setup() + Warnv(ctx, 0, "Hello, world!") + // Output: + // {"level":"warn","msg":"Hello, world!","caller":"log_test.go:133","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} +} + +func ExampleErrorv() { + setup() + Errorv(ctx, 0, "Hello, world!") + // Output: + // {"level":"error","msg":"Hello, world!","caller":"log_test.go:140","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} +} + +func ExampleDebugNoCtx() { + setup() + DebugNoCtx("Hello, world!") + // Output: + // {"level":"debug","msg":"Hello, world!","caller":"log_test.go:147","ts":"2000-01-01T00:00:00.000000000Z"} +} + +func ExampleInfoNoCtx() { + setup() + InfoNoCtx("Hello, world!") + // Output: + // {"level":"info","msg":"Hello, world!","caller":"log_test.go:154","ts":"2000-01-01T00:00:00.000000000Z"} +} + +func ExampleWarnNoCtx() { + setup() + WarnNoCtx("Hello, world!") + // Output: + // {"level":"warn","msg":"Hello, world!","caller":"log_test.go:161","ts":"2000-01-01T00:00:00.000000000Z"} +} + +func ExampleErrorNoCtx() { + setup() + ErrorNoCtx("Hello, world!") + // Output: + // {"level":"error","msg":"Hello, world!","caller":"log_test.go:168","ts":"2000-01-01T00:00:00.000000000Z"} +} + +func Example_danglingKey() { + setup() + Info(context.Background(), "Hello, world!", "myDanglingKey") + // Output: + // {"level":"error","msg":"Ignored key without a value.","caller":"log_test.go:175","ts":"2000-01-01T00:00:00.000000000Z","ignored":"myDanglingKey"} + // {"level":"info","msg":"Hello, world!","caller":"log_test.go:175","ts":"2000-01-01T00:00:00.000000000Z"} +} + +func ExampleDebug_logger() { + setup() + logger.Debug(ctx, "Hello, world!") + logger.Debug( + ctx, + "Hello, world!", + "foo", "bar", + "abc", 123, + "time", time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), + ) + // Output: + // {"level":"debug","msg":"Hello, world!","caller":"log_test.go:183","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} + // {"level":"debug","msg":"Hello, world!","caller":"log_test.go:184","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee","foo":"bar","abc":123,"time":"2000-01-02T00:00:00.000000000Z"} +} + +func ExampleInfo_logger() { + setup() + logger.Info(ctx, "Hello, world!") + logger.Info( + ctx, + "Hello, world!", + "foo", "bar", + "abc", 123, + "time", time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), + ) + // Output: + // {"level":"info","msg":"Hello, world!","caller":"log_test.go:198","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} + // {"level":"info","msg":"Hello, world!","caller":"log_test.go:199","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee","foo":"bar","abc":123,"time":"2000-01-02T00:00:00.000000000Z"} +} + +func ExampleWarn_logger() { + setup() + logger.Warn(ctx, "Hello, world!") + logger.Warn( + ctx, + "Hello, world!", + "foo", "bar", + "abc", 123, + "time", time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), + ) + // Output: + // {"level":"warn","msg":"Hello, world!","caller":"log_test.go:213","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} + // {"level":"warn","msg":"Hello, world!","caller":"log_test.go:214","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee","foo":"bar","abc":123,"time":"2000-01-02T00:00:00.000000000Z"} +} + +func ExampleError_logger() { + setup() + logger.Error(ctx, "Hello, world!") + logger.Error( + ctx, + "Hello, world!", + "foo", "bar", + "abc", 123, + "time", time.Date(2000, 1, 2, 0, 0, 0, 0, time.UTC), + ) + // Output: + // {"level":"error","msg":"Hello, world!","caller":"log_test.go:228","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} + // {"level":"error","msg":"Hello, world!","caller":"log_test.go:229","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee","foo":"bar","abc":123,"time":"2000-01-02T00:00:00.000000000Z"} +} + +func ExampleDebugf_logger() { + setup() + logger.Debugf(ctx, "Hello, %s!", "world") + // Output: + // {"level":"debug","msg":"Hello, world!","caller":"log_test.go:243","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} +} + +func ExampleInfof_logger() { + setup() + logger.Infof(ctx, "Hello, %s!", "world") + // Output: + // {"level":"info","msg":"Hello, world!","caller":"log_test.go:250","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} +} + +func ExampleWarnf_logger() { + setup() + logger.Warnf(ctx, "Hello, %s!", "world") + // Output: + // {"level":"warn","msg":"Hello, world!","caller":"log_test.go:257","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} +} + +func ExampleErrorf_logger() { + setup() + logger.Errorf(ctx, "Hello, %s!", "world") + // Output: + // {"level":"error","msg":"Hello, world!","caller":"log_test.go:264","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} +} + +func ExampleDebugv_logger() { + setup() + logger.Debugv(ctx, 0, "Hello, world!") + // Output: + // {"level":"debug","msg":"Hello, world!","caller":"log_test.go:271","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} +} + +func ExampleInfov_logger() { + setup() + logger.Infov(ctx, 0, "Hello, world!") + // Output: + // {"level":"info","msg":"Hello, world!","caller":"log_test.go:278","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} +} + +func ExampleWarnv_logger() { + setup() + logger.Warnv(ctx, 0, "Hello, world!") + // Output: + // {"level":"warn","msg":"Hello, world!","caller":"log_test.go:285","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} +} + +func ExampleErrorv_logger() { + setup() + logger.Errorv(ctx, 0, "Hello, world!") + // Output: + // {"level":"error","msg":"Hello, world!","caller":"log_test.go:292","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} +} + +func ExampleDebugNoCtx_logger() { + setup() + logger.DebugNoCtx("Hello, world!") + // Output: + // {"level":"debug","msg":"Hello, world!","caller":"log_test.go:299","ts":"2000-01-01T00:00:00.000000000Z"} +} + +func ExampleInfoNoCtx_logger() { + setup() + logger.InfoNoCtx("Hello, world!") + // Output: + // {"level":"info","msg":"Hello, world!","caller":"log_test.go:306","ts":"2000-01-01T00:00:00.000000000Z"} +} + +func ExampleWarnNoCtx_logger() { + setup() + logger.WarnNoCtx("Hello, world!") + // Output: + // {"level":"warn","msg":"Hello, world!","caller":"log_test.go:313","ts":"2000-01-01T00:00:00.000000000Z"} +} + +func ExampleErrorNoCtx_logger() { + setup() + logger.ErrorNoCtx("Hello, world!") + // Output: + // {"level":"error","msg":"Hello, world!","caller":"log_test.go:320","ts":"2000-01-01T00:00:00.000000000Z"} +} + +func Example_level() { + setup() + logger = NewLogger(Config{ + OutputPaths: []string{"stdout"}, + Level: InfoLevel, + }) + logger.now = func() time.Time { + return time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC) + } + Debug(ctx, "Hello, world!") + Info(ctx, "Hello, world!") + // Output: + // {"level":"info","msg":"Hello, world!","caller":"log_test.go:335","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} +} + +func Example_envVarLogLevel() { + old := os.Getenv(LOG_LEVEL_ENV_VAR) + os.Setenv(LOG_LEVEL_ENV_VAR, "WARN") + setup() + Info(ctx, "Hello, world!") + Warn(ctx, "Hello, world!") + // Output: + // {"level":"warn","msg":"Hello, world!","caller":"log_test.go:345","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} + os.Setenv(LOG_LEVEL_ENV_VAR, old) +} + +func Example_defaultFields() { + setup() + logger = NewLoggerWithDefaultFields(Config{ + OutputPaths: []string{"stdout"}, + Level: InfoLevel, + }, []interface{}{"foo", "bar"}) + logger.now = func() time.Time { + return time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC) + } + logger.Info(ctx, "Hello, world!") + // Output: + // {"level":"info","msg":"Hello, world!","caller":"log_test.go:360","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee","foo":"bar"} +} + +func Example_defaultFieldsDanglingKey() { + setup() + logger = NewLoggerWithDefaultFields(Config{ + OutputPaths: []string{"stdout"}, + Level: InfoLevel, + }, []interface{}{"foo", "bar", "foobar"}) + logger.now = func() time.Time { + return time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC) + } + logger.Info(ctx, "Hello, world!") + // Output: + // {"level":"error","msg":"defaultFields contains a key without a value.","ignored":"foobar"} + // {"level":"info","msg":"Hello, world!","caller":"log_test.go:374","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee","foo":"bar"} +} + +func ExampleDebugfNoCtx() { + setup() + DebugfNoCtx("Hello, %s!", "world") + // Output: + // {"level":"debug","msg":"Hello, world!","caller":"log_test.go:382","ts":"2000-01-01T00:00:00.000000000Z"} +} + +func ExampleInfofNoCtx() { + setup() + InfofNoCtx("Hello, %s!", "world") + // Output: + // {"level":"info","msg":"Hello, world!","caller":"log_test.go:389","ts":"2000-01-01T00:00:00.000000000Z"} +} + +func ExampleWarnfNoCtx() { + setup() + WarnfNoCtx("Hello, %s!", "world") + // Output: + // {"level":"warn","msg":"Hello, world!","caller":"log_test.go:396","ts":"2000-01-01T00:00:00.000000000Z"} +} + +func ExampleErrorfNoCtx() { + setup() + ErrorfNoCtx("Hello, %s!", "world") + // Output: + // {"level":"error","msg":"Hello, world!","caller":"log_test.go:403","ts":"2000-01-01T00:00:00.000000000Z"} +} + +func ExampleSetLoggerConfig() { + setup() + SetLoggerConfig(Config{ + OutputPaths: TestConfig.OutputPaths, + Level: InfoLevel, + }) + logger.now = func() time.Time { + return time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC) + } + Debug(ctx, "Hello, world!") + Info(ctx, "Goodbye, world!") + // Output: + // {"level":"info","msg":"Goodbye, world!","caller":"log_test.go:418","ts":"2000-01-01T00:00:00.000000000Z","requestID":"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"} +} diff --git a/common/log/logger.go b/common/log/logger.go new file mode 100644 index 00000000..8bee7350 --- /dev/null +++ b/common/log/logger.go @@ -0,0 +1,407 @@ +package log + +import ( + "context" + "fmt" + "os" + "runtime" + "strings" + "time" + + "github.com/google/uuid" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + v23 "v.io/v23" + vcontext "v.io/v23/context" +) + +const ( + // DebugLevel logs are typically voluminous. + DebugLevel = zapcore.DebugLevel + // InfoLevel is the default logging priority. + InfoLevel = zapcore.InfoLevel + // WarnLevel logs are more important than Info, but don't need individual human review. + WarnLevel = zapcore.WarnLevel + // ErrorLevel logs are high-priority. + // Applications running smoothly shouldn't generate any error-level logs. + ErrorLevel = zapcore.ErrorLevel + // FatalLevel logs a message, then calls os.Exit(1). + FatalLevel = zapcore.FatalLevel + // RFC3339TrailingNano is RFC3339 format with trailing nanoseconds precision. + RFC3339TrailingNano = "2006-01-02T15:04:05.000000000Z07:00" + // LOG_LEVEL_ENV_VAR is the environment variable name used to set logging level. + LOG_LEVEL_ENV_VAR = "LOG_LEVEL" +) + +// contextFields is a list of context key-value pairs to be logged. +// Key is the name of the field. +// Value is the context key. +var contextFields = map[string]interface{}{ + "requestID": RequestIDContextKey, +} + +var logLvls = map[string]zapcore.Level{ + "debug": DebugLevel, + "DEBUG": DebugLevel, + "info": InfoLevel, + "INFO": InfoLevel, + "warn": WarnLevel, + "WARN": WarnLevel, + "error": ErrorLevel, + "ERROR": ErrorLevel, + "fatal": FatalLevel, + "FATAL": FatalLevel, +} + +type Logger struct { + coreLogger *zap.SugaredLogger + // Additional information that may be unique to each service (e.g. order UUID for Ensemble orders) + defaultFields []interface{} + levelToLogger map[zapcore.Level]func(msg string, keysAndValues ...interface{}) + now func() time.Time +} + +type Config struct { + OutputPaths []string + // note: setting the environment variable LOG_LEVEL will override Config.Level + Level zapcore.Level +} + +func setDefaultLogLevelsMap(logger *Logger) *Logger { + logger.levelToLogger = map[zapcore.Level]func(msg string, keysAndValues ...interface{}){ + DebugLevel: logger.coreLogger.Debugw, + InfoLevel: logger.coreLogger.Infow, + WarnLevel: logger.coreLogger.Warnw, + ErrorLevel: logger.coreLogger.Errorw, + FatalLevel: logger.coreLogger.Fatalw, + } + return logger +} + +func NewLogger(config Config) *Logger { + return NewLoggerWithDefaultFields(config, []interface{}{}) +} + +// NewLogger creates a new logger instance. +// defaultFields is a list of key-value pairs to be included in every log message. +func NewLoggerWithDefaultFields(config Config, defaultFields []interface{}) *Logger { + if len(defaultFields)%2 != 0 { + danglingKey := defaultFields[len(defaultFields)-1] + defaultFields = defaultFields[:len(defaultFields)-1] + errLogger := NewLogger(config) + errLog := []interface{}{ + "ignored", danglingKey, + } + logErr := errLogger.levelToLogger[ErrorLevel] + logErr("defaultFields contains a key without a value.", errLog...) + } + l := Logger{ + coreLogger: mustBuildLogger(config, zap.AddCallerSkip(2)), + defaultFields: defaultFields, + now: time.Now, + } + return setDefaultLogLevelsMap(&l) +} + +// NewLoggerFromCore allows the caller to pass in a zap.SugaredLogger into the logger. +// This allows one to make unit test assertions about logs. +func NewLoggerFromCore(lager *zap.SugaredLogger) *Logger { + l := Logger{ + coreLogger: lager, + now: time.Now, + } + + return setDefaultLogLevelsMap(&l) +} + +func (l *Logger) log(ctx context.Context, level zapcore.Level, callerSkip int, msg string, keysAndValues []interface{}) { + t := l.now() + // Add default fields + keysAndValues = append(keysAndValues, l.defaultFields...) + // If there is a dangling key (i.e. odd length keysAndValues), log an error and then + // drop the dangling key and log original message. + if len(keysAndValues)%2 != 0 { + danglingKey := keysAndValues[len(keysAndValues)-1] + keysAndValues = keysAndValues[:len(keysAndValues)-1] + errLog := withDefaultFields(ctx, callerSkip, t, "ignored", danglingKey) + logErr := l.levelToLogger[ErrorLevel] + logErr("Ignored key without a value.", errLog...) + } + // Add caller and timestamp fields + prefix := withDefaultFields(ctx, callerSkip, t) + // Add context logged fields + if ctx != nil { + for k, v := range contextFields { + if ctxVal := ctx.Value(v); ctxVal != nil { + prefix = append(prefix, k, ctxVal) + } + } + } + keysAndValues = append(prefix, keysAndValues...) + // Log at the appropriate level + logLevel := l.levelToLogger[level] + logLevel(msg, keysAndValues...) +} + +// Debug logs a message, the key-value pairs defined in contextFields from ctx, and variadic key-value pairs. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +func (l *Logger) Debug(ctx context.Context, msg string, keysAndValues ...interface{}) { + l.Debugv(ctx, 1, msg, keysAndValues...) +} + +// Debugf uses fmt.Sprintf to log a templated message and the key-value pairs defined in contextFields from ctx. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +func (l *Logger) Debugf(ctx context.Context, fs string, args ...interface{}) { + l.Debugv(ctx, 1, fmt.Sprintf(fs, args...)) +} + +// Debugv logs a message, the key-value pairs defined in contextFields from ctx, and variadic key-value pairs. +// Caller stack field is skipped by skip levels. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +func (l *Logger) Debugv(ctx context.Context, skip int, msg string, keysAndValues ...interface{}) { + l.log(ctx, DebugLevel, skip, msg, keysAndValues) +} + +// DebugNoCtx logs a message and variadic key-value pairs. +func (l *Logger) DebugNoCtx(msg string, keysAndValues ...interface{}) { + l.Debugv(context.Background(), 1, msg, keysAndValues...) +} + +// Info logs a message, the key-value pairs defined in contextFields from ctx, and variadic key-value pairs. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +func (l *Logger) Info(ctx context.Context, msg string, keysAndValues ...interface{}) { + l.Infov(ctx, 1, msg, keysAndValues...) +} + +// Infof uses fmt.Sprintf to log a templated message and the key-value pairs defined in contextFields from ctx. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +func (l *Logger) Infof(ctx context.Context, fs string, args ...interface{}) { + l.Infov(ctx, 1, fmt.Sprintf(fs, args...)) +} + +// Infov logs a message, the key-value pairs defined in contextFields from ctx, and variadic key-value pairs. +// Caller stack field is skipped by skip levels. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +func (l *Logger) Infov(ctx context.Context, skip int, msg string, keysAndValues ...interface{}) { + l.log(ctx, InfoLevel, skip, msg, keysAndValues) +} + +// InfoNoCtx logs a message and variadic key-value pairs. +func (l *Logger) InfoNoCtx(msg string, keysAndValues ...interface{}) { + l.Infov(context.Background(), 1, msg, keysAndValues...) +} + +// Warn logs a message, the key-value pairs defined in contextFields from ctx, and variadic key-value pairs. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +func (l *Logger) Warn(ctx context.Context, msg string, keysAndValues ...interface{}) { + l.Warnv(ctx, 1, msg, keysAndValues...) +} + +// Warnf uses fmt.Sprintf to log a templated message and the key-value pairs defined in contextFields from ctx. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +func (l *Logger) Warnf(ctx context.Context, fs string, args ...interface{}) { + l.Warnv(ctx, 1, fmt.Sprintf(fs, args...)) +} + +// Warnv logs a message, the key-value pairs defined in contextFields from ctx, and variadic key-value pairs. +// Caller stack field is skipped by skip levels. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +func (l *Logger) Warnv(ctx context.Context, skip int, msg string, keysAndValues ...interface{}) { + l.log(ctx, WarnLevel, skip, msg, keysAndValues) +} + +// WarnNoCtx logs a message and variadic key-value pairs. +func (l *Logger) WarnNoCtx(msg string, keysAndValues ...interface{}) { + l.Warnv(context.Background(), 1, msg, keysAndValues...) +} + +// Fatal logs a message, the key-value pairs defined in contextFields from ctx, and variadic key-value pairs. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +func (l *Logger) Fatal(ctx context.Context, msg string, keysAndValues ...interface{}) { + l.Fatalv(ctx, 1, msg, keysAndValues...) +} + +// Fatalf uses fmt.Sprintf to log a templated message and the key-value pairs defined in contextFields from ctx. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +func (l *Logger) Fatalf(ctx context.Context, fs string, args ...interface{}) { + l.Fatalv(ctx, 1, fmt.Sprintf(fs, args...)) +} + +// Fatalv logs a message, the key-value pairs defined in contextFields from ctx, and variadic key-value pairs. +// Caller stack field is skipped by skip levels. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +func (l *Logger) Fatalv(ctx context.Context, skip int, msg string, keysAndValues ...interface{}) { + l.log(ctx, FatalLevel, skip, msg, keysAndValues) +} + +// FatalNoCtx logs a message and variadic key-value pairs. +func (l *Logger) FatalNoCtx(msg string, keysAndValues ...interface{}) { + l.Fatalv(context.Background(), 1, msg, keysAndValues...) +} + +// Error logs a message, the key-value pairs defined in contextFields from ctx, and variadic key-value pairs. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +func (l *Logger) Error(ctx context.Context, msg string, keysAndValues ...interface{}) { + l.Errorv(ctx, 1, msg, keysAndValues...) +} + +// Errorf uses fmt.Sprintf to log a templated message and the key-value pairs defined in contextFields from ctx. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +func (l *Logger) Errorf(ctx context.Context, fs string, args ...interface{}) { + l.Errorv(ctx, 1, fmt.Sprintf(fs, args...)) +} + +// Errorv logs a message, the key-value pairs defined in contextFields from ctx, and variadic key-value pairs. +// Caller stack field is skipped by skip levels. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +func (l *Logger) Errorv(ctx context.Context, skip int, msg string, keysAndValues ...interface{}) { + l.log(ctx, ErrorLevel, skip, msg, keysAndValues) +} + +// ErrorNoCtx logs a message and variadic key-value pairs. +func (l *Logger) ErrorNoCtx(msg string, keysAndValues ...interface{}) { + l.Errorv(context.Background(), 1, msg, keysAndValues...) +} + +// ErrorAndReturn logs a message, the key-value pairs defined in contextFields from ctx, and variadic key-value pairs. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +// Returns a new error constructed from the message. +func (l *Logger) ErrorAndReturn(ctx context.Context, msg string, keysAndValues ...interface{}) error { + return l.ErrorvAndReturn(ctx, 1, msg, keysAndValues...) +} + +// ErrorfAndReturn uses fmt.Errorf to construct an error from the provided arguments. +// It then logs the error message, along with data from the context. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +// Returns the error resulting from invoking fmt.Errorf with the provided arguments. +func (l *Logger) ErrorfAndReturn(ctx context.Context, fs string, args ...interface{}) error { + err := fmt.Errorf(fs, args...) + l.Errorv(ctx, 1, err.Error()) + return err +} + +// Errorv logs a message, the key-value pairs defined in contextFields from ctx, and variadic key-value pairs. +// Caller is skipped by skip. +// If ctx is nil, all fields from contextFields will be omitted. +// If ctx does not contain a key in contextFields, that field will be omitted. +// Returns a new error constructed from the message. +func (l *Logger) ErrorvAndReturn(ctx context.Context, skip int, msg string, keysAndValues ...interface{}) error { + l.Errorv(ctx, skip+1, msg, keysAndValues...) + return fmt.Errorf(msg) +} + +// ErrorNoCtxAndReturn logs a message and variadic key-value pairs. +// Returns a new error constructed from the message. +func (l *Logger) ErrorNoCtxAndReturn(msg string, keysAndValues ...interface{}) error { + // context.Background() is a singleton and gets initialized once + return l.ErrorvAndReturn(context.Background(), 1, msg, keysAndValues...) +} + +// rfc3339TrailingNanoTimeEncoder serializes a time.Time to an RFC3339-formatted string +// with trailing nanosecond precision. +func rfc3339TrailingNanoTimeEncoder(t time.Time, enc zapcore.PrimitiveArrayEncoder) { + enc.AppendString(t.Format(RFC3339TrailingNano)) +} + +func mustBuildLogger(config Config, opts ...zap.Option) *zap.SugaredLogger { + zapLogger, err := newConfig(config).Build(opts...) + if err != nil { + panic(err) + } + return zapLogger.Sugar() +} + +// newEncoderConfig is similar to Zap's NewProductionConfig with a few modifications +// to better fit our needs. +func newEncoderConfig() zapcore.EncoderConfig { + return zapcore.EncoderConfig{ + LevelKey: "level", + NameKey: "logger", + MessageKey: "msg", + LineEnding: zapcore.DefaultLineEnding, + EncodeLevel: zapcore.LowercaseLevelEncoder, + EncodeTime: rfc3339TrailingNanoTimeEncoder, + EncodeDuration: zapcore.SecondsDurationEncoder, + EncodeCaller: zapcore.ShortCallerEncoder, + } +} + +// newConfig is similar to Zap's NewProductionConfig with a few modifications +// to better fit our needs. +func newConfig(override Config) zap.Config { + // Default config + config := zap.Config{ + Level: zap.NewAtomicLevelAt(zap.DebugLevel), + Development: false, + Sampling: &zap.SamplingConfig{ + Initial: 100, + Thereafter: 100, + }, + Encoding: "json", + EncoderConfig: newEncoderConfig(), + OutputPaths: []string{"stderr"}, + ErrorOutputPaths: []string{"stderr"}, + } + // config overrides + if override.OutputPaths != nil { + config.OutputPaths = override.OutputPaths + } + if override.Level != zapcore.DebugLevel { + config.Level = zap.NewAtomicLevelAt(override.Level) + } + // LOG_LEVEL environment variable override + // Note: setting the environment variable LOG_LEVEL will override Config.Level + if logLvl, ok := logLvls[os.Getenv(LOG_LEVEL_ENV_VAR)]; ok { + config.Level = zap.NewAtomicLevelAt(logLvl) + } + return config +} + +func withDefaultFields(ctx context.Context, callerSkip int, t time.Time, + keysAndValues ...interface{}) []interface{} { + defaultFields := []interface{}{ + "caller", getCaller(callerSkip), + "ts", t, + } + if ctx != nil { + if vctx, ok := ctx.(*vcontext.T); ok { + if requestID := v23.GetRequestID(vctx); requestID != uuid.Nil { + defaultFields = append(defaultFields, "v23RequestID", requestID) + } + } + } + return append(defaultFields, keysAndValues...) +} + +func getCaller(skip int) string { + skipOffset := 5 + pc := make([]uintptr, 1) + numFrames := runtime.Callers(skip+skipOffset, pc) + if numFrames < 1 { + return "" + } + frame, _ := runtime.CallersFrames(pc).Next() + if frame.PC == 0 { + return "" + } + parts := strings.Split(frame.File, "/") + file := parts[len(parts)-1] + return fmt.Sprintf("%s:%d", file, frame.Line) +} diff --git a/common/log/loginterfaces/logger.go b/common/log/loginterfaces/logger.go new file mode 100644 index 00000000..52f017d0 --- /dev/null +++ b/common/log/loginterfaces/logger.go @@ -0,0 +1,24 @@ +package loginterfaces + +import ( + "context" +) + +type Logger interface { + Debug(ctx context.Context, msg string, keysAndValues ...interface{}) + Debugf(ctx context.Context, fs string, args ...interface{}) + Debugv(ctx context.Context, skip int, msg string, keysAndValues ...interface{}) + DebugNoCtx(msg string, keysAndValues ...interface{}) + Info(ctx context.Context, msg string, keysAndValues ...interface{}) + Infof(ctx context.Context, fs string, args ...interface{}) + Infov(ctx context.Context, skip int, msg string, keysAndValues ...interface{}) + InfoNoCtx(msg string, keysAndValues ...interface{}) + Warn(ctx context.Context, msg string, keysAndValues ...interface{}) + Warnf(ctx context.Context, fs string, args ...interface{}) + Warnv(ctx context.Context, skip int, msg string, keysAndValues ...interface{}) + WarnNoCtx(msg string, keysAndValues ...interface{}) + Error(ctx context.Context, msg string, keysAndValues ...interface{}) + Errorf(ctx context.Context, fs string, args ...interface{}) + Errorv(ctx context.Context, skip int, msg string, keysAndValues ...interface{}) + ErrorNoCtx(msg string, keysAndValues ...interface{}) +} diff --git a/common/log/loginterfaces/mocks/Logger.go b/common/log/loginterfaces/mocks/Logger.go new file mode 100644 index 00000000..aeb7fac8 --- /dev/null +++ b/common/log/loginterfaces/mocks/Logger.go @@ -0,0 +1,142 @@ +// Code generated by mockery v0.0.0-dev. DO NOT EDIT. + +package mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" +) + +// Logger is an autogenerated mock type for the Logger type +type Logger struct { + mock.Mock +} + +// Debug provides a mock function with given fields: ctx, msg, keysAndValues +func (_m *Logger) Debug(ctx context.Context, msg string, keysAndValues ...interface{}) { + var _ca []interface{} + _ca = append(_ca, ctx, msg) + _ca = append(_ca, keysAndValues...) + _m.Called(_ca...) +} + +// DebugNoCtx provides a mock function with given fields: msg, keysAndValues +func (_m *Logger) DebugNoCtx(msg string, keysAndValues ...interface{}) { + var _ca []interface{} + _ca = append(_ca, msg) + _ca = append(_ca, keysAndValues...) + _m.Called(_ca...) +} + +// Debugf provides a mock function with given fields: ctx, fs, args +func (_m *Logger) Debugf(ctx context.Context, fs string, args ...interface{}) { + var _ca []interface{} + _ca = append(_ca, ctx, fs) + _ca = append(_ca, args...) + _m.Called(_ca...) +} + +// Debugv provides a mock function with given fields: ctx, skip, msg, keysAndValues +func (_m *Logger) Debugv(ctx context.Context, skip int, msg string, keysAndValues ...interface{}) { + var _ca []interface{} + _ca = append(_ca, ctx, skip, msg) + _ca = append(_ca, keysAndValues...) + _m.Called(_ca...) +} + +// Error provides a mock function with given fields: ctx, msg, keysAndValues +func (_m *Logger) Error(ctx context.Context, msg string, keysAndValues ...interface{}) { + var _ca []interface{} + _ca = append(_ca, ctx, msg) + _ca = append(_ca, keysAndValues...) + _m.Called(_ca...) +} + +// ErrorNoCtx provides a mock function with given fields: msg, keysAndValues +func (_m *Logger) ErrorNoCtx(msg string, keysAndValues ...interface{}) { + var _ca []interface{} + _ca = append(_ca, msg) + _ca = append(_ca, keysAndValues...) + _m.Called(_ca...) +} + +// Errorf provides a mock function with given fields: ctx, fs, args +func (_m *Logger) Errorf(ctx context.Context, fs string, args ...interface{}) { + var _ca []interface{} + _ca = append(_ca, ctx, fs) + _ca = append(_ca, args...) + _m.Called(_ca...) +} + +// Errorv provides a mock function with given fields: ctx, skip, msg, keysAndValues +func (_m *Logger) Errorv(ctx context.Context, skip int, msg string, keysAndValues ...interface{}) { + var _ca []interface{} + _ca = append(_ca, ctx, skip, msg) + _ca = append(_ca, keysAndValues...) + _m.Called(_ca...) +} + +// Info provides a mock function with given fields: ctx, msg, keysAndValues +func (_m *Logger) Info(ctx context.Context, msg string, keysAndValues ...interface{}) { + var _ca []interface{} + _ca = append(_ca, ctx, msg) + _ca = append(_ca, keysAndValues...) + _m.Called(_ca...) +} + +// InfoNoCtx provides a mock function with given fields: msg, keysAndValues +func (_m *Logger) InfoNoCtx(msg string, keysAndValues ...interface{}) { + var _ca []interface{} + _ca = append(_ca, msg) + _ca = append(_ca, keysAndValues...) + _m.Called(_ca...) +} + +// Infof provides a mock function with given fields: ctx, fs, args +func (_m *Logger) Infof(ctx context.Context, fs string, args ...interface{}) { + var _ca []interface{} + _ca = append(_ca, ctx, fs) + _ca = append(_ca, args...) + _m.Called(_ca...) +} + +// Infov provides a mock function with given fields: ctx, skip, msg, keysAndValues +func (_m *Logger) Infov(ctx context.Context, skip int, msg string, keysAndValues ...interface{}) { + var _ca []interface{} + _ca = append(_ca, ctx, skip, msg) + _ca = append(_ca, keysAndValues...) + _m.Called(_ca...) +} + +// Warn provides a mock function with given fields: ctx, msg, keysAndValues +func (_m *Logger) Warn(ctx context.Context, msg string, keysAndValues ...interface{}) { + var _ca []interface{} + _ca = append(_ca, ctx, msg) + _ca = append(_ca, keysAndValues...) + _m.Called(_ca...) +} + +// WarnNoCtx provides a mock function with given fields: msg, keysAndValues +func (_m *Logger) WarnNoCtx(msg string, keysAndValues ...interface{}) { + var _ca []interface{} + _ca = append(_ca, msg) + _ca = append(_ca, keysAndValues...) + _m.Called(_ca...) +} + +// Warnf provides a mock function with given fields: ctx, fs, args +func (_m *Logger) Warnf(ctx context.Context, fs string, args ...interface{}) { + var _ca []interface{} + _ca = append(_ca, ctx, fs) + _ca = append(_ca, args...) + _m.Called(_ca...) +} + +// Warnv provides a mock function with given fields: ctx, skip, msg, keysAndValues +func (_m *Logger) Warnv(ctx context.Context, skip int, msg string, keysAndValues ...interface{}) { + var _ca []interface{} + _ca = append(_ca, ctx, skip, msg) + _ca = append(_ca, keysAndValues...) + _m.Called(_ca...) +} diff --git a/compress/libdeflate/libdeflate.go b/compress/libdeflate/libdeflate.go index feef97aa..2bb91a86 100644 --- a/compress/libdeflate/libdeflate.go +++ b/compress/libdeflate/libdeflate.go @@ -5,14 +5,12 @@ package libdeflate import ( + "compress/gzip" "encoding/binary" "errors" "fmt" "hash/crc32" "io" - - "github.com/grailbio/base/unsafe" - "github.com/klauspost/compress/gzip" ) // This is a slightly modified version of klauspost/compress/gzip/gzip.go , and @@ -220,7 +218,7 @@ func (z *Writer) Write(p []byte) (int, error) { z.buf = make([]byte, z.bufCap) } else { // No need to zero-reinitialize. - unsafe.ExtendBytes(&z.buf, z.bufCap) + z.buf = z.buf[:z.bufCap] } } else if len(z.buf) > z.bufCap { // Likely to be irrelevant, but may as well maintain this invariant @@ -275,8 +273,6 @@ func (z *Writer) Write(p []byte) (int, error) { z.size += uint32(len(p)) z.digest = crc32.Update(z.digest, crc32.IEEETable, p) - // ss, _ := fmt.Printf("libdeflate: %d %d %d %d\n", len(p), z.bufPos, z.bufCap-8, cap(z.buf)) - // panic(ss) n := z.compressor.Compress(z.buf[z.bufPos:z.bufCap-8], p) z.bufPos += n if n == 0 { diff --git a/compress/libdeflate/libdeflate_cgo.go b/compress/libdeflate/libdeflate_cgo.go index f578dc73..e776f83c 100644 --- a/compress/libdeflate/libdeflate_cgo.go +++ b/compress/libdeflate/libdeflate_cgo.go @@ -2,7 +2,8 @@ // Use of this source code is governed by the Apache-2.0 // license that can be found in the LICENSE file. -// +build cgo +//go:build cgo || !arm64 +// +build cgo !arm64 package libdeflate @@ -44,12 +45,39 @@ func (dd *Decompressor) Init() error { // decompressed data is returned on success (it may be smaller than // len(outData)). func (dd *Decompressor) Decompress(outData, inData []byte) (int, error) { + // Tolerate zero-length blocks on input, even though we don't on output. + // (Can be relevant when a BGZF file was formed by raw byte concatenation of + // smaller BGZF files.) + // Note that we can't use the usual *reflect.SliceHeader replacement for + // unsafe.Pointer(&inData[0]): that produces a "cgo argument has Go pointer + // to Go pointer" compile error. + if len(inData) == 0 { + return 0, nil + } var outLen C.size_t errcode := C.libdeflate_deflate_decompress( dd.cobj, unsafe.Pointer(&inData[0]), C.size_t(len(inData)), unsafe.Pointer(&outData[0]), C.size_t(len(outData)), &outLen) if errcode != C.LIBDEFLATE_SUCCESS { - return 0, fmt.Errorf("libdeflate: deflate_decompress() error code %d", errcode) + return 0, fmt.Errorf("libdeflate: libdeflate_deflate_decompress() error code %d", errcode) + } + return int(outLen), nil +} + +// GzipDecompress performs gzip decompression on a byte slice. outData[] must +// be large enough to fit the decompressed data. Byte count of the +// decompressed data is returned on success (it may be smaller than +// len(outData)). +func (dd *Decompressor) GzipDecompress(outData, inData []byte) (int, error) { + if len(inData) == 0 { + return 0, nil + } + var outLen C.size_t + errcode := C.libdeflate_gzip_decompress( + dd.cobj, unsafe.Pointer(&inData[0]), C.size_t(len(inData)), + unsafe.Pointer(&outData[0]), C.size_t(len(outData)), &outLen) + if errcode != C.LIBDEFLATE_SUCCESS { + return 0, fmt.Errorf("libdeflate: libdeflate_gzip_decompress() error code %d", errcode) } return int(outLen), nil } @@ -83,8 +111,13 @@ func (cc *Compressor) Init(compressionLevel int) error { // Compress performs raw DEFLATE compression on a byte slice. outData[] must // be large enough to fit the compressed data. Byte count of the compressed -// data is returned on success; zero is currently returned on failure. +// data is returned on success. +// Zero is currently returned on failure. A side effect is that inData cannot +// be length zero; this function will panic or crash if it is. func (cc *Compressor) Compress(outData, inData []byte) int { + // We *want* to crash on length-zero (that implies an error in the calling + // code, we don't want to be writing zero-length BGZF blocks without knowing + // about it), so we intentionally exclude the len(inData) == 0 check. outLen := int(C.libdeflate_deflate_compress( cc.cobj, unsafe.Pointer(&inData[0]), C.size_t(len(inData)), unsafe.Pointer(&outData[0]), C.size_t(len(outData)))) diff --git a/compress/libdeflate/libdeflate_nocgo.go b/compress/libdeflate/libdeflate_nocgo.go index 927280f2..18cbf5ff 100644 --- a/compress/libdeflate/libdeflate_nocgo.go +++ b/compress/libdeflate/libdeflate_nocgo.go @@ -2,18 +2,19 @@ // Use of this source code is governed by the Apache-2.0 // license that can be found in the LICENSE file. -// +build !cgo +//go:build !cgo || arm64 +// +build !cgo arm64 package libdeflate -// Fall back on the pure-go klauspost/compress/flate package if cgo support is +// Fall back on the pure-go compress/flate package if cgo support is // unavailable, to make it safe to include this package unconditionally. import ( "bytes" + "compress/flate" + "compress/gzip" "io" - - "github.com/klauspost/compress/flate" ) type Decompressor struct{} @@ -54,6 +55,40 @@ func (dd *Decompressor) Decompress(outData, inData []byte) (int, error) { return n, err } +// GzipDecompress performs gzip decompression on a byte slice. outData[] must +// be large enough to fit the decompressed data. Byte count of the +// decompressed data is returned on success (it may be smaller than +// len(outData)). +func (dd *Decompressor) GzipDecompress(outData, inData []byte) (int, error) { + dataReader := bytes.NewReader(inData) + actualDecompressor, err := gzip.NewReader(dataReader) + if err != nil { + return 0, err + } + // Copy of readToEOF() in github.com/biogo/hts/bgzf/cache.go. + n := 0 + outDataMax := len(outData) + for err == nil && n < outDataMax { + var nn int + nn, err = actualDecompressor.Read(outData[n:]) + n += nn + } + switch { + case err == io.EOF: + return n, nil + case n == outDataMax && err == nil: + var dummy [1]byte + _, err = actualDecompressor.Read(dummy[:]) + if err == nil { + return 0, io.ErrShortBuffer + } + if err == io.EOF { + err = nil + } + } + return n, err +} + func (dd *Decompressor) Cleanup() { } @@ -68,10 +103,15 @@ func (cc *Compressor) Init(compressionLevel int) error { // Compress performs raw DEFLATE compression on a byte slice. outData[] must // be large enough to fit the compressed data. Byte count of the compressed -// data is returned on success; zero is currently returned on failure. +// data is returned on success. +// Zero is currently returned on failure. A side effect is that inData cannot +// be length zero; this function will panic or crash if it is. func (cc *Compressor) Compress(outData, inData []byte) int { // I suspect this currently makes a few unnecessary allocations and copies; // can optimize later. + if len(inData) == 0 { + panic("libdeflate.Compress: zero-length inData") + } var buf bytes.Buffer actualCompressor, err := flate.NewWriter(&buf, cc.clvl) if err != nil { diff --git a/compress/libdeflate/programs/benchmark.c b/compress/libdeflate/programs/benchmark.c deleted file mode 100644 index 51fd5b9e..00000000 --- a/compress/libdeflate/programs/benchmark.c +++ /dev/null @@ -1,634 +0,0 @@ -/* - * benchmark.c - a compression testing and benchmark program - * - * Copyright 2016 Eric Biggers - * - * Permission is hereby granted, free of charge, to any person - * obtaining a copy of this software and associated documentation - * files (the "Software"), to deal in the Software without - * restriction, including without limitation the rights to use, - * copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the - * Software is furnished to do so, subject to the following - * conditions: - * - * The above copyright notice and this permission notice shall be - * included in all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES - * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND - * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT - * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, - * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR - * OTHER DEALINGS IN THE SOFTWARE. - */ - -#include /* for comparison purposes */ - -#include "prog_util.h" - -static const tchar *const optstring = T("1::2::3::4::5::6::7::8::9::C:D:ghs:VYZz"); - -enum wrapper { - NO_WRAPPER, - ZLIB_WRAPPER, - GZIP_WRAPPER, -}; - -struct compressor { - int level; - enum wrapper wrapper; - const struct engine *engine; - void *private; -}; - -struct decompressor { - enum wrapper wrapper; - const struct engine *engine; - void *private; -}; - -struct engine { - const tchar *name; - - bool (*init_compressor)(struct compressor *); - size_t (*compress)(struct compressor *, const void *, size_t, - void *, size_t); - void (*destroy_compressor)(struct compressor *); - - bool (*init_decompressor)(struct decompressor *); - bool (*decompress)(struct decompressor *, const void *, size_t, - void *, size_t); - void (*destroy_decompressor)(struct decompressor *); -}; - -/******************************************************************************/ - -static bool -libdeflate_engine_init_compressor(struct compressor *c) -{ - c->private = alloc_compressor(c->level); - return c->private != NULL; -} - -static size_t -libdeflate_engine_compress(struct compressor *c, const void *in, - size_t in_nbytes, void *out, size_t out_nbytes_avail) -{ - switch (c->wrapper) { - case ZLIB_WRAPPER: - return libdeflate_zlib_compress(c->private, in, in_nbytes, - out, out_nbytes_avail); - case GZIP_WRAPPER: - return libdeflate_gzip_compress(c->private, in, in_nbytes, - out, out_nbytes_avail); - default: - return libdeflate_deflate_compress(c->private, in, in_nbytes, - out, out_nbytes_avail); - } -} - -static void -libdeflate_engine_destroy_compressor(struct compressor *c) -{ - libdeflate_free_compressor(c->private); -} - -static bool -libdeflate_engine_init_decompressor(struct decompressor *d) -{ - d->private = alloc_decompressor(); - return d->private != NULL; -} - -static bool -libdeflate_engine_decompress(struct decompressor *d, const void *in, - size_t in_nbytes, void *out, size_t out_nbytes) -{ - switch (d->wrapper) { - case ZLIB_WRAPPER: - return !libdeflate_zlib_decompress(d->private, in, in_nbytes, - out, out_nbytes, NULL); - case GZIP_WRAPPER: - return !libdeflate_gzip_decompress(d->private, in, in_nbytes, - out, out_nbytes, NULL); - default: - return !libdeflate_deflate_decompress(d->private, in, in_nbytes, - out, out_nbytes, NULL); - } -} - -static void -libdeflate_engine_destroy_decompressor(struct decompressor *d) -{ - libdeflate_free_decompressor(d->private); -} - -static const struct engine libdeflate_engine = { - .name = T("libdeflate"), - - .init_compressor = libdeflate_engine_init_compressor, - .compress = libdeflate_engine_compress, - .destroy_compressor = libdeflate_engine_destroy_compressor, - - .init_decompressor = libdeflate_engine_init_decompressor, - .decompress = libdeflate_engine_decompress, - .destroy_decompressor = libdeflate_engine_destroy_decompressor, -}; - -/******************************************************************************/ - -static int -get_libz_window_bits(enum wrapper wrapper) -{ - const int windowBits = 15; - switch (wrapper) { - case ZLIB_WRAPPER: - return windowBits; - case GZIP_WRAPPER: - return windowBits + 16; - default: - return -windowBits; - } -} - -static bool -libz_engine_init_compressor(struct compressor *c) -{ - z_stream *z; - - if (c->level > 9) { - msg("libz only supports up to compression level 9"); - return false; - } - - z = xmalloc(sizeof(*z)); - if (z == NULL) - return false; - - z->next_in = NULL; - z->avail_in = 0; - z->zalloc = NULL; - z->zfree = NULL; - z->opaque = NULL; - if (deflateInit2(z, c->level, Z_DEFLATED, - get_libz_window_bits(c->wrapper), - 8, Z_DEFAULT_STRATEGY) != Z_OK) - { - msg("unable to initialize deflater"); - free(z); - return false; - } - - c->private = z; - return true; -} - -static size_t -libz_engine_compress(struct compressor *c, const void *in, size_t in_nbytes, - void *out, size_t out_nbytes_avail) -{ - z_stream *z = c->private; - - deflateReset(z); - - z->next_in = (void *)in; - z->avail_in = in_nbytes; - z->next_out = out; - z->avail_out = out_nbytes_avail; - - if (deflate(z, Z_FINISH) != Z_STREAM_END) - return 0; - - return out_nbytes_avail - z->avail_out; -} - -static void -libz_engine_destroy_compressor(struct compressor *c) -{ - z_stream *z = c->private; - - deflateEnd(z); - free(z); -} - -static bool -libz_engine_init_decompressor(struct decompressor *d) -{ - z_stream *z; - - z = xmalloc(sizeof(*z)); - if (z == NULL) - return false; - - z->next_in = NULL; - z->avail_in = 0; - z->zalloc = NULL; - z->zfree = NULL; - z->opaque = NULL; - if (inflateInit2(z, get_libz_window_bits(d->wrapper)) != Z_OK) { - msg("unable to initialize inflater"); - free(z); - return false; - } - - d->private = z; - return true; -} - -static bool -libz_engine_decompress(struct decompressor *d, const void *in, size_t in_nbytes, - void *out, size_t out_nbytes) -{ - z_stream *z = d->private; - - inflateReset(z); - - z->next_in = (void *)in; - z->avail_in = in_nbytes; - z->next_out = out; - z->avail_out = out_nbytes; - - return inflate(z, Z_FINISH) == Z_STREAM_END && z->avail_out == 0; -} - -static void -libz_engine_destroy_decompressor(struct decompressor *d) -{ - z_stream *z = d->private; - - inflateEnd(z); - free(z); -} - -static const struct engine libz_engine = { - .name = T("libz"), - - .init_compressor = libz_engine_init_compressor, - .compress = libz_engine_compress, - .destroy_compressor = libz_engine_destroy_compressor, - - .init_decompressor = libz_engine_init_decompressor, - .decompress = libz_engine_decompress, - .destroy_decompressor = libz_engine_destroy_decompressor, -}; - -/******************************************************************************/ - -static const struct engine * const all_engines[] = { - &libdeflate_engine, - &libz_engine, -}; - -#define DEFAULT_ENGINE libdeflate_engine - -static const struct engine * -name_to_engine(const tchar *name) -{ - size_t i; - - for (i = 0; i < ARRAY_LEN(all_engines); i++) - if (tstrcmp(all_engines[i]->name, name) == 0) - return all_engines[i]; - return NULL; -} - -/******************************************************************************/ - -static bool -compressor_init(struct compressor *c, int level, enum wrapper wrapper, - const struct engine *engine) -{ - c->level = level; - c->wrapper = wrapper; - c->engine = engine; - return engine->init_compressor(c); -} - -static size_t -do_compress(struct compressor *c, const void *in, size_t in_nbytes, - void *out, size_t out_nbytes_avail) -{ - return c->engine->compress(c, in, in_nbytes, out, out_nbytes_avail); -} - -static void -compressor_destroy(struct compressor *c) -{ - c->engine->destroy_compressor(c); -} - -static bool -decompressor_init(struct decompressor *d, enum wrapper wrapper, - const struct engine *engine) -{ - d->wrapper = wrapper; - d->engine = engine; - return engine->init_decompressor(d); -} - -static bool -do_decompress(struct decompressor *d, const void *in, size_t in_nbytes, - void *out, size_t out_nbytes) -{ - return d->engine->decompress(d, in, in_nbytes, out, out_nbytes); -} - -static void -decompressor_destroy(struct decompressor *d) -{ - d->engine->destroy_decompressor(d); -} - -/******************************************************************************/ - -static void -show_available_engines(FILE *fp) -{ - size_t i; - - fprintf(fp, "Available ENGINEs are: "); - for (i = 0; i < ARRAY_LEN(all_engines); i++) { - fprintf(fp, "%"TS, all_engines[i]->name); - if (i < ARRAY_LEN(all_engines) - 1) - fprintf(fp, ", "); - } - fprintf(fp, ". Default is %"TS"\n", DEFAULT_ENGINE.name); -} - -static void -show_usage(FILE *fp) -{ - fprintf(fp, -"Usage: %"TS" [-LVL] [-C ENGINE] [-D ENGINE] [-ghVz] [-s SIZE] [FILE]...\n" -"Benchmark DEFLATE compression and decompression on the specified FILEs.\n" -"\n" -"Options:\n" -" -1 fastest (worst) compression\n" -" -6 medium compression (default)\n" -" -12 slowest (best) compression\n" -" -C ENGINE compression engine\n" -" -D ENGINE decompression engine\n" -" -g use gzip wrapper\n" -" -h print this help\n" -" -s SIZE chunk size\n" -" -V show version and legal information\n" -" -z use zlib wrapper\n" -"\n", program_invocation_name); - - show_available_engines(fp); -} - -static void -show_version(void) -{ - printf( -"libdeflate compression benchmark program v" LIBDEFLATE_VERSION_STRING "\n" -"Copyright 2016 Eric Biggers\n" -"\n" -"This program is free software which may be modified and/or redistributed\n" -"under the terms of the MIT license. There is NO WARRANTY, to the extent\n" -"permitted by law. See the COPYING file for details.\n" - ); -} - - -/******************************************************************************/ - -static int -do_benchmark(struct file_stream *in, void *original_buf, void *compressed_buf, - void *decompressed_buf, u32 chunk_size, - struct compressor *compressor, - struct decompressor *decompressor) -{ - u64 total_uncompressed_size = 0; - u64 total_compressed_size = 0; - u64 total_compress_time = 0; - u64 total_decompress_time = 0; - ssize_t ret; - - while ((ret = xread(in, original_buf, chunk_size)) > 0) { - u32 original_size = ret; - u32 compressed_size; - u64 start_time; - bool ok; - - total_uncompressed_size += original_size; - - /* Compress the chunk of data. */ - start_time = timer_ticks(); - compressed_size = do_compress(compressor, - original_buf, - original_size, - compressed_buf, - original_size - 1); - total_compress_time += timer_ticks() - start_time; - - if (compressed_size) { - /* Successfully compressed the chunk of data. */ - - /* Decompress the data we just compressed and compare - * the result with the original. */ - start_time = timer_ticks(); - ok = do_decompress(decompressor, - compressed_buf, compressed_size, - decompressed_buf, original_size); - total_decompress_time += timer_ticks() - start_time; - - if (!ok) { - msg("%"TS": failed to decompress data", - in->name); - return -1; - } - - if (memcmp(original_buf, decompressed_buf, - original_size) != 0) - { - msg("%"TS": data did not decompress to " - "original", in->name); - return -1; - } - - total_compressed_size += compressed_size; - } else { - /* Compression did not make the chunk smaller. */ - total_compressed_size += original_size; - } - } - - if (ret < 0) - return ret; - - if (total_uncompressed_size == 0) { - printf("\tFile was empty.\n"); - return 0; - } - - if (total_compress_time == 0) - total_compress_time = 1; - if (total_decompress_time == 0) - total_decompress_time = 1; - - printf("\tCompressed %"PRIu64 " => %"PRIu64" bytes (%u.%03u%%)\n", - total_uncompressed_size, total_compressed_size, - (unsigned int)(total_compressed_size * 100 / - total_uncompressed_size), - (unsigned int)(total_compressed_size * 100000 / - total_uncompressed_size % 1000)); - printf("\tCompression time: %"PRIu64" ms (%"PRIu64" MB/s)\n", - timer_ticks_to_ms(total_compress_time), - timer_MB_per_s(total_uncompressed_size, total_compress_time)); - printf("\tDecompression time: %"PRIu64" ms (%"PRIu64" MB/s)\n", - timer_ticks_to_ms(total_decompress_time), - timer_MB_per_s(total_uncompressed_size, total_decompress_time)); - - return 0; -} - -int -tmain(int argc, tchar *argv[]) -{ - u32 chunk_size = 1048576; - int level = 6; - enum wrapper wrapper = NO_WRAPPER; - const struct engine *compress_engine = &DEFAULT_ENGINE; - const struct engine *decompress_engine = &DEFAULT_ENGINE; - void *original_buf = NULL; - void *compressed_buf = NULL; - void *decompressed_buf = NULL; - struct compressor compressor; - struct decompressor decompressor; - tchar *default_file_list[] = { NULL }; - int opt_char; - int i; - int ret; - - program_invocation_name = get_filename(argv[0]); - - while ((opt_char = tgetopt(argc, argv, optstring)) != -1) { - switch (opt_char) { - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - level = parse_compression_level(opt_char, toptarg); - if (level == 0) - return 1; - break; - case 'C': - compress_engine = name_to_engine(toptarg); - if (compress_engine == NULL) { - msg("invalid compression engine: \"%"TS"\"", toptarg); - show_available_engines(stderr); - return 1; - } - break; - case 'D': - decompress_engine = name_to_engine(toptarg); - if (decompress_engine == NULL) { - msg("invalid decompression engine: \"%"TS"\"", toptarg); - show_available_engines(stderr); - return 1; - } - break; - case 'g': - wrapper = GZIP_WRAPPER; - break; - case 'h': - show_usage(stdout); - return 0; - case 's': - chunk_size = tstrtoul(toptarg, NULL, 10); - if (chunk_size == 0) { - msg("invalid chunk size: \"%"TS"\"", toptarg); - return 1; - } - break; - case 'V': - show_version(); - return 0; - case 'Y': /* deprecated, use '-C libz' instead */ - compress_engine = &libz_engine; - break; - case 'Z': /* deprecated, use '-D libz' instead */ - decompress_engine = &libz_engine; - break; - case 'z': - wrapper = ZLIB_WRAPPER; - break; - default: - show_usage(stderr); - return 1; - } - } - - argc -= toptind; - argv += toptind; - - original_buf = xmalloc(chunk_size); - compressed_buf = xmalloc(chunk_size - 1); - decompressed_buf = xmalloc(chunk_size); - - ret = -1; - if (original_buf == NULL || compressed_buf == NULL || - decompressed_buf == NULL) - goto out0; - - if (!compressor_init(&compressor, level, wrapper, compress_engine)) - goto out0; - - if (!decompressor_init(&decompressor, wrapper, decompress_engine)) - goto out1; - - if (argc == 0) { - argv = default_file_list; - argc = ARRAY_LEN(default_file_list); - } else { - for (i = 0; i < argc; i++) - if (argv[i][0] == '-' && argv[i][1] == '\0') - argv[i] = NULL; - } - - printf("Benchmarking DEFLATE compression:\n"); - printf("\tCompression level: %d\n", level); - printf("\tChunk size: %"PRIu32"\n", chunk_size); - printf("\tWrapper: %s\n", - wrapper == NO_WRAPPER ? "None" : - wrapper == ZLIB_WRAPPER ? "zlib" : "gzip"); - printf("\tCompression engine: %"TS"\n", compress_engine->name); - printf("\tDecompression engine: %"TS"\n", decompress_engine->name); - - for (i = 0; i < argc; i++) { - struct file_stream in; - - ret = xopen_for_read(argv[i], true, &in); - if (ret != 0) - goto out2; - - printf("Processing %"TS"...\n", in.name); - - ret = do_benchmark(&in, original_buf, compressed_buf, - decompressed_buf, chunk_size, &compressor, - &decompressor); - xclose(&in); - if (ret != 0) - goto out2; - } - ret = 0; -out2: - decompressor_destroy(&decompressor); -out1: - compressor_destroy(&compressor); -out0: - free(decompressed_buf); - free(compressed_buf); - free(original_buf); - return -ret; -} diff --git a/compress/libdeflate/programs/checksum.c b/compress/libdeflate/programs/checksum.c deleted file mode 100644 index 3e52ca42..00000000 --- a/compress/libdeflate/programs/checksum.c +++ /dev/null @@ -1,197 +0,0 @@ -/* - * checksum.c - Adler-32 and CRC-32 checksumming program - * - * Copyright 2016 Eric Biggers - * - * Permission is hereby granted, free of charge, to any person - * obtaining a copy of this software and associated documentation - * files (the "Software"), to deal in the Software without - * restriction, including without limitation the rights to use, - * copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the - * Software is furnished to do so, subject to the following - * conditions: - * - * The above copyright notice and this permission notice shall be - * included in all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES - * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND - * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT - * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, - * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR - * OTHER DEALINGS IN THE SOFTWARE. - */ - -#include - -#include "prog_util.h" - -static const tchar *const optstring = T("Ahs:tZ"); - -static void -show_usage(FILE *fp) -{ - fprintf(fp, -"Usage: %"TS" [-A] [-h] [-s SIZE] [-t] [-Z] [FILE]...\n" -"Calculate Adler-32 or CRC-32 checksums of the specified FILEs.\n" -"\n" -"Options:\n" -" -A use Adler-32 (default is CRC-32)\n" -" -h print this help\n" -" -s SIZE chunk size\n" -" -t show checksum speed, excluding I/O\n" -" -Z use zlib implementation instead of libdeflate\n", - program_invocation_name); -} - -static u32 -zlib_adler32(u32 adler, const void *buf, size_t len) -{ - return adler32(adler, buf, len); -} - -static u32 -zlib_crc32(u32 crc, const void *buf, size_t len) -{ - return crc32(crc, buf, len); -} - -typedef u32 (*cksum_fn_t)(u32, const void *, size_t); - -static int -checksum_stream(struct file_stream *in, cksum_fn_t cksum, u32 *sum, - void *buf, size_t bufsize, u64 *size_ret, u64 *elapsed_ret) -{ - u64 size = 0; - u64 elapsed = 0; - - for (;;) { - ssize_t ret; - u64 start_time; - - ret = xread(in, buf, bufsize); - if (ret < 0) - return ret; - if (ret == 0) - break; - - size += ret; - start_time = timer_ticks(); - *sum = cksum(*sum, buf, ret); - elapsed += timer_ticks() - start_time; - } - - if (elapsed == 0) - elapsed = 1; - *size_ret = size; - *elapsed_ret = elapsed; - return 0; -} - -int -tmain(int argc, tchar *argv[]) -{ - bool use_adler32 = false; - bool use_zlib_impl = false; - bool do_timing = false; - void *buf; - size_t bufsize = 131072; - tchar *default_file_list[] = { NULL }; - cksum_fn_t cksum; - int opt_char; - int i; - int ret; - - program_invocation_name = get_filename(argv[0]); - - while ((opt_char = tgetopt(argc, argv, optstring)) != -1) { - switch (opt_char) { - case 'A': - use_adler32 = true; - break; - case 'h': - show_usage(stdout); - return 0; - case 's': - bufsize = tstrtoul(toptarg, NULL, 10); - if (bufsize == 0) { - msg("invalid chunk size: \"%"TS"\"", toptarg); - return 1; - } - break; - case 't': - do_timing = true; - break; - case 'Z': - use_zlib_impl = true; - break; - default: - show_usage(stderr); - return 1; - } - } - - argc -= toptind; - argv += toptind; - - if (use_adler32) { - if (use_zlib_impl) - cksum = zlib_adler32; - else - cksum = libdeflate_adler32; - } else { - if (use_zlib_impl) - cksum = zlib_crc32; - else - cksum = libdeflate_crc32; - } - - buf = xmalloc(bufsize); - if (buf == NULL) - return 1; - - if (argc == 0) { - argv = default_file_list; - argc = ARRAY_LEN(default_file_list); - } else { - for (i = 0; i < argc; i++) - if (argv[i][0] == '-' && argv[i][1] == '\0') - argv[i] = NULL; - } - - for (i = 0; i < argc; i++) { - struct file_stream in; - u32 sum = cksum(0, NULL, 0); - u64 size = 0; - u64 elapsed = 0; - - ret = xopen_for_read(argv[i], true, &in); - if (ret != 0) - goto out; - - ret = checksum_stream(&in, cksum, &sum, buf, bufsize, - &size, &elapsed); - if (ret == 0) { - if (do_timing) { - printf("%08"PRIx32"\t%"TS"\t" - "%"PRIu64" ms\t%"PRIu64" MB/s\n", - sum, in.name, timer_ticks_to_ms(elapsed), - timer_MB_per_s(size, elapsed)); - } else { - printf("%08"PRIx32"\t%"TS"\t\n", sum, in.name); - } - } - - xclose(&in); - - if (ret != 0) - goto out; - } - ret = 0; -out: - free(buf); - return -ret; -} diff --git a/compress/libdeflate/programs/detect.sh b/compress/libdeflate/programs/detect.sh deleted file mode 100755 index 9139cf68..00000000 --- a/compress/libdeflate/programs/detect.sh +++ /dev/null @@ -1,62 +0,0 @@ -#!/bin/sh - -if [ -z "$CC" ]; then - CC=cc -fi - -echo "/* THIS FILE WAS AUTOMATICALLY GENERATED. DO NOT EDIT. */" -echo "#ifndef CONFIG_H" -echo "#define CONFIG_H" - -tmpfile="$(mktemp -t libdeflate_config.XXXXXXXX)" -trap "rm -f \"$tmpfile\"" EXIT - -program_compiles() { - echo "$1" > "$tmpfile" - $CC $CFLAGS -x c "$tmpfile" -o /dev/null > /dev/null 2>&1 -} - -check_function() { - funcname="$1" - macro="HAVE_$(echo $funcname | tr a-z A-Z)" - echo - echo "/* Is the $funcname() function available? */" - if program_compiles "int main() { $funcname(); }"; then - echo "#define $macro 1" - else - echo "/* $macro is not set */" - fi -} - -have_stat_field() { - program_compiles "#include - #include - int main() { struct stat st; st.$1; }" -} - -check_stat_nanosecond_precision() { - echo - echo "/* Does stat() provide nanosecond-precision timestamps? */" - if have_stat_field st_atim; then - echo "#define HAVE_STAT_NANOSECOND_PRECISION 1" - elif have_stat_field st_atimespec; then - # Nonstandard field names used by OS X and older BSDs - echo "#define HAVE_STAT_NANOSECOND_PRECISION 1" - echo "#define st_atim st_atimespec" - echo "#define st_mtim st_mtimespec" - echo "#define st_ctim st_ctimespec" - else - echo "/* HAVE_STAT_NANOSECOND_PRECISION is not set */" - fi -} - -check_function clock_gettime -check_function futimens -check_function futimes -check_function posix_fadvise -check_function posix_madvise - -check_stat_nanosecond_precision - -echo -echo "#endif /* CONFIG_H */" diff --git a/compress/libdeflate/programs/gzip.c b/compress/libdeflate/programs/gzip.c deleted file mode 100644 index a08d4151..00000000 --- a/compress/libdeflate/programs/gzip.c +++ /dev/null @@ -1,632 +0,0 @@ -/* - * gzip.c - a file compression and decompression program - * - * Copyright 2016 Eric Biggers - * - * Permission is hereby granted, free of charge, to any person - * obtaining a copy of this software and associated documentation - * files (the "Software"), to deal in the Software without - * restriction, including without limitation the rights to use, - * copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the - * Software is furnished to do so, subject to the following - * conditions: - * - * The above copyright notice and this permission notice shall be - * included in all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES - * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND - * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT - * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, - * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR - * OTHER DEALINGS IN THE SOFTWARE. - */ - -#include "prog_util.h" - -#include -#include -#include -#ifdef _WIN32 -# include -#else -# include -# include -# include -#endif - -struct options { - bool to_stdout; - bool decompress; - bool force; - bool keep; - int compression_level; - const tchar *suffix; -}; - -static const tchar *const optstring = T("1::2::3::4::5::6::7::8::9::cdfhknS:V"); - -static void -show_usage(FILE *fp) -{ - fprintf(fp, -"Usage: %"TS" [-LEVEL] [-cdfhkV] [-S SUF] FILE...\n" -"Compress or decompress the specified FILEs.\n" -"\n" -"Options:\n" -" -1 fastest (worst) compression\n" -" -6 medium compression (default)\n" -" -12 slowest (best) compression\n" -" -c write to standard output\n" -" -d decompress\n" -" -f overwrite existing output files\n" -" -h print this help\n" -" -k don't delete input files\n" -" -S SUF use suffix SUF instead of .gz\n" -" -V show version and legal information\n", - program_invocation_name); -} - -static void -show_version(void) -{ - printf( -"gzip compression program v" LIBDEFLATE_VERSION_STRING "\n" -"Copyright 2016 Eric Biggers\n" -"\n" -"This program is free software which may be modified and/or redistributed\n" -"under the terms of the MIT license. There is NO WARRANTY, to the extent\n" -"permitted by law. See the COPYING file for details.\n" - ); -} - -/* Was the program invoked in decompression mode? */ -static bool -is_gunzip(void) -{ - if (tstrxcmp(program_invocation_name, T("gunzip")) == 0) - return true; - if (tstrxcmp(program_invocation_name, T("libdeflate-gunzip")) == 0) - return true; -#ifdef _WIN32 - if (tstrxcmp(program_invocation_name, T("gunzip.exe")) == 0) - return true; - if (tstrxcmp(program_invocation_name, T("libdeflate-gunzip.exe")) == 0) - return true; -#endif - return false; -} - -static const tchar * -get_suffix(const tchar *path, const tchar *suffix) -{ - size_t path_len = tstrlen(path); - size_t suffix_len = tstrlen(suffix); - const tchar *p; - - if (path_len <= suffix_len) - return NULL; - p = &path[path_len - suffix_len]; - if (tstrxcmp(p, suffix) == 0) - return p; - return NULL; -} - -static bool -has_suffix(const tchar *path, const tchar *suffix) -{ - return get_suffix(path, suffix) != NULL; -} - -static tchar * -append_suffix(const tchar *path, const tchar *suffix) -{ - size_t path_len = tstrlen(path); - size_t suffix_len = tstrlen(suffix); - tchar *suffixed_path; - - suffixed_path = xmalloc((path_len + suffix_len + 1) * sizeof(tchar)); - if (suffixed_path == NULL) - return NULL; - tmemcpy(suffixed_path, path, path_len); - tmemcpy(&suffixed_path[path_len], suffix, suffix_len + 1); - return suffixed_path; -} - -static int -do_compress(struct libdeflate_compressor *compressor, - struct file_stream *in, struct file_stream *out) -{ - const void *uncompressed_data = in->mmap_mem; - size_t uncompressed_size = in->mmap_size; - void *compressed_data; - size_t actual_compressed_size; - size_t max_compressed_size; - int ret; - - max_compressed_size = libdeflate_gzip_compress_bound(compressor, - uncompressed_size); - compressed_data = xmalloc(max_compressed_size); - if (compressed_data == NULL) { - msg("%"TS": file is probably too large to be processed by this " - "program", in->name); - ret = -1; - goto out; - } - - actual_compressed_size = libdeflate_gzip_compress(compressor, - uncompressed_data, - uncompressed_size, - compressed_data, - max_compressed_size); - if (actual_compressed_size == 0) { - msg("Bug in libdeflate_gzip_compress_bound()!"); - ret = -1; - goto out; - } - - ret = full_write(out, compressed_data, actual_compressed_size); -out: - free(compressed_data); - return ret; -} - -static u32 -load_u32_gzip(const u8 *p) -{ - return ((u32)p[0] << 0) | ((u32)p[1] << 8) | - ((u32)p[2] << 16) | ((u32)p[3] << 24); -} - -static int -do_decompress(struct libdeflate_decompressor *decompressor, - struct file_stream *in, struct file_stream *out) -{ - const u8 *compressed_data = in->mmap_mem; - size_t compressed_size = in->mmap_size; - void *uncompressed_data = NULL; - size_t uncompressed_size; - size_t actual_in_nbytes; - size_t actual_out_nbytes; - enum libdeflate_result result; - int ret = 0; - - if (compressed_size < sizeof(u32)) { - msg("%"TS": not in gzip format", in->name); - ret = -1; - goto out; - } - - uncompressed_size = load_u32_gzip(&compressed_data[compressed_size - 4]); - - do { - if (uncompressed_data == NULL) { - uncompressed_data = xmalloc(uncompressed_size); - if (uncompressed_data == NULL) { - msg("%"TS": file is probably too large to be " - "processed by this program", in->name); - ret = -1; - goto out; - } - } - - result = libdeflate_gzip_decompress_ex(decompressor, - compressed_data, - compressed_size, - uncompressed_data, - uncompressed_size, - &actual_in_nbytes, - &actual_out_nbytes); - - if (result == LIBDEFLATE_INSUFFICIENT_SPACE) { - if (uncompressed_size * 2 <= uncompressed_size) { - msg("%"TS": file corrupt or too large to be " - "processed by this program", in->name); - ret = -1; - goto out; - } - uncompressed_size *= 2; - free(uncompressed_data); - uncompressed_data = NULL; - continue; - } - - if (result != LIBDEFLATE_SUCCESS) { - msg("%"TS": file corrupt or not in gzip format", - in->name); - ret = -1; - goto out; - } - - if (actual_in_nbytes == 0 || - actual_in_nbytes > compressed_size || - actual_out_nbytes > uncompressed_size) { - msg("Bug in libdeflate_gzip_decompress_ex()!"); - ret = -1; - goto out; - } - - ret = full_write(out, uncompressed_data, actual_out_nbytes); - if (ret != 0) - goto out; - - compressed_data += actual_in_nbytes; - compressed_size -= actual_in_nbytes; - - } while (compressed_size != 0); -out: - free(uncompressed_data); - return ret; -} - -static int -stat_file(struct file_stream *in, stat_t *stbuf, bool allow_hard_links) -{ - if (tfstat(in->fd, stbuf) != 0) { - msg("%"TS": unable to stat file", in->name); - return -1; - } - - if (!S_ISREG(stbuf->st_mode) && !in->is_standard_stream) { - msg("%"TS" is %s -- skipping", - in->name, S_ISDIR(stbuf->st_mode) ? "a directory" : - "not a regular file"); - return -2; - } - - if (stbuf->st_nlink > 1 && !allow_hard_links) { - msg("%"TS" has multiple hard links -- skipping " - "(use -f to process anyway)", in->name); - return -2; - } - - return 0; -} - -static void -restore_mode(struct file_stream *out, const stat_t *stbuf) -{ -#ifndef _WIN32 - if (fchmod(out->fd, stbuf->st_mode) != 0) - msg_errno("%"TS": unable to preserve mode", out->name); -#endif -} - -static void -restore_owner_and_group(struct file_stream *out, const stat_t *stbuf) -{ -#ifndef _WIN32 - if (fchown(out->fd, stbuf->st_uid, stbuf->st_gid) != 0) { - msg_errno("%"TS": unable to preserve owner and group", - out->name); - } -#endif -} - -static void -restore_timestamps(struct file_stream *out, const tchar *newpath, - const stat_t *stbuf) -{ - int ret; -#if defined(HAVE_FUTIMENS) && defined(HAVE_STAT_NANOSECOND_PRECISION) - struct timespec times[2] = { - stbuf->st_atim, stbuf->st_mtim, - }; - ret = futimens(out->fd, times); -#elif defined(HAVE_FUTIMES) && defined(HAVE_STAT_NANOSECOND_PRECISION) - struct timeval times[2] = { - { stbuf->st_atim.tv_sec, stbuf->st_atim.tv_nsec / 1000, }, - { stbuf->st_mtim.tv_sec, stbuf->st_mtim.tv_nsec / 1000, }, - }; - ret = futimes(out->fd, times); -#else - struct tutimbuf times = { - stbuf->st_atime, stbuf->st_mtime, - }; - ret = tutime(newpath, ×); -#endif - if (ret != 0) - msg_errno("%"TS": unable to preserve timestamps", out->name); -} - -static void -restore_metadata(struct file_stream *out, const tchar *newpath, - const stat_t *stbuf) -{ - restore_mode(out, stbuf); - restore_owner_and_group(out, stbuf); - restore_timestamps(out, newpath, stbuf); -} - -static int -decompress_file(struct libdeflate_decompressor *decompressor, const tchar *path, - const struct options *options) -{ - tchar *oldpath = (tchar *)path; - tchar *newpath = NULL; - struct file_stream in; - struct file_stream out; - stat_t stbuf; - int ret; - int ret2; - - if (path != NULL) { - const tchar *suffix = get_suffix(path, options->suffix); - if (suffix == NULL) { - /* - * Input file is unsuffixed. If the file doesn't exist, - * then try it suffixed. Otherwise, if we're not - * writing to stdout, skip the file with warning status. - * Otherwise, go ahead and try to open the file anyway - * (which will very likely fail). - */ - if (tstat(path, &stbuf) != 0 && errno == ENOENT) { - oldpath = append_suffix(path, options->suffix); - if (oldpath == NULL) - return -1; - if (!options->to_stdout) - newpath = (tchar *)path; - } else if (!options->to_stdout) { - msg("\"%"TS"\" does not end with the %"TS" " - "suffix -- skipping", - path, options->suffix); - return -2; - } - } else if (!options->to_stdout) { - /* - * Input file is suffixed, and we're not writing to - * stdout. Strip the suffix to get the path to the - * output file. - */ - newpath = xmalloc((suffix - oldpath + 1) * - sizeof(tchar)); - if (newpath == NULL) - return -1; - tmemcpy(newpath, oldpath, suffix - oldpath); - newpath[suffix - oldpath] = '\0'; - } - } - - ret = xopen_for_read(oldpath, options->force || options->to_stdout, - &in); - if (ret != 0) - goto out_free_paths; - - if (!options->force && isatty(in.fd)) { - msg("Refusing to read compressed data from terminal. " - "Use -f to override.\nFor help, use -h."); - ret = -1; - goto out_close_in; - } - - ret = stat_file(&in, &stbuf, options->force || options->keep || - oldpath == NULL || newpath == NULL); - if (ret != 0) - goto out_close_in; - - ret = xopen_for_write(newpath, options->force, &out); - if (ret != 0) - goto out_close_in; - - /* TODO: need a streaming-friendly solution */ - ret = map_file_contents(&in, stbuf.st_size); - if (ret != 0) - goto out_close_out; - - ret = do_decompress(decompressor, &in, &out); - if (ret != 0) - goto out_close_out; - - if (oldpath != NULL && newpath != NULL) - restore_metadata(&out, newpath, &stbuf); - ret = 0; -out_close_out: - ret2 = xclose(&out); - if (ret == 0) - ret = ret2; - if (ret != 0 && newpath != NULL) - tunlink(newpath); -out_close_in: - xclose(&in); - if (ret == 0 && oldpath != NULL && newpath != NULL && !options->keep) - tunlink(oldpath); -out_free_paths: - if (newpath != path) - free(newpath); - if (oldpath != path) - free(oldpath); - return ret; -} - -static int -compress_file(struct libdeflate_compressor *compressor, const tchar *path, - const struct options *options) -{ - tchar *newpath = NULL; - struct file_stream in; - struct file_stream out; - stat_t stbuf; - int ret; - int ret2; - - if (path != NULL && !options->to_stdout) { - if (!options->force && has_suffix(path, options->suffix)) { - msg("%"TS": already has %"TS" suffix -- skipping", - path, options->suffix); - return 0; - } - newpath = append_suffix(path, options->suffix); - if (newpath == NULL) - return -1; - } - - ret = xopen_for_read(path, options->force || options->to_stdout, &in); - if (ret != 0) - goto out_free_newpath; - - ret = stat_file(&in, &stbuf, options->force || options->keep || - path == NULL || newpath == NULL); - if (ret != 0) - goto out_close_in; - - ret = xopen_for_write(newpath, options->force, &out); - if (ret != 0) - goto out_close_in; - - if (!options->force && isatty(out.fd)) { - msg("Refusing to write compressed data to terminal. " - "Use -f to override.\nFor help, use -h."); - ret = -1; - goto out_close_out; - } - - /* TODO: need a streaming-friendly solution */ - ret = map_file_contents(&in, stbuf.st_size); - if (ret != 0) - goto out_close_out; - - ret = do_compress(compressor, &in, &out); - if (ret != 0) - goto out_close_out; - - if (path != NULL && newpath != NULL) - restore_metadata(&out, newpath, &stbuf); - ret = 0; -out_close_out: - ret2 = xclose(&out); - if (ret == 0) - ret = ret2; - if (ret != 0 && newpath != NULL) - tunlink(newpath); -out_close_in: - xclose(&in); - if (ret == 0 && path != NULL && newpath != NULL && !options->keep) - tunlink(path); -out_free_newpath: - free(newpath); - return ret; -} - -int -tmain(int argc, tchar *argv[]) -{ - tchar *default_file_list[] = { NULL }; - struct options options; - int opt_char; - int i; - int ret; - - program_invocation_name = get_filename(argv[0]); - - options.to_stdout = false; - options.decompress = is_gunzip(); - options.force = false; - options.keep = false; - options.compression_level = 6; - options.suffix = T(".gz"); - - while ((opt_char = tgetopt(argc, argv, optstring)) != -1) { - switch (opt_char) { - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - options.compression_level = - parse_compression_level(opt_char, toptarg); - if (options.compression_level == 0) - return 1; - break; - case 'c': - options.to_stdout = true; - break; - case 'd': - options.decompress = true; - break; - case 'f': - options.force = true; - break; - case 'h': - show_usage(stdout); - return 0; - case 'k': - options.keep = true; - break; - case 'n': - /* - * -n means don't save or restore the original filename - * in the gzip header. Currently this implementation - * already behaves this way by default, so accept the - * option as a no-op. - */ - break; - case 'S': - options.suffix = toptarg; - if (options.suffix[0] == T('\0')) { - msg("invalid suffix"); - return 1; - } - break; - case 'V': - show_version(); - return 0; - default: - show_usage(stderr); - return 1; - } - } - - argv += toptind; - argc -= toptind; - - if (argc == 0) { - argv = default_file_list; - argc = ARRAY_LEN(default_file_list); - } else { - for (i = 0; i < argc; i++) - if (argv[i][0] == '-' && argv[i][1] == '\0') - argv[i] = NULL; - } - - ret = 0; - if (options.decompress) { - struct libdeflate_decompressor *d; - - d = alloc_decompressor(); - if (d == NULL) - return 1; - - for (i = 0; i < argc; i++) - ret |= -decompress_file(d, argv[i], &options); - - libdeflate_free_decompressor(d); - } else { - struct libdeflate_compressor *c; - - c = alloc_compressor(options.compression_level); - if (c == NULL) - return 1; - - for (i = 0; i < argc; i++) - ret |= -compress_file(c, argv[i], &options); - - libdeflate_free_compressor(c); - } - - /* - * If ret=0, there were no warnings or errors. Exit with status 0. - * If ret=2, there was at least one warning. Exit with status 2. - * Else, there was at least one error. Exit with status 1. - */ - if (ret != 0 && ret != 2) - ret = 1; - - return ret; -} diff --git a/compress/libdeflate/programs/prog_util.c b/compress/libdeflate/programs/prog_util.c deleted file mode 100644 index 68e9ae34..00000000 --- a/compress/libdeflate/programs/prog_util.c +++ /dev/null @@ -1,530 +0,0 @@ -/* - * prog_util.c - utility functions for programs - * - * Copyright 2016 Eric Biggers - * - * Permission is hereby granted, free of charge, to any person - * obtaining a copy of this software and associated documentation - * files (the "Software"), to deal in the Software without - * restriction, including without limitation the rights to use, - * copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the - * Software is furnished to do so, subject to the following - * conditions: - * - * The above copyright notice and this permission notice shall be - * included in all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES - * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND - * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT - * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, - * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR - * OTHER DEALINGS IN THE SOFTWARE. - */ - -#include "prog_util.h" - -#include -#include -#include -#include -#ifdef _WIN32 -# include -#else -# include -# include -# include -#endif - -#ifndef O_BINARY -# define O_BINARY 0 -#endif -#ifndef O_SEQUENTIAL -# define O_SEQUENTIAL 0 -#endif -#ifndef O_NOFOLLOW -# define O_NOFOLLOW 0 -#endif -#ifndef O_NONBLOCK -# define O_NONBLOCK 0 -#endif -#ifndef O_NOCTTY -# define O_NOCTTY 0 -#endif - -/* The invocation name of the program (filename component only) */ -const tchar *program_invocation_name; - -static void -do_msg(const char *format, bool with_errno, va_list va) -{ - int saved_errno = errno; - - fprintf(stderr, "%"TS": ", program_invocation_name); - vfprintf(stderr, format, va); - if (with_errno) - fprintf(stderr, ": %s\n", strerror(saved_errno)); - else - fprintf(stderr, "\n"); - - errno = saved_errno; -} - -/* Print a message to standard error */ -void -msg(const char *format, ...) -{ - va_list va; - - va_start(va, format); - do_msg(format, false, va); - va_end(va); -} - -/* Print a message to standard error, including a description of errno */ -void -msg_errno(const char *format, ...) -{ - va_list va; - - va_start(va, format); - do_msg(format, true, va); - va_end(va); -} - -/* malloc() wrapper */ -void * -xmalloc(size_t size) -{ - void *p = malloc(size); - if (p == NULL && size == 0) - p = malloc(1); - if (p == NULL) - msg("Out of memory"); - return p; -} - -/* - * Return the number of timer ticks that have elapsed since some unspecified - * point fixed at the start of program execution - */ -u64 -timer_ticks(void) -{ -#ifdef _WIN32 - LARGE_INTEGER count; - QueryPerformanceCounter(&count); - return count.QuadPart; -#elif defined(HAVE_CLOCK_GETTIME) - struct timespec ts; - clock_gettime(CLOCK_MONOTONIC, &ts); - return (1000000000 * (u64)ts.tv_sec) + ts.tv_nsec; -#else - struct timeval tv; - gettimeofday(&tv, NULL); - return (1000000 * (u64)tv.tv_sec) + tv.tv_usec; -#endif -} - -/* - * Return the number of timer ticks per second - */ -static u64 -timer_frequency(void) -{ -#ifdef _WIN32 - LARGE_INTEGER freq; - QueryPerformanceFrequency(&freq); - return freq.QuadPart; -#elif defined(HAVE_CLOCK_GETTIME) - return 1000000000; -#else - return 1000000; -#endif -} - -/* - * Convert a number of elapsed timer ticks to milliseconds - */ -u64 timer_ticks_to_ms(u64 ticks) -{ - return ticks * 1000 / timer_frequency(); -} - -/* - * Convert a byte count and a number of elapsed timer ticks to MB/s - */ -u64 timer_MB_per_s(u64 bytes, u64 ticks) -{ - return bytes * timer_frequency() / ticks / 1000000; -} - -/* - * Retrieve a pointer to the filename component of the specified path. - * - * Note: this does not modify the path. Therefore, it is not guaranteed to work - * properly for directories, since a path to a directory might have trailing - * slashes. - */ -const tchar * -get_filename(const tchar *path) -{ - const tchar *slash = tstrrchr(path, '/'); -#ifdef _WIN32 - const tchar *backslash = tstrrchr(path, '\\'); - if (backslash != NULL && (slash == NULL || backslash > slash)) - slash = backslash; -#endif - if (slash != NULL) - return slash + 1; - return path; -} - -/* Create a copy of 'path' surrounded by double quotes */ -static tchar * -quote_path(const tchar *path) -{ - size_t len = tstrlen(path); - tchar *result; - - result = xmalloc((1 + len + 1 + 1) * sizeof(tchar)); - if (result == NULL) - return NULL; - result[0] = '"'; - tmemcpy(&result[1], path, len); - result[1 + len] = '"'; - result[1 + len + 1] = '\0'; - return result; -} - -/* Open a file for reading, or set up standard input for reading */ -int -xopen_for_read(const tchar *path, bool symlink_ok, struct file_stream *strm) -{ - strm->mmap_token = NULL; - strm->mmap_mem = NULL; - - if (path == NULL) { - strm->is_standard_stream = true; - strm->name = T("standard input"); - strm->fd = STDIN_FILENO; - #ifdef _WIN32 - _setmode(strm->fd, O_BINARY); - #endif - return 0; - } - - strm->is_standard_stream = false; - - strm->name = quote_path(path); - if (strm->name == NULL) - return -1; - - strm->fd = topen(path, O_RDONLY | O_BINARY | O_NONBLOCK | O_NOCTTY | - (symlink_ok ? 0 : O_NOFOLLOW) | O_SEQUENTIAL); - if (strm->fd < 0) { - msg_errno("Can't open %"TS" for reading", strm->name); - free(strm->name); - return -1; - } - -#if defined(HAVE_POSIX_FADVISE) && (O_SEQUENTIAL == 0) - posix_fadvise(strm->fd, 0, 0, POSIX_FADV_SEQUENTIAL); -#endif - - return 0; -} - -/* Open a file for writing, or set up standard output for writing */ -int -xopen_for_write(const tchar *path, bool overwrite, struct file_stream *strm) -{ - int ret = -1; - - strm->mmap_token = NULL; - strm->mmap_mem = NULL; - - if (path == NULL) { - strm->is_standard_stream = true; - strm->name = T("standard output"); - strm->fd = STDOUT_FILENO; - #ifdef _WIN32 - _setmode(strm->fd, O_BINARY); - #endif - return 0; - } - - strm->is_standard_stream = false; - - strm->name = quote_path(path); - if (strm->name == NULL) - goto err; -retry: - strm->fd = topen(path, O_WRONLY | O_BINARY | O_NOFOLLOW | - O_CREAT | O_EXCL, 0644); - if (strm->fd < 0) { - if (errno != EEXIST) { - msg_errno("Can't open %"TS" for writing", strm->name); - goto err; - } - if (!overwrite) { - if (!isatty(STDERR_FILENO) || !isatty(STDIN_FILENO)) { - msg("%"TS" already exists; use -f to overwrite", - strm->name); - ret = -2; /* warning only */ - goto err; - } - fprintf(stderr, "%"TS": %"TS" already exists; " - "overwrite? (y/n) ", - program_invocation_name, strm->name); - if (getchar() != 'y') { - msg("Not overwriting."); - goto err; - } - } - if (tunlink(path) != 0) { - msg_errno("Unable to delete %"TS, strm->name); - goto err; - } - goto retry; - } - - return 0; - -err: - free(strm->name); - return ret; -} - -/* Read the full contents of a file into memory */ -static int -read_full_contents(struct file_stream *strm) -{ - size_t filled = 0; - size_t capacity = 4096; - char *buf; - int ret; - - buf = xmalloc(capacity); - if (buf == NULL) - return -1; - do { - if (filled == capacity) { - char *newbuf; - - if (capacity == SIZE_MAX) - goto oom; - capacity += MIN(SIZE_MAX - capacity, capacity); - newbuf = realloc(buf, capacity); - if (newbuf == NULL) - goto oom; - buf = newbuf; - } - ret = xread(strm, &buf[filled], capacity - filled); - if (ret < 0) - goto err; - filled += ret; - } while (ret != 0); - - strm->mmap_mem = buf; - strm->mmap_size = filled; - return 0; - -err: - free(buf); - return ret; -oom: - msg("Out of memory! %"TS" is too large to be processed by " - "this program as currently implemented.", strm->name); - ret = -1; - goto err; -} - -/* Map the contents of a file into memory */ -int -map_file_contents(struct file_stream *strm, u64 size) -{ - if (size == 0) /* mmap isn't supported on empty files */ - return read_full_contents(strm); - - if (size > SIZE_MAX) { - msg("%"TS" is too large to be processed by this program", - strm->name); - return -1; - } -#ifdef _WIN32 - strm->mmap_token = CreateFileMapping( - (HANDLE)(intptr_t)_get_osfhandle(strm->fd), - NULL, PAGE_READONLY, 0, 0, NULL); - if (strm->mmap_token == NULL) { - DWORD err = GetLastError(); - if (err == ERROR_BAD_EXE_FORMAT) /* mmap unsupported */ - return read_full_contents(strm); - msg("Unable create file mapping for %"TS": Windows error %u", - strm->name, (unsigned int)err); - return -1; - } - - strm->mmap_mem = MapViewOfFile((HANDLE)strm->mmap_token, - FILE_MAP_READ, 0, 0, size); - if (strm->mmap_mem == NULL) { - msg("Unable to map %"TS" into memory: Windows error %u", - strm->name, (unsigned int)GetLastError()); - CloseHandle((HANDLE)strm->mmap_token); - return -1; - } -#else /* _WIN32 */ - strm->mmap_mem = mmap(NULL, size, PROT_READ, MAP_SHARED, strm->fd, 0); - if (strm->mmap_mem == MAP_FAILED) { - strm->mmap_mem = NULL; - if (errno == ENODEV) /* mmap isn't supported on this file */ - return read_full_contents(strm); - if (errno == ENOMEM) { - msg("%"TS" is too large to be processed by this " - "program", strm->name); - } else { - msg_errno("Unable to map %"TS" into memory", - strm->name); - } - return -1; - } - -#ifdef HAVE_POSIX_MADVISE - posix_madvise(strm->mmap_mem, size, POSIX_MADV_SEQUENTIAL); -#endif - strm->mmap_token = strm; /* anything that's not NULL */ - -#endif /* !_WIN32 */ - strm->mmap_size = size; - return 0; -} - -/* - * Read from a file, returning the full count to indicate all bytes were read, a - * short count (possibly 0) to indicate EOF, or -1 to indicate error. - */ -ssize_t -xread(struct file_stream *strm, void *buf, size_t count) -{ - char *p = buf; - size_t orig_count = count; - - while (count != 0) { - ssize_t res = read(strm->fd, p, MIN(count, INT_MAX)); - if (res == 0) - break; - if (res < 0) { - if (errno == EAGAIN || errno == EINTR) - continue; - msg_errno("Error reading from %"TS, strm->name); - return -1; - } - p += res; - count -= res; - } - return orig_count - count; -} - -/* Write to a file, returning 0 if all bytes were written or -1 on error */ -int -full_write(struct file_stream *strm, const void *buf, size_t count) -{ - const char *p = buf; - - while (count != 0) { - ssize_t res = write(strm->fd, p, MIN(count, INT_MAX)); - if (res <= 0) { - msg_errno("Error writing to %"TS, strm->name); - return -1; - } - p += res; - count -= res; - } - return 0; -} - -/* Close a file, returning 0 on success or -1 on error */ -int -xclose(struct file_stream *strm) -{ - int ret = 0; - - if (!strm->is_standard_stream) { - if (close(strm->fd) != 0) { - msg_errno("Error closing %"TS, strm->name); - ret = -1; - } - free(strm->name); - } - - if (strm->mmap_token != NULL) { -#ifdef _WIN32 - UnmapViewOfFile(strm->mmap_mem); - CloseHandle((HANDLE)strm->mmap_token); -#else - munmap(strm->mmap_mem, strm->mmap_size); -#endif - strm->mmap_token = NULL; - } else { - free(strm->mmap_mem); - } - strm->mmap_mem = NULL; - strm->fd = -1; - strm->name = NULL; - return ret; -} - -/* - * Parse the compression level given on the command line, returning the - * compression level on success or 0 on error - */ -int -parse_compression_level(tchar opt_char, const tchar *arg) -{ - unsigned long level = opt_char - '0'; - const tchar *p; - - if (arg == NULL) - arg = T(""); - - for (p = arg; *p >= '0' && *p <= '9'; p++) - level = (level * 10) + (*p - '0'); - - if (level < 1 || level > 12 || *p != '\0') { - msg("Invalid compression level: \"%"TC"%"TS"\". " - "Must be an integer in the range [1, 12].", opt_char, arg); - return 0; - } - - return level; -} - -/* Allocate a new DEFLATE compressor */ -struct libdeflate_compressor * -alloc_compressor(int level) -{ - struct libdeflate_compressor *c; - - c = libdeflate_alloc_compressor(level); - if (c == NULL) { - msg_errno("Unable to allocate compressor with " - "compression level %d", level); - } - return c; -} - -/* Allocate a new DEFLATE decompressor */ -struct libdeflate_decompressor * -alloc_decompressor(void) -{ - struct libdeflate_decompressor *d; - - d = libdeflate_alloc_decompressor(); - if (d == NULL) - msg_errno("Unable to allocate decompressor"); - - return d; -} diff --git a/compress/libdeflate/programs/prog_util.h b/compress/libdeflate/programs/prog_util.h deleted file mode 100644 index 0dcad015..00000000 --- a/compress/libdeflate/programs/prog_util.h +++ /dev/null @@ -1,162 +0,0 @@ -/* - * prog_util.h - utility functions for programs - * - * Copyright 2016 Eric Biggers - * - * Permission is hereby granted, free of charge, to any person - * obtaining a copy of this software and associated documentation - * files (the "Software"), to deal in the Software without - * restriction, including without limitation the rights to use, - * copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the - * Software is furnished to do so, subject to the following - * conditions: - * - * The above copyright notice and this permission notice shall be - * included in all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES - * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND - * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT - * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, - * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR - * OTHER DEALINGS IN THE SOFTWARE. - */ - -#ifndef PROGRAMS_PROG_UTIL_H -#define PROGRAMS_PROG_UTIL_H - -#ifdef HAVE_CONFIG_H -# include "config.h" -#endif - -#include "libdeflate.h" - -#include -#include -#include -#include - -#include "common_defs.h" - -#ifdef __GNUC__ -# define _printf(str_idx, args_idx) \ - __attribute__((format(printf, str_idx, args_idx))) -#else -# define _printf(str_idx, args_idx) -#endif - -#ifdef _WIN32 - -/* - * Definitions for Windows builds. Mainly, 'tchar' is defined to be the 2-byte - * 'wchar_t' type instead of 'char'. This is the only "easy" way I know of to - * get full Unicode support on Windows... - */ - -#include -extern int wmain(int argc, wchar_t **argv); -# define tmain wmain -# define tchar wchar_t -# define _T(text) L##text -# define T(text) _T(text) -# define TS "ls" -# define TC "lc" -# define tmemcpy wmemcpy -# define topen _wopen -# define tstrchr wcschr -# define tstrcmp wcscmp -# define tstrcpy wcscpy -# define tstrlen wcslen -# define tstrrchr wcsrchr -# define tstrtoul wcstoul -# define tstrxcmp wcsicmp -# define tunlink _wunlink -# define tutimbuf __utimbuf64 -# define tutime _wutime64 -# define tstat _wstat64 -# define tfstat _fstat64 -# define stat_t struct _stat64 -# ifdef _MSC_VER -# define STDIN_FILENO 0 -# define STDOUT_FILENO 1 -# define STDERR_FILENO 2 -# define S_ISREG(m) (((m) & S_IFMT) == S_IFREG) -# define S_ISDIR(m) (((m) & S_IFMT) == S_IFDIR) -# endif - -#else /* _WIN32 */ - -/* Standard definitions for everyone else */ - -# define tmain main -# define tchar char -# define T(text) text -# define TS "s" -# define TC "c" -# define tmemcpy memcpy -# define topen open -# define tstrchr strchr -# define tstrcmp strcmp -# define tstrcpy strcpy -# define tstrlen strlen -# define tstrrchr strrchr -# define tstrtoul strtoul -# define tstrxcmp strcmp -# define tunlink unlink -# define tutimbuf utimbuf -# define tutime utime -# define tstat stat -# define tfstat fstat -# define stat_t struct stat - -#endif /* !_WIN32 */ - -extern const tchar *program_invocation_name; - -extern void _printf(1, 2) msg(const char *fmt, ...); -extern void _printf(1, 2) msg_errno(const char *fmt, ...); - -extern void *xmalloc(size_t size); - -extern u64 timer_ticks(void); -extern u64 timer_ticks_to_ms(u64 ticks); -extern u64 timer_MB_per_s(u64 bytes, u64 ticks); - -extern const tchar *get_filename(const tchar *path); - -struct file_stream { - int fd; - tchar *name; - bool is_standard_stream; - void *mmap_token; - void *mmap_mem; - size_t mmap_size; -}; - -extern int xopen_for_read(const tchar *path, bool symlink_ok, - struct file_stream *strm); -extern int xopen_for_write(const tchar *path, bool force, - struct file_stream *strm); -extern int map_file_contents(struct file_stream *strm, u64 size); - -extern ssize_t xread(struct file_stream *strm, void *buf, size_t count); -extern int full_write(struct file_stream *strm, const void *buf, size_t count); - -extern int xclose(struct file_stream *strm); - -extern int parse_compression_level(tchar opt_char, const tchar *arg); - -extern struct libdeflate_compressor *alloc_compressor(int level); -extern struct libdeflate_decompressor *alloc_decompressor(void); - -/* tgetopt.c */ - -extern tchar *toptarg; -extern int toptind, topterr, toptopt; - -extern int tgetopt(int argc, tchar *argv[], const tchar *optstring); - -#endif /* PROGRAMS_PROG_UTIL_H */ diff --git a/compress/libdeflate/programs/test_checksums.c b/compress/libdeflate/programs/test_checksums.c deleted file mode 100644 index 48cc56ba..00000000 --- a/compress/libdeflate/programs/test_checksums.c +++ /dev/null @@ -1,182 +0,0 @@ -/* - * test_checksums.c - * - * Verify that libdeflate's Adler-32 and CRC-32 functions produce the same - * results as their zlib equivalents. - */ - -#include -#include -#include - -#include "prog_util.h" - -static unsigned int rng_seed; - -static void -assertion_failed(const char *file, int line) -{ - fprintf(stderr, "Assertion failed at %s:%d\n", file, line); - fprintf(stderr, "RNG seed was %u\n", rng_seed); - abort(); -} - -#define ASSERT(expr) if (!(expr)) assertion_failed(__FILE__, __LINE__); - -typedef u32 (*cksum_fn_t)(u32, const void *, size_t); - -static u32 -zlib_adler32(u32 adler, const void *buf, size_t len) -{ - return adler32(adler, buf, len); -} - -static u32 -zlib_crc32(u32 crc, const void *buf, size_t len) -{ - return crc32(crc, buf, len); -} - -static u32 -select_initial_crc(void) -{ - if (rand() & 1) - return 0; - return ((u32)rand() << 16) | rand(); -} - -static u32 -select_initial_adler(void) -{ - u32 lo, hi; - - if (rand() & 1) - return 1; - - lo = (rand() % 4 == 0 ? 65520 : rand() % 65521); - hi = (rand() % 4 == 0 ? 65520 : rand() % 65521); - return (hi << 16) | lo; -} - -static void -test_initial_values(cksum_fn_t cksum, u32 expected) -{ - ASSERT(cksum(0, NULL, 0) == expected); - if (cksum != zlib_adler32) /* broken */ - ASSERT(cksum(0, NULL, 1) == expected); - ASSERT(cksum(0, NULL, 1234) == expected); - ASSERT(cksum(1234, NULL, 0) == expected); - ASSERT(cksum(1234, NULL, 1234) == expected); -} - -static void -test_multipart(const u8 *buffer, size_t size, const char *name, - cksum_fn_t cksum, u32 v, u32 expected) -{ - size_t division = rand() % (size + 1); - v = cksum(v, buffer, division); - v = cksum(v, buffer + division, size - division); - if (v != expected) { - fprintf(stderr, "%s checksum failed multipart test\n", name); - ASSERT(0); - } -} - -static void -test_checksums(const void *buffer, size_t size, const char *name, - cksum_fn_t cksum1, cksum_fn_t cksum2, u32 initial_value) -{ - u32 v1 = cksum1(initial_value, buffer, size); - u32 v2 = cksum2(initial_value, buffer, size); - - if (v1 != v2) { - fprintf(stderr, "%s checksum mismatch\n", name); - fprintf(stderr, "initial_value=0x%08"PRIx32", buffer=%p, " - "size=%zu, buffer=", initial_value, buffer, size); - for (size_t i = 0; i < MIN(size, 256); i++) - fprintf(stderr, "%02x", ((const u8 *)buffer)[i]); - if (size > 256) - fprintf(stderr, "..."); - fprintf(stderr, "\n"); - ASSERT(0); - } - - if ((rand() & 15) == 0) { - test_multipart(buffer, size, name, cksum1, initial_value, v1); - test_multipart(buffer, size, name, cksum2, initial_value, v1); - } -} - -static void -test_crc32(const void *buffer, size_t size, u32 initial_value) -{ - test_checksums(buffer, size, "CRC-32", - libdeflate_crc32, zlib_crc32, initial_value); -} - -static void -test_adler32(const void *buffer, size_t size, u32 initial_value) -{ - test_checksums(buffer, size, "Adler-32", - libdeflate_adler32, zlib_adler32, initial_value); -} - -static void test_random_buffers(u8 *buffer, size_t limit, u32 num_iter) -{ - for (u32 i = 0; i < num_iter; i++) { - size_t start = rand() % limit; - size_t len = rand() % (limit - start); - - for (size_t j = start; j < start + len; j++) - buffer[j] = rand(); - - test_adler32(&buffer[start], len, select_initial_adler()); - test_crc32(&buffer[start], len, select_initial_crc()); - } -} - -int -tmain(int argc, tchar *argv[]) -{ - u8 *buffer = malloc(32768); - - rng_seed = time(NULL); - srand(rng_seed); - - test_initial_values(libdeflate_adler32, 1); - test_initial_values(zlib_adler32, 1); - test_initial_values(libdeflate_crc32, 0); - test_initial_values(zlib_crc32, 0); - - /* Test different buffer sizes and alignments */ - test_random_buffers(buffer, 256, 5000); - test_random_buffers(buffer, 1024, 500); - test_random_buffers(buffer, 32768, 50); - - /* - * Test Adler-32 overflow cases. For example, given all 0xFF bytes and - * the highest possible initial (s1, s2) of (65520, 65520), then s2 if - * stored as a 32-bit unsigned integer will overflow if > 5552 bytes are - * processed. Implementations must make sure to reduce s2 modulo 65521 - * before that point. Also, some implementations make use of 16-bit - * counters which can overflow earlier. - */ - memset(buffer, 0xFF, 32768); - for (u32 i = 0; i < 20; i++) { - u32 initial_value; - - if (i == 0) - initial_value = ((u32)65520 << 16) | 65520; - else - initial_value = select_initial_adler(); - - test_adler32(buffer, 5553, initial_value); - test_adler32(buffer, rand() % 32769, initial_value); - buffer[rand() % 32768] = 0xFE; - } - - printf("Adler-32 and CRC-32 checksum tests passed!\n"); - - free(buffer); - return 0; -} diff --git a/compress/libdeflate/programs/tgetopt.c b/compress/libdeflate/programs/tgetopt.c deleted file mode 100644 index 868600d9..00000000 --- a/compress/libdeflate/programs/tgetopt.c +++ /dev/null @@ -1,118 +0,0 @@ -/* - * tgetopt.c - portable replacement for GNU getopt() - * - * Copyright 2016 Eric Biggers - * - * Permission is hereby granted, free of charge, to any person - * obtaining a copy of this software and associated documentation - * files (the "Software"), to deal in the Software without - * restriction, including without limitation the rights to use, - * copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the - * Software is furnished to do so, subject to the following - * conditions: - * - * The above copyright notice and this permission notice shall be - * included in all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES - * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND - * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT - * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, - * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING - * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR - * OTHER DEALINGS IN THE SOFTWARE. - */ - -#include "prog_util.h" - -tchar *toptarg; -int toptind = 1, topterr = 1, toptopt; - -/* - * This is a simple implementation of getopt(). It can be compiled with either - * 'char' or 'wchar_t' as the character type. - * - * Do *not* use this implementation if you need any of the following features, - * as they are not supported: - * - Long options - * - Option-related arguments retained in argv, not nulled out - * - '+' and '-' characters in optstring - */ -int -tgetopt(int argc, tchar *argv[], const tchar *optstring) -{ - static tchar empty[1]; - static tchar *nextchar; - static bool done; - - if (toptind == 1) { - /* Starting to scan a new argument vector */ - nextchar = NULL; - done = false; - } - - while (!done && (nextchar != NULL || toptind < argc)) { - if (nextchar == NULL) { - /* Scanning a new argument */ - tchar *arg = argv[toptind++]; - if (arg[0] == '-' && arg[1] != '\0') { - if (arg[1] == '-' && arg[2] == '\0') { - /* All args after "--" are nonoptions */ - argv[toptind - 1] = NULL; - done = true; - } else { - /* Start of short option characters */ - nextchar = &arg[1]; - } - } - } else { - /* More short options in previous arg */ - tchar opt = *nextchar; - tchar *p = tstrchr(optstring, opt); - if (p == NULL) { - if (topterr) - msg("invalid option -- '%"TC"'", opt); - toptopt = opt; - return '?'; - } - /* 'opt' is a valid short option character */ - nextchar++; - toptarg = NULL; - if (*(p + 1) == ':') { - /* 'opt' can take an argument */ - if (*nextchar != '\0') { - /* Optarg is in same argv argument */ - toptarg = nextchar; - nextchar = empty; - } else if (toptind < argc && *(p + 2) != ':') { - /* Optarg is next argv argument */ - argv[toptind - 1] = NULL; - toptarg = argv[toptind++]; - } else if (*(p + 2) != ':') { - if (topterr && *optstring != ':') { - msg("option requires an " - "argument -- '%"TC"'", opt); - } - toptopt = opt; - opt = (*optstring == ':') ? ':' : '?'; - } - } - if (*nextchar == '\0') { - argv[toptind - 1] = NULL; - nextchar = NULL; - } - return opt; - } - } - - /* Done scanning. Move all nonoptions to the end, set optind to the - * index of the first nonoption, and return -1. */ - toptind = argc; - while (--argc > 0) - if (argv[argc] != NULL) - argv[--toptind] = argv[argc]; - done = true; - return -1; -} diff --git a/compress/rw.go b/compress/rw.go new file mode 100644 index 00000000..af13d633 --- /dev/null +++ b/compress/rw.go @@ -0,0 +1,222 @@ +// Package compress provides convenience functions for creating compressors and +// uncompressors based on filenames. +package compress + +import ( + "bytes" + "compress/bzip2" + "context" + "fmt" + "io" + "io/ioutil" + + "github.com/grailbio/base/compress/zstd" + "github.com/grailbio/base/errors" + "github.com/grailbio/base/file" + "github.com/grailbio/base/fileio" + "github.com/grailbio/base/ioctx" + "github.com/klauspost/compress/gzip" + "github.com/yasushi-saito/zlibng" +) + +// errorReader is a ReadCloser implementation that always returns the given +// error. +type errorReader struct{ err error } + +func (r *errorReader) Read(buf []byte) (int, error) { return 0, r.err } +func (r *errorReader) Close() error { return r.err } + +// nopWriteCloser adds a noop Closer to io.Writer. +type nopWriteCloser struct{ io.Writer } + +func (w *nopWriteCloser) Close() error { return nil } + +func isBzip2Header(buf []byte) bool { + // https://www.forensicswiki.org/wiki/Bzip2 + if len(buf) < 10 { + return false + } + if !(buf[0] == 'B' && buf[1] == 'Z' && buf[2] == 'h' && buf[3] >= '1' && buf[3] <= '9') { + return false + } + if buf[4] == 0x31 && buf[5] == 0x41 && + buf[6] == 0x59 && buf[7] == 0x26 && + buf[8] == 0x53 && buf[9] == 0x59 { // block magic + return true + } + if buf[4] == 0x17 && buf[5] == 0x72 && + buf[6] == 0x45 && buf[7] == 0x38 && + buf[8] == 0x50 && buf[9] == 0x90 { // eos magic, happens only for an empty bz2 file. + return true + } + return false +} + +func isGzipHeader(buf []byte) bool { + if len(buf) < 10 { + return false + } + if !(buf[0] == 0x1f && buf[1] == 0x8b) { + return false + } + if !(buf[2] <= 3 || buf[2] == 8) { + return false + } + if (buf[3] & 0xc0) != 0 { + return false + } + if !(buf[9] <= 0xd || buf[9] == 0xff) { + return false + } + return true +} + +// https://tools.ietf.org/html/rfc8478 +func isZstdHeader(buf []byte) bool { + if len(buf) < 4 { + return false + } + if buf[0] != 0x28 || buf[1] != 0xB5 || buf[2] != 0x2F || buf[3] != 0xFD { + return false + } + return true +} + +// NewReader creates an uncompressing reader by reading the first few bytes of +// the input and finding a magic header for either gzip, zstd, bzip2. If the +// magic header is found , it returns an uncompressing ReadCloser and +// true. Else, it returns ioutil.NopCloser(r) and false. +// +// CAUTION: this function will misbehave when the input is a binary string that +// happens to have the same magic gzip, zstd, or bzip2 header. Thus, you should +// use this function only when the input is expected to be ASCII. +func NewReader(r io.Reader) (io.ReadCloser, bool) { + buf := bytes.Buffer{} + _, err := io.CopyN(&buf, r, 128) + var m io.Reader + switch err { + case io.EOF: + m = &buf + case nil: + m = io.MultiReader(&buf, r) + default: + m = io.MultiReader(&buf, &errorReader{err}) + } + if isGzipHeader(buf.Bytes()) { + z, err := zlibng.NewReader(m) + if err != nil { + return &errorReader{err}, false + } + return z, true + } + if isZstdHeader(buf.Bytes()) { + zr, err := zstd.NewReader(m) + if err != nil { + return &errorReader{err}, false + } + return zr, true + } + if isBzip2Header(buf.Bytes()) { + return ioutil.NopCloser(bzip2.NewReader(m)), true + } + return ioutil.NopCloser(m), false +} + +// NewReaderPath creates a reader that uncompresses data read from the given +// reader. The compression format is determined by the pathname extensions. If +// the pathname ends with one of the following extensions, it creates an +// uncompressing ReadCloser and returns true. +// +// .gz => gzip format +// .zst => zstd format +// .bz2 => bz2 format +// +// For other extensions, this function returns an ioutil.NopCloser(r) and false. +// +// The caller must close the ReadCloser after use. For some file formats, +// Close() is the only place that reports file corruption. +func NewReaderPath(r io.Reader, path string) (io.ReadCloser, bool) { + switch fileio.DetermineType(path) { + case fileio.Gzip: + gz, err := zlibng.NewReader(r) + if err != nil { + return file.NewError(err), false + } + return gz, true + case fileio.Zstd: + zr, err := zstd.NewReader(r) + if err != nil { + return file.NewError(err), false + } + return zr, true + case fileio.Bzip2: + return ioutil.NopCloser(bzip2.NewReader(r)), true + } + return ioutil.NopCloser(r), false +} + +// Open opens path with file.Open and decompresses with NewReaderPath. +func Open(ctx context.Context, path string) (io.ReadCloser, bool) { + f, err := file.Open(ctx, path) + if err != nil { + return file.NewError(err), false + } + r, isCompressed := NewReaderPath(f.Reader(ctx), path) + + return struct { + io.Reader + io.Closer + }{r, doubleCloser{r, ioctx.ToStdCloser(ctx, f)}}, isCompressed +} + +// NewWriterPath creates a WriteCloser that compresses data. The compression +// format is determined by the pathname extensions. If the pathname ends with +// one of the following extensions, it creates an compressing WriteCloser and +// returns true. +// +// .gz => gzip format +// .zst => zstd format +// +// For other extensions, this function creates a noop WriteCloser and returns +// false. The caller must close the WriteCloser after use. +func NewWriterPath(w io.Writer, path string) (io.WriteCloser, bool) { + switch fileio.DetermineType(path) { + case fileio.Gzip: + return gzip.NewWriter(w), true + case fileio.Zstd: + zw, err := zstd.NewWriter(w) + if err != nil { + return file.NewError(err), false + } + return zw, true + case fileio.Bzip2: + return file.NewError(fmt.Errorf("%s: bzip2 writer not supported", path)), false + } + return &nopWriteCloser{w}, false +} + +// Create creates path with file.Create and compresses with NewWriterPath. +func Create(ctx context.Context, path string, opts ...file.Opts) (io.WriteCloser, bool) { + f, err := file.Create(ctx, path, opts...) + if err != nil { + return file.NewError(err), false + } + w, isCompressed := NewWriterPath(f.Writer(ctx), path) + return struct { + io.Writer + io.Closer + }{w, doubleCloser{w, ioctx.ToStdCloser(ctx, f)}}, isCompressed +} + +// doubleCloser implements io.Closer and serves to clean up the boilerplate +// around closing both the files and reader/writer objects created in +// Open and Create. +type doubleCloser struct { + c, d io.Closer +} + +func (c doubleCloser) Close() (err error) { + errors.CleanUp(c.c.Close, &err) + errors.CleanUp(c.d.Close, &err) + return +} diff --git a/compress/rw_test.go b/compress/rw_test.go new file mode 100644 index 00000000..d056df9e --- /dev/null +++ b/compress/rw_test.go @@ -0,0 +1,175 @@ +package compress_test + +import ( + "bytes" + "compress/gzip" + "fmt" + "io" + "io/ioutil" + "math/rand" + "os" + "os/exec" + "strings" + "testing" + + "github.com/grailbio/base/compress" + "github.com/grailbio/testutil/assert" + "github.com/klauspost/compress/zstd" +) + +func testReader(t *testing.T, plaintext string, comp func(t *testing.T, in []byte) []byte) { + compressed := comp(t, []byte(plaintext)) + cr := bytes.NewReader(compressed) + r, n := compress.NewReader(cr) + assert.True(t, n) + assert.NotNil(t, r) + got := bytes.Buffer{} + _, err := io.Copy(&got, r) + assert.NoError(t, err) + assert.NoError(t, r.Close()) + assert.EQ(t, got.String(), plaintext) +} + +// Generate a random ASCII text. +func randomText(buf *strings.Builder, r *rand.Rand, n int) { + for i := 0; i < n; i++ { + buf.WriteByte(byte(r.Intn(96) + 32)) + } +} + +func gzipCompress(t *testing.T, in []byte) []byte { + buf := bytes.Buffer{} + w := gzip.NewWriter(&buf) + _, err := io.Copy(w, bytes.NewReader(in)) + assert.NoError(t, err) + assert.NoError(t, w.Close()) + return buf.Bytes() +} + +func bzip2Compress(t *testing.T, in []byte) []byte { + temp, err := ioutil.TempFile("", "test") + assert.NoError(t, err) + _, err = temp.Write(in) + assert.NoError(t, err) + assert.NoError(t, temp.Close()) + cmd := exec.Command("bzip2", temp.Name()) + assert.NoError(t, cmd.Run()) + + compressed, err := ioutil.ReadFile(temp.Name() + ".bz2") + assert.NoError(t, err) + assert.NoError(t, os.Remove(temp.Name()+".bz2")) + return compressed +} + +func zstdCompress(t *testing.T, in []byte) []byte { + buf := bytes.Buffer{} + // WithZeroFrames ensures that a zero-length input (like in TestReaderSmall) yields + // a non-empty output with a header that compress.NewReader can sniff. + w, err := zstd.NewWriter(&buf, zstd.WithZeroFrames(true)) + assert.NoError(t, err) + _, err = io.Copy(w, bytes.NewReader(in)) + assert.NoError(t, err) + assert.NoError(t, w.Close()) + return buf.Bytes() +} + +type compressor struct { + fn func(t *testing.T, in []byte) []byte + ext string +} + +var compressors = []compressor{ + {zstdCompress, "zst"}, + {gzipCompress, "gz"}, + {bzip2Compress, "bz2"}, +} + +func TestReaderSmall(t *testing.T) { + for _, c := range compressors { + t.Run(c.ext, func(t *testing.T) { + testReader(t, "", c.fn) + testReader(t, "hello", c.fn) + }) + n := 1 + for i := 1; i < 25; i++ { + t.Run(fmt.Sprint("format=", c.ext, ",n=", n), func(t *testing.T) { + r := rand.New(rand.NewSource(int64(i))) + n = (n + 1) * 3 / 2 + buf := strings.Builder{} + randomText(&buf, r, n) + testReader(t, buf.String(), c.fn) + }) + } + } +} + +func TestGzipReaderUncompressed(t *testing.T) { + data := make([]byte, 128<<10+1) + got := bytes.Buffer{} + + runTest := func(t *testing.T, n int) { + for i := range data[:n] { + // gzip/bzip2 header contains at least one char > 128, so the plaintext should + // never be conflated with a gzip header. + data[i] = byte(n + i%128) + } + cr := bytes.NewReader(data[:n]) + r, compressed := compress.NewReader(cr) + assert.False(t, compressed) + got.Reset() + nRead, err := io.Copy(&got, r) + assert.NoError(t, err) + assert.EQ(t, int(nRead), n) + assert.NoError(t, r.Close()) + assert.EQ(t, got.Bytes(), data[:n]) + } + + dataSize := 1 + for dataSize <= len(data) { + n := dataSize + t.Run(fmt.Sprint(n), func(t *testing.T) { runTest(t, n) }) + t.Run(fmt.Sprint(n-1), func(t *testing.T) { runTest(t, n-1) }) + t.Run(fmt.Sprint(n+1), func(t *testing.T) { runTest(t, n+1) }) + dataSize *= 2 + } +} + +func TestReaderWriterPath(t *testing.T) { + for _, c := range compressors { + t.Run(c.ext, func(t *testing.T) { + if c.ext == "bz2" { // bz2 compression not yet supported + t.Skip("bz2") + } + buf := bytes.Buffer{} + w, compressed := compress.NewWriterPath(&buf, "foo."+c.ext) + assert.True(t, compressed) + _, err := io.WriteString(w, "hello") + assert.NoError(t, w.Close()) + assert.NoError(t, err) + + r, compressed := compress.NewReaderPath(&buf, "foo."+c.ext) + assert.True(t, compressed) + data, err := ioutil.ReadAll(r) + assert.NoError(t, err) + assert.EQ(t, string(data), "hello") + assert.NoError(t, r.Close()) + }) + } +} + +// NewReaderPath and NewWriterPath for non-compressed extensions. +func TestReaderWriterPathNop(t *testing.T) { + buf := bytes.Buffer{} + w, compressed := compress.NewWriterPath(&buf, "foo.txt") + assert.False(t, compressed) + _, err := io.WriteString(w, "hello") + assert.NoError(t, w.Close()) + assert.NoError(t, err) + + r, compressed := compress.NewReaderPath(&buf, "foo.txt") + assert.False(t, compressed) + data, err := ioutil.ReadAll(r) + assert.NoError(t, err) + assert.EQ(t, string(data), "hello") + assert.NoError(t, r.Close()) +} diff --git a/compress/zstd/zstd_cgo.go b/compress/zstd/zstd_cgo.go new file mode 100644 index 00000000..1a847f13 --- /dev/null +++ b/compress/zstd/zstd_cgo.go @@ -0,0 +1,53 @@ +// Package zstd wraps github.com/DataDog/zstd and +// github.com/klauspost/compress/zstd. It uses DataDog/zstd in cgo mode, and +// klauspost/compress/zstd in noncgo mode. + +// +build cgo + +package zstd + +import ( + "io" + + cgozstd "github.com/DataDog/zstd" +) + +// Compress compresses the given source data. Scratch can be passed to prevent +// prevent allocation. If it is too small, or if nil is passed, a new buffer +// will be allocated and returned. Arg level specifies the compression +// level. level < 0 means the default compression level. +func CompressLevel(scratch []byte, in []byte, level int) ([]byte, error) { + if level < 0 { + level = cgozstd.DefaultCompression + } + if cap(scratch) == 0 { + scratch = nil + } else { + scratch = scratch[:cap(scratch)] + } + return cgozstd.CompressLevel(scratch, in, level) +} + +// Decompress uncompresses the given source data. Scratch can be passed to +// prevent allocation. If it is too small, or if nil is passed, a new buffer +// will be allocated and returned. +func Decompress(scratch []byte, in []byte) ([]byte, error) { + if cap(scratch) == 0 { + scratch = nil + } else { + scratch = scratch[:cap(scratch)] + } + return cgozstd.Decompress(scratch, in) +} + +// NewReader creates a ReadCloser that uncompresses data. The returned object +// must be Closed after use. +func NewReader(r io.Reader) (io.ReadCloser, error) { + return cgozstd.NewReader(r), nil +} + +// NewWriter creates a WriterCloser that compresses data. The returned object +// must be Closed after use. +func NewWriter(w io.Writer) (io.WriteCloser, error) { + return cgozstd.NewWriter(w), nil +} diff --git a/compress/zstd/zstd_nocgo.go b/compress/zstd/zstd_nocgo.go new file mode 100644 index 00000000..96b6e320 --- /dev/null +++ b/compress/zstd/zstd_nocgo.go @@ -0,0 +1,67 @@ +// +build !cgo + +package zstd + +import ( + "bytes" + "io" + + nocgozstd "github.com/klauspost/compress/zstd" +) + +func CompressLevel(scratch []byte, in []byte, level int) ([]byte, error) { + if level < 0 { + level = 5 // 5 is the default compression const in cgo zstd + } + wBuf := bytes.NewBuffer(scratch[:0]) + w, err := nocgozstd.NewWriter(wBuf, + nocgozstd.WithEncoderLevel(nocgozstd.EncoderLevelFromZstd(level))) + if err != nil { + return nil, err + } + rBuf := bytes.NewReader(in) + _, err = io.Copy(w, rBuf) + if err != nil { + return nil, err + } + if err := w.Close(); err != nil { + return nil, err + } + return wBuf.Bytes(), nil +} + +func Decompress(scratch []byte, in []byte) ([]byte, error) { + rBuf := bytes.NewReader(in) + r, err := nocgozstd.NewReader(rBuf) + if err != nil { + return nil, err + } + + wBuf := bytes.NewBuffer(scratch[:0]) + if _, err = io.Copy(wBuf, r); err != nil { + return nil, err + } + r.Close() + return wBuf.Bytes(), nil +} + +type readerWrapper struct { + *nocgozstd.Decoder +} + +func (r *readerWrapper) Close() error { + r.Decoder.Close() + return nil +} + +func NewReader(r io.Reader) (io.ReadCloser, error) { + zr, err := nocgozstd.NewReader(r) + if err != nil { + return nil, err + } + return &readerWrapper{zr}, nil +} + +func NewWriter(w io.Writer) (io.WriteCloser, error) { + return nocgozstd.NewWriter(w) +} diff --git a/compress/zstd/zstd_test.go b/compress/zstd/zstd_test.go new file mode 100644 index 00000000..56ce4e9e --- /dev/null +++ b/compress/zstd/zstd_test.go @@ -0,0 +1,95 @@ +package zstd_test + +import ( + "flag" + "io/ioutil" + "os" + "testing" + + "bytes" + "io" + + "github.com/grailbio/base/compress/zstd" + "github.com/grailbio/testutil/assert" +) + +func TestCompress(t *testing.T) { + z, err := zstd.CompressLevel(nil, []byte("hello"), -1) + assert.NoError(t, err) + assert.GT(t, len(z), 0) + d, err := zstd.Decompress(nil, z) + assert.NoError(t, err) + assert.EQ(t, d, []byte("hello")) +} + +func TestCompressScratch(t *testing.T) { + z, err := zstd.CompressLevel(make([]byte, 3), []byte("hello"), -1) + assert.NoError(t, err) + assert.GT(t, len(z), 0) + d, err := zstd.Decompress(make([]byte, 3), z) + assert.NoError(t, err) + assert.EQ(t, d, []byte("hello")) +} + +func TestReadWrite(t *testing.T) { + buf := bytes.Buffer{} + w, err := zstd.NewWriter(&buf) + assert.NoError(t, err) + _, err = io.WriteString(w, "hello2") + assert.NoError(t, err) + assert.NoError(t, w.Close()) + + r, err := zstd.NewReader(&buf) + assert.NoError(t, err) + d, err := ioutil.ReadAll(r) + assert.NoError(t, err) + assert.EQ(t, d, []byte("hello2")) +} + +var plaintextFlag = flag.String("plaintext", "", "plaintext file used in compression test") + +func BenchmarkCompress(b *testing.B) { + if *plaintextFlag == "" { + b.Skip("--plaintext not set") + } + + for i := 0; i < b.N; i++ { + buf := bytes.Buffer{} + w, err := zstd.NewWriter(&buf) + assert.NoError(b, err) + r, err := os.Open(*plaintextFlag) + assert.NoError(b, err) + _, err = io.Copy(w, r) + assert.NoError(b, err) + assert.NoError(b, w.Close()) + assert.NoError(b, r.Close()) + } +} + +func BenchmarkUncompress(b *testing.B) { + if *plaintextFlag == "" { + b.Skip("--plaintext not set") + } + + b.StopTimer() + buf := bytes.Buffer{} + w, err := zstd.NewWriter(&buf) + assert.NoError(b, err) + r, err := os.Open(*plaintextFlag) + assert.NoError(b, err) + _, err = io.Copy(w, r) + assert.NoError(b, err) + assert.NoError(b, w.Close()) + assert.NoError(b, r.Close()) + b.StartTimer() + + for i := 0; i < b.N; i++ { + zr, err := zstd.NewReader(bytes.NewReader(buf.Bytes())) + assert.NoError(b, err) + + w := bytes.Buffer{} + _, err = io.Copy(&w, zr) + assert.NoError(b, err) + assert.NoError(b, zr.Close()) + } +} diff --git a/config/aws/aws.go b/config/aws/aws.go new file mode 100644 index 00000000..21910ecb --- /dev/null +++ b/config/aws/aws.go @@ -0,0 +1,24 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package aws + +import ( + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/grailbio/base/config" +) + +func init() { + config.Register("aws/env", func(constr *config.Constructor[*session.Session]) { + var cfg aws.Config + cfg.Region = constr.String("region", "us-west-2", "the default AWS region for the session") + constr.Doc = "configure an AWS session from the environment" + constr.New = func() (*session.Session, error) { + return session.NewSession(&cfg) + } + }) + + config.Default("aws", "aws/env") +} diff --git a/config/awsticket/awsticket.go b/config/awsticket/awsticket.go new file mode 100644 index 00000000..f9bf5aa3 --- /dev/null +++ b/config/awsticket/awsticket.go @@ -0,0 +1,36 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package awsticket + +import ( + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/grailbio/base/cloud/awssession" + "github.com/grailbio/base/config" + "github.com/grailbio/base/vcontext" +) + +func init() { + config.Register("aws/ticket", func(constr *config.Constructor[*session.Session]) { + var ( + region = constr.String("region", "us-west-2", "the default AWS region for the session") + path = constr.String("path", "tickets/eng/dev/aws", "path to AWS ticket") + ) + constr.Doc = "configure an AWS session from a GRAIL ticket server path" + constr.New = func() (*session.Session, error) { + return session.NewSession(&aws.Config{ + Credentials: credentials.NewCredentials(&awssession.Provider{ + Ctx: vcontext.Background(), + Timeout: 10 * time.Second, + TicketPath: *path, + }), + Region: region, + }) + } + }) +} diff --git a/config/cmd/demosuggestimpl/main.go b/config/cmd/demosuggestimpl/main.go new file mode 100644 index 00000000..a60c1498 --- /dev/null +++ b/config/cmd/demosuggestimpl/main.go @@ -0,0 +1,59 @@ +package main + +import ( + "flag" + "fmt" + + "github.com/grailbio/base/config" + "github.com/grailbio/base/must" +) + +type ( + Fruit interface{ IsFruit() } + Apple struct{ color string } + Orange struct{} +) + +func (Apple) IsFruit() {} +func (Orange) IsFruit() {} + +func init() { + config.Register("fruits/apple-red", func(c *config.Constructor[Apple]) { + c.Doc = "Some people like apples." + c.New = func() (Apple, error) { return Apple{"red"}, nil } + }) + config.Register("fruits/apple-green", func(c *config.Constructor[Apple]) { + c.Doc = "Another apple." + c.New = func() (Apple, error) { return Apple{"green"}, nil } + }) + config.Register("fruits/orange", func(c *config.Constructor[Orange]) { + c.Doc = "Some people like oranges." + c.New = func() (Orange, error) { return Orange{}, nil } + }) + config.Register("favorite", func(c *config.Constructor[Fruit]) { + c.Doc = "My favorite fruit." + var favorite Fruit + c.InstanceVar(&favorite, "is", "favorite-apple", "Favorite fruit?") + c.New = func() (Fruit, error) { return favorite, nil } + }) + config.Register("favorite-apple", func(c *config.Constructor[Apple]) { + c.Doc = "My favorite apple." + var favorite Apple + c.InstanceVar(&favorite, "is", "fruits/apple-green", "Favorite apple?") + c.New = func() (Apple, error) { return favorite, nil } + }) +} + +func main() { + config.RegisterFlags("", "") + flag.Parse() + must.Nil(config.ProcessFlags()) + + var fruit Fruit + must.Nil(config.Instance("favorite", &fruit)) + fmt.Printf("My favorite fruit is %#v.\n", fruit) + + var apple Apple + must.Nil(config.Instance("favorite-apple", &apple)) + fmt.Printf("My favorite apple is %#v.\n", apple) +} diff --git a/config/flag.go b/config/flag.go new file mode 100644 index 00000000..98ca7e88 --- /dev/null +++ b/config/flag.go @@ -0,0 +1,182 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package config + +import ( + "context" + "flag" + "fmt" + "os" + "strings" + + "github.com/grailbio/base/backgroundcontext" + "github.com/grailbio/base/errors" + "github.com/grailbio/base/file" +) + +type ( + // flags is an ordered representation of profile flags. Each entry (implementing flagEntry) + // is a type of flag, and entry types may be interleaved. They're handled in the order + // the user passed them. + // + // The flags object is wrapped for each entry type, and each wrapper's flag.Value implementation + // appends the appropriate entry. + flags struct { + defaultProfilePath string + entries []flagEntry + } + flagsProfilePaths flags + flagsProfileInlines flags + flagsSets flags + + flagEntry interface { + process(context.Context, *Profile) error + } + flagEntryProfilePath struct{ string } + flagEntryProfileInline struct{ string } + flagEntrySet struct{ key, value string } +) + +var ( + _ flagEntry = flagEntryProfilePath{} + _ flagEntry = flagEntryProfileInline{} + _ flagEntry = flagEntrySet{} + + _ flag.Value = (*flagsProfilePaths)(nil) + _ flag.Value = (*flagsProfileInlines)(nil) + _ flag.Value = (*flagsSets)(nil) +) + +func (e flagEntryProfilePath) process(ctx context.Context, p *Profile) error { + return p.loadFile(ctx, e.string) +} +func (e flagEntryProfileInline) process(_ context.Context, p *Profile) error { + return p.Parse(strings.NewReader(e.string)) +} +func (e flagEntrySet) process(_ context.Context, p *Profile) error { + return p.Set(e.key, e.value) +} + +func (f *flagsProfilePaths) String() string { return f.defaultProfilePath } +func (*flagsProfileInlines) String() string { return "" } +func (*flagsSets) String() string { return "" } + +func (f *flagsProfilePaths) Set(s string) error { + if s == "" { + return errors.New("empty path to profile") + } + f.entries = append(f.entries, flagEntryProfilePath{s}) + return nil +} +func (f *flagsProfileInlines) Set(s string) error { + if s != "" { + f.entries = append(f.entries, flagEntryProfileInline{s}) + } + return nil +} +func (f *flagsSets) Set(s string) error { + elems := strings.SplitN(s, "=", 2+1) // Split an additional part to detect errors. + if len(elems) != 2 || elems[0] == "" { + return fmt.Errorf("wrong argument format, expected key=value, got %q", s) + } + f.entries = append(f.entries, flagEntrySet{elems[0], elems[1]}) + return nil +} + +// RegisterFlags registers a set of flags on the provided FlagSet. +// These flags configure the profile when ProcessFlags is called +// (after flag parsing). The flags are: +// +// -profile path +// Parses and loads the profile at the given path. This flag may be +// repeated, loading each profile in turn. If no -profile flags are +// specified, then the provided default path is loaded instead. If +// the default path does not exist, it is skipped; other profile loading +// errors cause ProcessFlags to return an error. +// +// -set key=value +// Sets the value of the named parameter. See Profile.Set for +// details. This flag may be repeated. +// +// -profileinline text +// Parses the argument. This is equivalent to writing the text to a file +// and using -profile. +// +// -profiledump +// Writes the profile (after processing the above flags) to standard +// error and exits. +// +// The flag names are prefixed with the provided prefix. +func (p *Profile) RegisterFlags(fs *flag.FlagSet, prefix string, defaultProfilePath string) { + p.flags.defaultProfilePath = defaultProfilePath + fs.Var((*flagsProfilePaths)(&p.flags), prefix+"profile", "load the profile at the provided path; may be repeated") + fs.Var((*flagsSets)(&p.flags), prefix+"set", "set a profile parameter; may be repeated") + fs.Var((*flagsProfileInlines)(&p.flags), prefix+"profileinline", "parse the profile passed as an argument; may be repeated") + fs.BoolVar(&p.flagDump, "profiledump", false, "dump the profile to stderr and exit") +} + +// NeedProcessFlags returns true when a call to p.ProcessFlags should +// not be delayed -- i.e., the flag values have user-visible side effects. +func (p *Profile) NeedProcessFlags() bool { + return p.flagDump +} + +func (f *flags) hasProfilePathEntry() bool { + for _, entry := range f.entries { + if _, ok := entry.(flagEntryProfilePath); ok { + return true + } + } + return false +} + +func (p *Profile) loadFile(ctx context.Context, path string) (err error) { + f, err := file.Open(ctx, path) + if err != nil { + return err + } + defer errors.CleanUpCtx(ctx, f.Close, &err) + return p.Parse(f.Reader(ctx)) +} + +// ProcessFlags processes the flags as registered by RegisterFlags, +// and is documented by that method. +func (p *Profile) ProcessFlags() error { + ctx := backgroundcontext.Get() + if p.flags.defaultProfilePath != "" && !p.flags.hasProfilePathEntry() { + if err := p.loadFile(ctx, p.flags.defaultProfilePath); err != nil { + if !errors.Is(errors.NotExist, err) { + return err + } + } + } + for _, entry := range p.flags.entries { + if err := entry.process(ctx, p); err != nil { + return err + } + } + if p.flagDump { + // TODO(marius): also prune uninstantiable instances? + for _, inst := range p.sorted() { + if len(inst.params) == 0 && inst.parent == "" { + continue + } + fmt.Fprintln(os.Stderr, inst.SyntaxString(p.docs(inst))) + } + os.Exit(1) + } + return nil +} + +// RegisterFlags registers the default profile on flag.CommandLine +// with the provided prefix. See Profile.RegisterFlags for details. +func RegisterFlags(prefix string, defaultProfilePath string) { + Application().RegisterFlags(flag.CommandLine, prefix, defaultProfilePath) +} + +// ProcessFlags processes the flags as registered by RegisterFlags. +func ProcessFlags() error { + return Application().ProcessFlags() +} diff --git a/config/flag_test.go b/config/flag_test.go new file mode 100644 index 00000000..fc981f19 --- /dev/null +++ b/config/flag_test.go @@ -0,0 +1,217 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package config + +import ( + "errors" + "flag" + "fmt" + "os" + "runtime" + "strconv" + "testing" + + "github.com/grailbio/base/must" +) + +// This test uses a mock "app" to demonstrate various aspects of package config. + +type credentials interface { + Creds() string +} + +type userCredentials string + +func (u userCredentials) Creds() string { return string(u) } + +type envCredentials struct{} + +func (e envCredentials) Creds() string { return "environment" } + +type database struct { + table string + creds credentials +} + +type frontend struct { + db database + creds credentials + limit int +} + +func init() { + Register("app/auth/env", func(constr *Constructor[envCredentials]) { + constr.New = func() (envCredentials, error) { + return envCredentials{}, nil + } + }) + Register("app/auth/login", func(constr *Constructor[userCredentials]) { + var ( + username = constr.String("user", "test", "the username") + password = constr.String("password", "secret", "the password") + ) + constr.New = func() (userCredentials, error) { + return userCredentials(fmt.Sprintf("%s:%s", *username, *password)), nil + } + }) + + Register("app/database", func(constr *Constructor[database]) { + var db database + constr.StringVar(&db.table, "table", "defaulttable", "the database table") + constr.InstanceVar(&db.creds, "credentials", "app/auth/env", "credentials used for database access") + constr.New = func() (database, error) { + if db.creds == nil { + return database{}, errors.New("credentials not defined") + } + return db, nil + } + }) + + Register("app/frontend", func(constr *Constructor[frontend]) { + var fe frontend + constr.InstanceVar(&fe.db, "database", "app/database", "the database to be used") + constr.InstanceVar(&fe.creds, "credentials", "app/auth/env", "credentials to use for authentication") + constr.IntVar(&fe.limit, "limit", 128, "maximum number of concurrent requests to handle") + constr.New = func() (frontend, error) { + if fe.db == (database{}) || fe.creds == nil { + return frontend{}, errors.New("missing configuration") + } + return fe, nil + } + }) +} + +func TestFlag(t *testing.T) { + profile := func(args ...string) *Profile { + t.Helper() + p := New() + f, err := os.Open("testdata/profile") + must.Nil(err) + defer f.Close() + if err := p.Parse(f); err != nil { + t.Fatal(err) + } + fs := flag.NewFlagSet("test", flag.PanicOnError) + p.RegisterFlags(fs, "", "testdata/profile") + if err := fs.Parse(args); err != nil { + t.Fatal(err) + } + if err := p.ProcessFlags(); err != nil { + t.Fatal(err) + } + return p + } + + for _, test := range []struct { + line int + args []string + wantFE, wantDB string + }{ + { + callerLine(), + nil, + "marius:supersecret", "marius:supersecret", + }, + { + callerLine(), + []string{"-set", "app/auth/login.password=public"}, + "marius:public", "marius:public", + }, + { + callerLine(), + []string{"-set", "app/frontend.credentials=app/auth/env"}, + "environment", "marius:supersecret", + }, + { + callerLine(), + []string{"-profileinline", `param app/auth/login password = "public"`}, + "marius:public", "marius:public", + }, + { + callerLine(), + []string{ + "-set", "app/auth/login.password=public", + "-profile", "testdata/profile", + }, + // Parameter settings in profile file should override, since they come later. + "marius:supersecret", "marius:supersecret", + }, + { + callerLine(), + []string{ + "-set", "app/auth/login.password=public", + "-profile", "testdata/profile", + "-set", "app/auth/login.password=hunter2", + }, + "marius:hunter2", "marius:hunter2", + }, + { + callerLine(), + []string{ + "-set", "app/auth/login.password=public", + "-profile", "testdata/profile", + "-set", "app/auth/login.password=hunter2", + "-profileinline", ` + instance test/felogin app/auth/login ( + user = "tester" + ) + param app/frontend credentials = test/felogin + `, + }, + "tester:hunter2", "marius:hunter2", + }, + { + callerLine(), + []string{ + "-set", "app/auth/login.password=public", + "-profile", "testdata/profile", + "-set", "app/auth/login.password=hunter2", + "-profileinline", ` + instance test/felogin app/auth/login ( + user = "tester" + ) + param app/frontend credentials = test/felogin + param test/felogin password = "abc" + `, + }, + "tester:abc", "marius:hunter2", + }, + { + callerLine(), + []string{ + "-set", "app/auth/login.password=public", + "-profile", "testdata/profile", + "-set", "app/auth/login.password=hunter2", + "-profileinline", ` + instance test/felogin app/auth/login ( + user = "tester" + ) + param app/frontend credentials = test/felogin + `, + "-profile", "testdata/profile_felogin_password", + }, + "tester:abc", "marius:hunter2", + }, + } { + t.Run(strconv.Itoa(test.line), func(t *testing.T) { + p := profile(test.args...) + var fe frontend + if err := p.Instance("app/frontend", &fe); err != nil { + t.Fatal(err) + } + if got, want := fe.creds.Creds(), test.wantFE; got != want { + t.Errorf("got %v, want %v", got, want) + } + if got, want := fe.db.creds.Creds(), test.wantDB; got != want { + t.Errorf("got %v, want %v", got, want) + } + }) + } +} + +func callerLine() int { + _, _, line, _ := runtime.Caller(1) // 1 skips the callerLine() frame. + return line +} diff --git a/config/http/http.go b/config/http/http.go new file mode 100644 index 00000000..2d71da15 --- /dev/null +++ b/config/http/http.go @@ -0,0 +1,29 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +// Package http defines profile providers for local HTTP servers. +// It is imported for its side effects. +package http + +import ( + "net/http" + + "github.com/grailbio/base/config" + "github.com/grailbio/base/log" +) + +func init() { + config.Register("http", func(constr *config.Constructor[config.Nil]) { + addr := constr.String("addr", ":3333", "the address used for serving http") + constr.Doc = "configure a local HTTP server, using the default http muxer" + constr.New = func() (config.Nil, error) { + go func() { + log.Print("http: serve ", *addr) + err := http.ListenAndServe(*addr, nil) + log.Error.Print("http: serve ", *addr, ": ", err) + }() + return nil, nil + } + }) +} diff --git a/config/instance.go b/config/instance.go new file mode 100644 index 00000000..3934e339 --- /dev/null +++ b/config/instance.go @@ -0,0 +1,255 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package config + +import ( + "fmt" + "reflect" + "runtime" + "sync" +) + +// typedConfigure is a type-erased version of func(*Constructor[T]). +type typedConfigure struct { + configure func(*Constructor[any]) + typ reflect.Type +} + +var ( + globalsMu sync.Mutex + globals = make(map[string]typedConfigure) + defaults = make(map[string]string) +) + +// Register registers a constructor and later invokes the provided +// function whenever a new profile instance is created. Register +// panics if multiple constructors are registered with the same name. +// Constructors should typically be registered in package init +// functions, and the configure function must define at least +// Constructor.New. For example, the following configures a +// constructor with a single parameter, n, which simply returns its +// value. +// +// config.Register("config/test", func(constr *config.Constructor[int]) { +// n := constr.Int("n", 32, "the number configured") +// constr.New = func() (int, error) { +// return *n, nil +// } +// constr.Doc = "a customizable integer" +// }) +func Register[T any](name string, configure func(*Constructor[T])) { + globalsMu.Lock() + defer globalsMu.Unlock() + if _, found := globals[name]; found { + panic("config.Register: instance with name " + name + " has already been registered") + } + globals[name] = typedConfigure{ + func(untyped *Constructor[any]) { + typed := Constructor[T]{params: untyped.params} + configure(&typed) + untyped.Doc = typed.Doc + untyped.New = func() (any, error) { return typed.New() } + }, + reflect.TypeOf(new(T)).Elem(), + } +} + +// Default declares a new derived instance. It is a convenience +// function used to provide a default implementation among multiple +// choices, and is equivalent to the the profile directive +// +// instance name instance +// +// Default panics if name is already the name of an instance, or if +// the specified parent instance does not exist. +func Default(name, instance string) { + globalsMu.Lock() + defer globalsMu.Unlock() + if _, found := globals[name]; found { + panic("config.Default: default " + name + " has same name as a global") + } + if _, found := globals[instance]; !found { + if _, found = defaults[instance]; !found { + panic("config.Default: instance " + instance + " does not exist") + } + } + defaults[name] = instance +} + +type ( + // Constructor defines a constructor, as configured by Register. + // Typically a constructor registers a set of parameters through the + // flags-like methods provided by Constructor. The value returned by + // New is configured by these parameters. + Constructor[T any] struct { + New func() (T, error) + Doc string + params map[string]*param + } + // Nil is an interface type with no implementations. Constructor[Nil] + // indicates an instance is created just for its side effects. + Nil interface{ neverImplemented() } +) + +func newConstructor() *Constructor[any] { + return &Constructor[any]{ + params: make(map[string]*param), + } +} + +// InstanceVar registers a parameter that is satisfied by another +// instance; the method panics if ptr is not a pointer. The default +// value is always an indirection; if it is left empty it is taken as +// the nil value: it remains uninitialized by default. +func (c *Constructor[_]) InstanceVar(ptr interface{}, name string, value string, help string) { + ptrTyp := reflect.TypeOf(ptr) + if ptrTyp.Kind() != reflect.Ptr { + panic(fmt.Sprintf( + "Instance.InterfaceVar: passed ptr %s is not a pointer", + ptrTyp, + )) + } + param := c.define(name, paramInterface, help) + param.ifaceptr = ptr + if value == "nil" { + value = "" + } + if value == "" && !isNilAssignable(ptrTyp.Elem()) { + // TODO: Consider allowing empty values to mean zero values for types + // that are not nil-assignable. We currently do not allow the empty + // string to be consistent with parsing, as there is no way to set a + // parameter to an empty value, as we require an identifier. + panic(fmt.Sprintf( + "Instance.InterfaceVar: ptr element %s cannot have nil/empty value", + ptrTyp.Elem(), + )) + } + param.ifaceindir = indirect(value) +} + +// Int registers an integer parameter with a default value. The returned +// pointer points to its value. +func (c *Constructor[_]) Int(name string, value int, help string) *int { + p := new(int) + c.IntVar(p, name, value, help) + return p +} + +// IntVar registers an integer parameter with a default value. The parameter's +// value written to the location pointed to by ptr. +func (c *Constructor[_]) IntVar(ptr *int, name string, value int, help string) { + *ptr = value + c.define(name, paramInt, help).intptr = ptr +} + +// Float registers floating point parameter with a default value. The returned +// pointer points to its value. +func (c *Constructor[_]) Float(name string, value float64, help string) *float64 { + p := new(float64) + c.FloatVar(p, name, value, help) + return p +} + +// FloatVar register a floating point parameter with a default value. The parameter's +// value is written to the provided pointer. +func (c *Constructor[_]) FloatVar(ptr *float64, name string, value float64, help string) { + *ptr = value + c.define(name, paramFloat, help).floatptr = ptr +} + +// String registers a string parameter with a default value. The returned pointer +// points to its value. +func (c *Constructor[_]) String(name string, value string, help string) *string { + p := new(string) + c.StringVar(p, name, value, help) + return p +} + +// StringVar registers a string parameter with a default value. The parameter's +// value written to the location pointed to by ptr. +func (c *Constructor[_]) StringVar(ptr *string, name string, value string, help string) { + *ptr = value + c.define(name, paramString, help).strptr = ptr +} + +// Bool registers a boolean parameter with a default value. The returned pointer +// points to its value. +func (c *Constructor[_]) Bool(name string, value bool, help string) *bool { + p := new(bool) + c.BoolVar(p, name, value, help) + return p +} + +// BoolVar registers a boolean parameter with a default value. The parameter's +// value written to the location pointed to by ptr. +func (c *Constructor[_]) BoolVar(ptr *bool, name string, value bool, help string) { + *ptr = value + c.define(name, paramBool, help).boolptr = ptr +} + +func (c *Constructor[_]) define(name string, kind int, help string) *param { + if c.params[name] != nil { + panic("config: parameter " + name + " already defined") + } + p := ¶m{kind: kind, help: help} + _, p.file, p.line, _ = runtime.Caller(2) + c.params[name] = p + return c.params[name] +} + +const ( + paramInterface = iota + paramInt + paramFloat + paramString + paramBool +) + +type param struct { + kind int + help string + + file string + line int + + intptr *int + floatptr *float64 + ifaceptr interface{} + ifaceindir indirect + strptr *string + boolptr *bool +} + +func (p *param) Interface() interface{} { + switch p.kind { + case paramInterface: + return reflect.ValueOf(p.ifaceptr).Elem().Interface() + case paramInt: + return *p.intptr + case paramFloat: + return *p.floatptr + case paramString: + return *p.strptr + case paramBool: + return *p.boolptr + default: + panic(p.kind) + } +} + +func isNilAssignable(typ reflect.Type) bool { + switch typ.Kind() { + case reflect.Chan: + case reflect.Func: + case reflect.Interface: + case reflect.Map: + case reflect.Ptr: + case reflect.Slice: + case reflect.UnsafePointer: + default: + return false + } + return true +} diff --git a/config/parse.go b/config/parse.go new file mode 100644 index 00000000..4a1f2a71 --- /dev/null +++ b/config/parse.go @@ -0,0 +1,551 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package config + +import ( + "errors" + "fmt" + "io" + "log" + "os" + "path/filepath" + "sort" + "strconv" + "strings" + "text/scanner" + "unicode" +) + +// insertionToks defines the sets of tokens after which +// a semicolon is inserted. +var insertionToks = map[rune]bool{ + scanner.Ident: true, + scanner.String: true, + scanner.RawString: true, + scanner.Int: true, + scanner.Float: true, + scanner.Char: true, + ')': true, + '}': true, + ']': true, +} + +// def wraps a value to indicate that it is a default. +type def struct{ value any } + +// unwrap returns the value v, unwrapped from def. +func unwrap(v interface{}) (_ any, wasDef bool) { + if v, ok := v.(def); ok { + u, _ := unwrap(v.value) + return u, true + } + return v, false +} + +// indirect is a type that indicates an indirection. +type indirect string + +// GoString renders an indirect type as a string without quotes, +// matching the concrete representation of indirections. +func (i indirect) GoString() string { + if i == "" { + return "nil" + } + return string(i) +} + +// An instance stores a parsed configuration clause. +type instance struct { + // name is the global name of the instance. + name string + // parent is the instance of which this is derived, if any. + parent string + // params contains the set of parameters defined by this instance. + // The values of the parameter map takes on valid config literal + // values. They are: indirect, bool, int, float64, and string. + params map[string]interface{} +} + +// Merge merges the provided instance into inst. Any +// nondefault parameter values in other are set in this +// instance. +func (inst *instance) Merge(other *instance) { + if other.parent != "" { + inst.parent = other.parent + } + for k, v := range other.params { + if _, ok := v.(def); ok { + continue + } + inst.params[k] = v + } +} + +// Equal tells whether two instances are equal. +func (inst *instance) Equal(other *instance) bool { + if inst.name != other.name || inst.parent != other.parent || len(inst.params) != len(other.params) { + return false + } + for k, v := range inst.params { + w, ok := other.params[k] + if !ok { + return false + } + v, _ = unwrap(v) + w, _ = unwrap(w) + switch vval := v.(type) { + case indirect: + wval, ok := w.(indirect) + if !ok || vval != wval { + return false + } + case string: + wval, ok := w.(string) + if !ok || vval != wval { + return false + } + case bool: + wval, ok := w.(bool) + if !ok || vval != wval { + return false + } + case int: + wval, ok := w.(int) + if !ok || vval != wval { + return false + } + case float64: + wval, ok := w.(float64) + if !ok || vval != wval { + return false + } + } + } + return true +} + +// instances stores a collection of named instanes. +type instances map[string]*instance + +// Merge merges an instance into this collection. +func (m instances) Merge(inst *instance) { + if m[inst.name] == nil { + m[inst.name] = inst + return + } + m[inst.name].Merge(inst) +} + +// Equal tells whether instances m is equal to instances n. +func (m instances) Equal(n instances) bool { + if len(m) != len(n) { + return false + } + for name, minst := range m { + ninst, ok := n[name] + if !ok { + return false + } + if !minst.Equal(ninst) { + return false + } + } + return true +} + +// SyntaxString returns a string representation of this instance +// which is also valid config syntax. Docs optionally provides +// documentation for the parameters in the instance. +func (inst *instance) SyntaxString(docs map[string]string) string { + // TODO: Consider printing floats with minimum precision (1 appears as 1.0) so users + // can easily contrast them with integers. + var b strings.Builder + writeDoc(&b, "", docs[""]) + if inst.parent == "" { + b.WriteString("param ") + b.WriteString(inst.name) + if len(inst.params) == 0 { + b.WriteString(" ()\n") + return b.String() + } + b.WriteString(" (\n") + writeParams(&b, inst.params, docs) + b.WriteString(")\n") + return b.String() + } + b.WriteString("instance ") + b.WriteString(inst.name) + b.WriteString(" ") + b.WriteString(inst.parent) + if len(inst.params) > 0 { + b.WriteString(" (\n") + writeParams(&b, inst.params, docs) + b.WriteString(")") + } + b.WriteString("\n") + return b.String() +} + +func writeDoc(b *strings.Builder, prefix string, doc string) { + if doc == "" { + return + } + for _, line := range strings.Split(doc, "\n") { + b.WriteString(prefix) + b.WriteString("// ") + b.WriteString(line) + b.WriteString("\n") + } +} + +func writeParams(b *strings.Builder, params map[string]any, docs map[string]string) { + forEachParam(params, func(name string, v any) { + writeDoc(b, "\t", docs[name]) + v, wasDef := unwrap(v) + var repr string + switch vt := v.(type) { + case string: + // Improve readability by using a raw literal (no quote-escaping), if possible. + if strings.ContainsRune(vt, '"') && !strings.ContainsRune(vt, '`') { + repr = "`" + vt + "`" + } else { + repr = strconv.Quote(vt) + } + default: + repr = fmt.Sprintf("%#v", v) + } + fmt.Fprintf(b, "\t%s = %s", name, repr) + if wasDef { + b.WriteString(" // default") + } + b.WriteString("\n") + }) +} + +func forEachParam(params map[string]any, fn func(k string, v any)) { + keys := make([]string, 0, len(params)) + for k := range params { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + fn(k, params[k]) + } +} + +// A parser stores parser state defines the productions +// in the profile grammar. +type parser struct { + scanner scanner.Scanner + errors []string + + insertion bool + scanned rune +} + +// parse parses the config read by the provided reader into a +// concrete profile into a set of instances. If the reader r +// implements +// +// Name() string +// +// then this is used as a filename to display positional information +// in error messages. +func parse(r io.Reader) (instances, error) { + var p parser + p.scanner.Whitespace &= ^uint64(1 << '\n') + p.scanner.Mode = scanner.ScanIdents | scanner.ScanFloats | scanner.ScanChars | + scanner.ScanStrings | scanner.ScanRawStrings + p.scanner.IsIdentRune = func(ch rune, i int) bool { + return unicode.IsLetter(ch) || (unicode.IsDigit(ch) || ch == '_' || ch == '/' || ch == '-') && i > 0 + } + if named, ok := r.(interface{ Name() string }); ok { + filename := named.Name() + if cwd, err := os.Getwd(); err == nil { + if rel, err := filepath.Rel(cwd, filename); err == nil && len(rel) < len(filename) { + filename = rel + } + } + p.scanner.Position.Filename = filename + } + p.scanner.Error = func(s *scanner.Scanner, msg string) { + // TODO(marius): report these in error + log.Printf("%s: %s", s.Position, msg) + } + p.scanner.Init(r) + if insts, ok := p.toplevel(); ok { + return insts, nil + } + switch len(p.errors) { + case 0: + return nil, errors.New("parse error") + case 1: + return nil, fmt.Errorf("parse error: %s", p.errors[0]) + default: + return nil, fmt.Errorf("parse error:\n%s", strings.Join(p.errors, "\n")) + } +} + +// toplevel parses the config grammar. It is as follows: +// +// toplevel: +// clause +// clause ';' toplevel +// +// +// clause: +// param +// instance +// +// param: +// ident assign +// ident assignlist +// +// instance: +// ident ident +// ident ident assignlist +// +// assign: +// key = value +// +// assignlist: +// ( list ) +// +// list: +// assign +// assign ';' list +// +// value: +// 'true' +// 'false' +// 'nil' +// ident +// integer +// float +// string +func (p *parser) toplevel() (insts instances, ok bool) { + ok = true // Empty input is okay. + insts = make(instances) + for { + switch p.next() { + case scanner.EOF: + return + case ';': + case scanner.Ident: + switch p.text() { + case "param": + var ( + name string + params map[string]interface{} + ) + name, params, ok = p.param() + if !ok { + return + } + insts.Merge(&instance{name: name, params: params}) + case "instance": + var inst *instance + inst, ok = p.instance() + if !ok { + return + } + insts.Merge(inst) + default: + p.errorf("unrecognized toplevel clause: %s", p.text()) + return nil, false + } + } + } +} + +// param: +// ident assign +// ident assignlist +func (p *parser) param() (instance string, params map[string]interface{}, ok bool) { + if p.next() != scanner.Ident { + p.errorf("expected identifier") + return + } + instance = p.text() + switch tok := p.peek(); tok { + case scanner.Ident: + var ( + key string + value interface{} + ) + key, value, ok = p.assign() + if !ok { + return + } + params = map[string]interface{}{key: value} + case '(': + params, ok = p.assignlist() + default: + p.next() + p.errorf("unexpected: %s", scanner.TokenString(tok)) + } + return +} + +// instance: +// ident ident +// ident ident assignlist +func (p *parser) instance() (inst *instance, ok bool) { + if p.next() != scanner.Ident { + p.errorf("expected identifier") + return + } + inst = &instance{name: p.text()} + if p.next() != scanner.Ident { + p.errorf("expected identifier") + return + } + inst.parent = p.text() + if p.peek() != '(' { + ok = true + return + } + inst.params, ok = p.assignlist() + return +} + +// assign: +// key = value +func (p *parser) assign() (key string, value interface{}, ok bool) { + if p.next() != scanner.Ident { + p.errorf("expected identifier") + return + } + key = p.text() + if p.next() != '=' { + p.errorf(`expected "="`) + return + } + value, ok = p.value() + return +} + +// assignlist: +// ( list ) +// +// list: +// assign +// assign ';' list +func (p *parser) assignlist() (assigns map[string]interface{}, ok bool) { + if p.next() != '(' { + p.errorf(`parse error: expected "("`) + return + } + assigns = make(map[string]interface{}) + for { + switch p.peek() { + default: + var ( + key string + value interface{} + ) + key, value, ok = p.assign() + if !ok { + return + } + assigns[key] = value + case ';': + p.next() + case ')': + p.next() + ok = true + return + } + } +} + +// value: +// 'true' +// 'false' +// 'nil' +// identifier +// integer +// float +// string +func (p *parser) value() (value any, ok bool) { + switch tok := p.next(); tok { + case scanner.Ident: + switch p.text() { + case "true": + return true, true + case "false": + return false, true + case "nil": + return indirect(""), true + default: + return indirect(p.text()), true + } + case scanner.String, scanner.RawString: + text, err := strconv.Unquote(p.text()) + if err != nil { + p.errorf("could not parse string: %v", err) + return nil, false + } + return text, true + case '-': + return p.parseNumber(p.next(), true) + default: + return p.parseNumber(tok, false) + } +} + +func (p *parser) parseNumber(tok rune, negate bool) (value any, ok bool) { + switch tok { + case scanner.Int: + v, err := strconv.ParseInt(p.text(), 0, 64) + if err != nil { + p.errorf("could not parse integer: %v", err) + return nil, false + } + if negate { + v = -v + } + return int(v), true + case scanner.Float: + v, err := strconv.ParseFloat(p.text(), 64) + if err != nil { + p.errorf("could not parse float: %v", err) + return nil, false + } + if negate { + v = -v + } + return v, true + default: + p.errorf("parse error: not a value") + return nil, false + } +} + +func (p *parser) next() rune { + tok := p.peek() + p.insertion = insertionToks[tok] + p.scanned = 0 + return tok +} + +func (p *parser) peek() rune { + if p.scanned == 0 { + p.scanned = p.scanner.Scan() + } + if p.insertion && p.scanned == '\n' { + return ';' + } + return p.scanned +} + +func (p *parser) text() string { + return p.scanner.TokenText() +} + +func (p *parser) errorf(format string, args ...interface{}) { + e := fmt.Sprintf("%s: %s", p.scanner.Position, fmt.Sprintf(format, args...)) + p.errors = append(p.errors, e) +} diff --git a/config/parse_test.go b/config/parse_test.go new file mode 100644 index 00000000..d798f8ec --- /dev/null +++ b/config/parse_test.go @@ -0,0 +1,145 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package config + +import ( + "strings" + "testing" +) + +// TestParseEmpty verifies that parsing a (logically) empty file is valid. +func TestParseEmpty(t *testing.T) { + for _, c := range []struct { + name string + text string + }{ + {"Empty", ""}, + {"Whitespace", "\t\n \n\t \n"}, + {"Semicolons", ";;"}, + {"Mix", " \t \n \n;\n ;"}, + } { + t.Run(c.name, func(t *testing.T) { + instances, err := parse(strings.NewReader(c.text)) + if err != nil { + t.Fatal(err) + } + if got, want := len(instances), 0; got != want { + t.Errorf("got %v, want %v", got, want) + } + }) + } +} + +func TestParse(t *testing.T) { + got, err := parse(strings.NewReader(strings.ReplaceAll(` +param x y = "okay" +param y z = 123 +param y a = "a"; param y b = b +param y c = nil + +param y ( + x = "blah" + y = 333 + fprec = 0.123456789 + raw = ${"it's json": +12.3}$ +) + +instance z blah ( + bloop = 123 + negint = -3 + negfloat = -3.14 +) + +param z x = 89898 + +instance bigslice/system blah ( + region = "us-west-2" +) + +param zero-params () +`, "$", "`"))) + if err != nil { + t.Fatal(err) + } + want := instances{ + "x": &instance{ + name: "x", + params: map[string]interface{}{ + "y": "okay", + }, + }, + "y": &instance{ + name: "y", + params: map[string]interface{}{ + "x": "blah", + "y": 333, + "z": 123, + "a": "a", + "b": indirect("b"), + "c": indirect(""), + "fprec": 0.123456789, + "raw": `{"it's json": +12.3}`, + }, + }, + "z": &instance{ + name: "z", + parent: "blah", + params: map[string]interface{}{ + "bloop": 123, + "x": 89898, + "negint": -3, + "negfloat": -3.14, + }, + }, + "bigslice/system": &instance{ + name: "bigslice/system", + parent: "blah", + params: map[string]interface{}{ + "region": "us-west-2", + }, + }, + "zero-params": &instance{ + name: "zero-params", + parent: "", + params: map[string]interface{}{}, + }, + } + if !got.Equal(want) { + t.Errorf("got %v, want %v", got, want) + } + for name, wantInst := range want { + t.Run(name, func(t *testing.T) { + syntax := wantInst.SyntaxString(nil) + insts, err := parse(strings.NewReader(syntax)) + if err != nil { + t.Fatalf("%v. syntax:\n%s", err, syntax) + } + if gotInst := insts[wantInst.name]; !wantInst.Equal(gotInst) { + t.Errorf("got %v, want %v, syntax:\n%s", got, want, syntax) + } + }) + } +} + +func TestParseError(t *testing.T) { + testError(t, `parm x y = 1`, `parse error: :1:1: unrecognized toplevel clause: parm`) + testError(t, `param x _y = "hey"`, `parse error: :1:9: unexpected: "_"`) + testError(t, `param x 123 = blah`, `parse error: :1:9: unexpected: Int`) + testError(t, `param x y z`, `parse error: :1:11: expected "="`) +} + +func testError(t *testing.T, s, expect string) { + t.Helper() + _, err := parse(strings.NewReader(s)) + if err == nil { + t.Error("expected error") + return + } + if got, want := err.Error(), expect; got != want { + t.Errorf("got %q, want %q", got, want) + } +} diff --git a/config/profile.go b/config/profile.go new file mode 100644 index 00000000..a877e704 --- /dev/null +++ b/config/profile.go @@ -0,0 +1,708 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +// Package config is used to configure software systems. A +// configuration managed by package config is called a profile. The +// binary that loads a profile declares a set of named, global +// objects through the APIs in this package. A profile configures +// these objects (objects may depend on each other, forming a DAG) +// and lets the user retrieve configured objects through its API. +// +// The semantics of profiles provide the kind of flexibility that is +// often required in operational contexts. Profiles define a +// principled overriding so that a base configuration can be extended +// by the user, either by composing multiple configuration or by +// editing the configuration through a command-line integration. +// Profiles may also derive multiple instances from the same base +// instance in order to provide small variations on instance +// configuration. Profiles define a concrete syntax so that they may +// be stored (e.g., centrally) or transmitted over a network +// connection (e.g., to bootstrap a remote binary with a particular +// configuration). Profiles are also self-documenting in the manner +// of Go's flag package. Profiles are resolved lazily, and thus +// maintain configuration for unknown instances, so long as these are +// never retrieved. This permits a single profile to be reused across +// many binaries without concern for compatibility. +// +// Profile syntax +// +// A profile contains a set of clauses, or directives. Each clause +// either declares a new instance or configures an existing instance. +// Clauses are interpreted in order, top-to-bottom, and later +// configurations override earlier configurations. These semantics +// accommodate for "overlays", where for example a user profile is +// loaded after a base profile to provide customization. Within +// GRAIL, a base profile is declared in the standard package +// github.com/grailbio/base/grail, which also loads a user +// profile from $HOME/grail/profile. +// +// A parameter is set by the directive param. For example, the +// following sets the parallelism parameter on the instance bigslice +// to 1024: +// +// param bigslice parallelism = 1024 +// +// The values supported by profiles are: integers, strings, booleans, floats, +// and indirections (naming other instances). The following shows an example +// of each: +// +// param bigslice load-factor = 0.8 +// param bigmachine/ec2system username = "marius" +// param bigmachine/ec2system on-demand = false +// param s3 retries = 8 +// +// As a shortcut, parameters for the same instance may be grouped +// together. For example, the two parameters on the instance +// bigmachine/ec2system may be grouped together as follows: +// +// param bigmachine/ec2system ( +// username = "marius" +// on-demand = false +// ) +// +// Instances may refer to each other by name. The following +// configures the aws/ticket instance to use a particular ticket path +// and region; it then configures bigmachine/ec2system to use this +// AWS session. +// +// param aws/ticket ( +// path = "eng/dev/aws" +// region = "us-west-2" +// ) +// +// param bigmachine/ec2system aws = aws/ticket +// +// Profiles may also define new instances with different configurations. +// This is done via the instance directive. For example, if we wanted to +// declare a new bigmachine/ec2system that used on-demand instances +// instead of spot instances, we could define a profile as follows: +// +// instance bigmachine/ec2ondemand bigmachine/ec2system +// +// param bigmachine/ec2ondemand on-demand = false +// +// Since it is common to declare an instance and configure it, the +// profile syntax provides an affordance for combining the two, +// also through grouping. The above is equivalent to: +// +// instance bigmachine/ec2ondemand bigmachine/ec2system ( +// on-demand = false +// username = "marius-ondemand" +// // (any other configuration to be changed from the base) +// ) +// +// New instances may depend on any instance. For example, the above +// may be further customized as follows. +// +// instance bigmachine/ec2ondemand-anonymous bigmachine/ec2ondemand ( +// username = "anonymous" +// ) +// +// Customization through flags +// +// Profile parameters may be adjusted via command-line flags. Profile +// provides utility methods to register flags and interpret them. See +// the appropriate methods for more details. Any parameter may be +// set through the provided command-line flags by specifying the path +// to the parameter. As an example, the following invocations customize +// aspects of the above profile. +// +// # Override the ticket path and the default ec2system username. +// # -set flags are interpreted in order, and the following is equivalent +// # to the clauses +// # param aws/ticket path = "eng/prod/aws" +// # param bigmachine/ec2system username = "anonymous" +// $ program -set aws/ticket.path=eng/prod/aws -set bigmachine/ec2system.username=anonymous +// +// # User the aws/env instance instead of aws/ticket, as above. +// # The type of a flag is interpreted based on underlying type, so +// # the following is equivalent to the clause +// # param bigmachine/ec2system aws = aws/env +// $ program -set bigmachine/ec2system.aws=aws/env +// +// Default profile +// +// Package config also defines a default profile and a set of package-level +// methods that operate on this profile. Most users should make use only +// of the default profile. This package also exports an http handler on the +// path /debug/profile on the default (global) ServeMux, which returns the +// global profile in parseable form. +package config + +import ( + "fmt" + "io" + "log" + "net/http" + "reflect" + "runtime" + "sort" + "strconv" + "strings" + "sync" + "unsafe" +) + +func init() { + http.HandleFunc("/debug/profile", func(w http.ResponseWriter, r *http.Request) { + if err := Application().PrintTo(w); err != nil { + http.Error(w, fmt.Sprintf("writing profile: %v", err), http.StatusInternalServerError) + } + }) +} + +type ( + // Profile stores a set of parameters and configures instances based + // on these. It is the central data structure of this package as + // detailed in the package docs. Each Profile instance maintains its + // own set of instances. Most users should use the package-level + // methods that operate on the default profile. + Profile struct { + // The following are used by the flag registration and + // handling mechanism. + flags flags + flagDump bool + + globals map[string]typedConstructor + + mu sync.Mutex + instances instances + cached map[string]interface{} + } + // typedConstructor is a type-erased *Constructor[T]. + typedConstructor struct { + constructor *Constructor[any] + typ reflect.Type + } +) + +// New creates and returns a new profile, installing all currently +// registered global objects. Global objects registered after a call +// to New are not reflected in the returned profile. +func New() *Profile { + p := &Profile{ + globals: make(map[string]typedConstructor), + instances: make(instances), + cached: make(map[string]interface{}), + } + + globalsMu.Lock() + for name, ct := range globals { + p.globals[name] = typedConstructor{newConstructor(), ct.typ} + ct.configure(p.globals[name].constructor) + } + globalsMu.Unlock() + + // Make a shadow instance for each global instance. This helps keep + // the downstream code simple. We also populate any defaults + // provided by the configured instances, so that printing the + // profile shows the true, global (and re-createable) state of the + // profile. + for name, global := range p.globals { + inst := &instance{name: name, params: make(map[string]interface{})} + for pname, param := range global.constructor.params { + // Special case for interface params: use their indirections + // instead of their value; this is always how they are satisfied + // in practice. + if param.kind == paramInterface { + inst.params[pname] = def{param.ifaceindir} + } else { + inst.params[pname] = def{param.Interface()} + } + } + p.instances[name] = inst + } + + // Populate defaults as empty instance declarations, effectively + // redirecting the instance and making it overridable, etc. + globalsMu.Lock() + for name, parent := range defaults { + p.instances[name] = &instance{name: name, parent: parent} + } + globalsMu.Unlock() + + return p +} + +// Set sets the value of the parameter at the provided path to the +// provided value, which is intepreted according to the type of the +// parameter at that path. Set returns an error if the parameter does +// not exist or if the value cannot be parsed into the expected type. +// The path is a set of identifiers separated by dots ("."). Paths may +// traverse multiple indirections. +func (p *Profile) Set(path string, value string) error { + p.mu.Lock() + defer p.mu.Unlock() + + // Special case: toplevel instance assignment. + elems := strings.Split(path, ".") + if len(elems) == 1 { + if value == "" || value == "nil" { + return fmt.Errorf( + "%s: top-level path may only be set to an instance; cannot be set to nil/empty", + elems[0], + ) + } + p.instances[elems[0]] = &instance{ + name: elems[0], + parent: value, + } + return nil + } + + // Otherwise infer the type and parse it accordingly. + inst := p.instances[elems[0]] + if inst == nil { + return fmt.Errorf("%s: path not found: instance not found", path) + } + for i := 1; i < len(elems)-1; i++ { + var v interface{} + for { + var ok bool + v, ok = inst.params[elems[i]] + if ok { + break + } + if inst.parent == "" || p.instances[inst.parent] == nil { + return fmt.Errorf("%s: path not found: instance not found: %s", path, strings.Join(elems[:i], ".")) + } + inst = p.instances[inst.parent] + } + v, _ = unwrap(v) + indir, ok := v.(indirect) + if !ok { + return fmt.Errorf("%s: path not found: %s is not an instance", path, strings.Join(elems[:i], ".")) + } + inst = p.instances[string(indir)] + if inst == nil { + return fmt.Errorf("%s: path not found: instance not found: %s", path, strings.Join(elems[:i], ".")) + } + } + + name := elems[len(elems)-1] + for { + if _, ok := inst.params[name]; ok { + break + } + if inst.parent == "" || p.instances[inst.parent] == nil { + return fmt.Errorf("%s: no such parameter", path) + } + inst = p.instances[inst.parent] + } + + switch v, _ := unwrap(inst.params[name]); v.(type) { + case indirect: + // TODO(marius): validate that it's a good identifier? + if value == "nil" { + value = "" + } + inst.params[name] = indirect(value) + case string: + inst.params[name] = value + case bool: + v, err := strconv.ParseBool(value) + if err != nil { + return fmt.Errorf("param %s is a bool, but could not parse %s into bool: %v", path, value, err) + } + inst.params[name] = v + case int: + v, err := strconv.ParseInt(value, 0, 64) + if err != nil { + return fmt.Errorf("param %s is an int, but could not parse %s into int: %v", path, value, err) + } + inst.params[name] = int(v) + case float64: + v, err := strconv.ParseFloat(value, 64) + if err != nil { + return fmt.Errorf("param %s is a float, but could not parse %s into float: %v", path, value, err) + } + inst.params[name] = v + default: + panic(fmt.Sprintf("%T", v)) + } + return nil +} + +// Get returns the value of the configured parameter at the provided +// dot-separated path. +func (p *Profile) Get(path string) (value string, ok bool) { + p.mu.Lock() + defer p.mu.Unlock() + + var ( + elems = strings.Split(path, ".") + inst = p.instances[elems[0]] + ) + if inst == nil { + return "", false + } + // Special case: toplevels are "set" only if they are inherited. + // We return only the first level of inheritance. + if len(elems) == 1 { + return inst.parent, inst.parent != "" + } + + for i := 1; i < len(elems)-1; i++ { + elem := elems[i] + for inst != nil && inst.params[elem] == nil { + inst = p.instances[inst.parent] + } + if inst == nil { + return "", false + } + v, _ := unwrap(inst.params[elem]) + indir, ok := v.(indirect) + if !ok { + return "", false + } + inst = p.instances[string(indir)] + if inst == nil { + return "", false + } + } + + for elem := elems[len(elems)-1]; inst != nil; inst = p.instances[inst.parent] { + if v, ok := inst.params[elem]; ok { + v, _ = unwrap(v) + return fmt.Sprintf("%#v", v), true + } + } + return "", false +} + +// Merge merges the instance parameters in profile q into p, +// so that parameters defined in q override those in p. +func (p *Profile) Merge(q *Profile) { + defer lock(p, q)() + for _, inst := range q.instances { + p.instances.Merge(inst) + } +} + +// Parse parses a profile from the provided reader into p. On +// success, the instances defined by the profile in src are merged into +// profile p. If the reader implements +// +// Name() string +// +// then the result of calling Name is used as a filename to provide +// positional information in errors. +func (p *Profile) Parse(r io.Reader) error { + insts, err := parse(r) + if err != nil { + return err + } + p.mu.Lock() + defer p.mu.Unlock() + for _, inst := range insts { + p.instances.Merge(inst) + } + return nil +} + +// InstanceNames returns the set of names of instances provided by p. +func (p *Profile) InstanceNames() map[string]struct{} { + p.mu.Lock() + defer p.mu.Unlock() + names := make(map[string]struct{}, len(p.instances)) + for name := range p.instances { + names[name] = struct{}{} + } + return names +} + +// Instance retrieves the named instance from this profile into the +// pointer ptr. All of its parameters are fully resolved and the +// underlying global object is instantiated according to the desired +// parameterization. Instance panics if ptr is not a pointer type. If the +// type of the instance cannot be assigned to the value pointed to by +// ptr, an error is returned. Since such errors may occur +// transitively (e.g., the type of an instance required by another +// instance may be wrong), the source location of the type mismatch +// is included in the error to help with debugging. Instances are +// cached and are only initialized the first time they are requested. +// +// If ptr is nil, the instance is created without populating the pointer. +func (p *Profile) Instance(name string, ptr interface{}) error { + var ptrv reflect.Value + if ptr != nil { + ptrv = reflect.ValueOf(ptr) + if ptrv.Kind() != reflect.Ptr { + panic("profile.Get: not a pointer") + } + } + _, file, line, _ := runtime.Caller(1) + p.mu.Lock() + err := p.getLocked(name, ptrv, file, line) + p.mu.Unlock() + return err +} + +func (p *Profile) PrintTo(w io.Writer) error { + p.mu.Lock() + defer p.mu.Unlock() + instances := p.sorted() + for _, inst := range instances { + if len(inst.params) == 0 && inst.parent == "" { + continue + } + if _, err := fmt.Fprintln(w, inst.SyntaxString(p.docs(inst))); err != nil { + return err + } + } + return nil +} + +// docs collects the documentation strings for inst and its parameters. +// Special key "" holds the documentation for the instance itself. +// Remaining keys are inst's param names. +func (p *Profile) docs(inst *instance) map[string]string { + global, ok := p.globals[inst.name] + if !ok { + return nil + } + docs := map[string]string{"": global.constructor.Doc} + for name, param := range global.constructor.params { + docs[name] = param.help + + var paramType reflect.Type + if param.kind == paramInterface { + paramType = reflect.ValueOf(param.ifaceptr).Elem().Type() + } else { + paramType = reflect.TypeOf(param.Interface()) + } + var insts []string + // TODO: This is asymptotically slow. Make it faster, perhaps with indexing. + for name, other := range p.globals { + if name != inst.name && other.typ != nil && other.typ.AssignableTo(paramType) { + insts = append(insts, name) + } + } + sort.Strings(insts) + if len(insts) > 0 { + docs[name] += "\n\nAvailable instances:\n\t" + strings.Join(insts, "\n\t") + } + } + return docs +} + +func (p *Profile) getLocked(name string, ptr reflect.Value, file string, line int) error { + if v, ok := p.cached[name]; ok { + return assign(name, v, ptr, file, line) + } + inst := p.instances[name] + if inst == nil { + return fmt.Errorf("no instance named %q", name) + } + + resolved := make(map[string]interface{}) + for { + for k, v := range inst.params { + if _, ok := resolved[k]; !ok { + resolved[k] = v + } + } + if inst.parent == "" { + break + } + parent := p.instances[inst.parent] + if parent == nil { + return fmt.Errorf("no such instance: %q", inst.parent) + } + inst = parent + } + + if _, ok := p.globals[inst.name]; !ok { + return fmt.Errorf("missing global instance: %q", inst.name) + } + // Even though we have a configured instance in globals, we create + // a new one to reduce the changes that multiple instances clobber + // each other. + globalsMu.Lock() + ct := globals[inst.name] + globalsMu.Unlock() + instance := newConstructor() + ct.configure(instance) + + for pname, param := range instance.params { + val, ok := resolved[pname] + if !ok { + continue + } + // Skip defaults except for paramInterface since these need to be resolved. + if _, ok := val.(def); ok && param.kind != paramInterface { + continue + } + val, _ = unwrap(val) + if indir, ok := val.(indirect); ok { + if param.kind != paramInterface { + return fmt.Errorf("resolving %s.%s: cannot indirect parameters of type %T", name, pname, val) + } + if indir == "" { + typ := reflect.ValueOf(param.ifaceptr).Elem().Type() + if !isNilAssignable(typ) { + return fmt.Errorf("resolving %s.%s: cannot assign nil/empty to parameter of type %s", name, pname, typ) + } + continue // nil: skip + } + if err := p.getLocked(string(indir), reflect.ValueOf(param.ifaceptr), param.file, param.line); err != nil { + return err + } + continue + } + + switch param.kind { + case paramInterface: + var ( + dst = reflect.ValueOf(param.ifaceptr).Elem() + src = reflect.ValueOf(val) + ) + // TODO: include embedded fields, etc? + if !src.Type().AssignableTo(dst.Type()) { + return fmt.Errorf("%s.%s: cannot assign value of type %s to type %s", name, pname, src.Type(), dst.Type()) + } + dst.Set(src) + case paramInt: + ival, ok := val.(int) + if !ok { + return fmt.Errorf("%s.%s: wrong parameter type: expected int, got %T", name, pname, val) + } + *param.intptr = ival + case paramFloat: + switch tv := val.(type) { + case int: + *param.floatptr = float64(tv) + case float64: + *param.floatptr = tv + default: + return fmt.Errorf("%s.%s: wrong parameter type: expected float64, got %T", name, pname, val) + } + case paramString: + sval, ok := val.(string) + if !ok { + return fmt.Errorf("%s.%s: wrong parameter type: expected string, got %T", name, pname, val) + } + *param.strptr = sval + case paramBool: + bval, ok := val.(bool) + if !ok { + return fmt.Errorf("%s.%s: wrong parameter type: expected bool, got %T", name, pname, val) + } + *param.boolptr = bval + default: + panic(param.kind) + } + } + + v, err := instance.New() + if err != nil { + return err + } + p.cached[name] = v + return assign(name, v, ptr, file, line) +} + +func (p *Profile) sorted() []*instance { + instances := make([]*instance, 0, len(p.instances)) + for _, inst := range p.instances { + instances = append(instances, inst) + } + sort.Slice(instances, func(i, j int) bool { + return instances[i].name < instances[j].name + }) + return instances +} + +var ( + defaultInit sync.Once + defaultInstance *Profile +) + +// NewDefault is used to initialize the default profile. It can be +// set by a program before the application profile has been created +// in order to support asynchronous profile retrieval. +var NewDefault = New + +// Application returns the default application profile. The default +// instance is initialized during the first call to Application (and thus +// of the package-level methods that operate on the default profile). +// Because of this, Application (and the other package-level methods +// operating on the default profile) should not be called during +// package initialization as doing so means that some global objects +// may not yet have been registered. +func Application() *Profile { + // TODO(marius): freeze registration after this? + defaultInit.Do(func() { + defaultInstance = NewDefault() + }) + return defaultInstance +} + +// Merge merges profile p into the default profile. +// See Profile.Merge for more details. +func Merge(p *Profile) { + Application().Merge(p) +} + +// Parse parses the profile in reader r into the default +// profile. See Profile.Parse for more details. +func Parse(r io.Reader) error { + return Application().Parse(r) +} + +// Instance retrieves the instance with the provided name into the +// provided pointer from the default profile. See Profile.Instance for +// more details. +func Instance(name string, ptr interface{}) error { + return Application().Instance(name, ptr) +} + +// Set sets the value of the parameter named by the provided path on +// the default profile. See Profile.Set for more details. +func Set(path, value string) error { + return Application().Set(path, value) +} + +// Get retrieves the value of the parameter named by the provided path +// on the default profile. +func Get(path string) (value string, ok bool) { + return Application().Get(path) +} + +// Must is a version of get which calls log.Fatal on error. +func Must(name string, ptr interface{}) { + if err := Instance(name, ptr); err != nil { + log.Fatal(err) + } +} + +func assign(name string, instance interface{}, ptr reflect.Value, file string, line int) error { + if ptr == (reflect.Value{}) { + return nil + } + v := reflect.ValueOf(instance) + if !v.IsValid() { + ptr.Elem().Set(reflect.Zero(ptr.Elem().Type())) + return nil + } + if !v.Type().AssignableTo(ptr.Elem().Type()) { + return fmt.Errorf( + "%s:%d: instance %q of type %s is not assignable to provided pointer element type %s", + file, line, name, v.Type(), ptr.Elem().Type()) + } + ptr.Elem().Set(v) + return nil +} + +func lock(p, q *Profile) (unlock func()) { + if uintptr(unsafe.Pointer(q)) < uintptr(unsafe.Pointer(p)) { + p, q = q, p + } + p.mu.Lock() + q.mu.Lock() + return func() { + q.mu.Unlock() + p.mu.Unlock() + } +} diff --git a/config/profile_test.go b/config/profile_test.go new file mode 100644 index 00000000..f6009b40 --- /dev/null +++ b/config/profile_test.go @@ -0,0 +1,647 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package config + +import ( + "strings" + "testing" +) + +type custom struct { + x int + f float64 +} + +// paramFields is a convenient structure for testing that has fields of various +// types that we use to test parameter handling. +type paramFields struct { + c custom + p *custom + ch chan struct{} + a any +} + +func init() { + Register("test/custom", func(inst *Constructor[custom]) { + var c custom + inst.IntVar(&c.x, "x", -1, "the x value") + inst.FloatVar(&c.f, "f", 0, "the f value") + inst.New = func() (custom, error) { + return c, nil + } + }) + + Default("test/default", "test/custom") + + Default("test/default2", "test/default") + + Register("test/custom-ptr", func(inst *Constructor[*custom]) { + var c custom + inst.IntVar(&c.x, "x", -1, "the x value") + inst.New = func() (*custom, error) { + return &c, nil + } + }) + + Register("test/1", func(inst *Constructor[int]) { + var c custom + inst.InstanceVar(&c, "custom", "test/default", "the custom struct") + x := inst.Int("x", 123, "the x value") + inst.New = func() (int, error) { + return *x + c.x, nil + } + }) + + Register("test/custom-nil", func(inst *Constructor[*custom]) { + inst.New = func() (*custom, error) { + return (*custom)(nil), nil + } + }) + + Default("test/default-custom-nil", "test/custom-nil") + + Register("test/untyped-nil", func(inst *Constructor[any]) { + inst.New = func() (any, error) { + return nil, nil + } + }) + + Default("test/default-untyped-nil", "test/untyped-nil") + + Register("test/params/empty", func(inst *Constructor[paramFields]) { + var pf paramFields + inst.InstanceVar(&pf.p, "p", "test/custom-nil", "") + inst.InstanceVar(&pf.ch, "ch", "", "") + inst.InstanceVar(&pf.a, "a", "", "") + inst.New = func() (paramFields, error) { + return pf, nil + } + }) + + Register("test/params/nil", func(inst *Constructor[paramFields]) { + var pf paramFields + inst.InstanceVar(&pf.p, "p", "nil", "") + inst.InstanceVar(&pf.ch, "ch", "", "") + inst.InstanceVar(&pf.a, "a", "nil", "") + inst.New = func() (paramFields, error) { + return pf, nil + } + }) + + Register("test/params/nil-instance", func(inst *Constructor[paramFields]) { + var pf paramFields + inst.InstanceVar(&pf.p, "p", "test/custom-nil", "") + inst.New = func() (paramFields, error) { + return pf, nil + } + }) + + Register("test/params/empty-non-nilable-recovered", func(inst *Constructor[any]) { + var r any + func() { + defer func() { + r = recover() + }() + var pf paramFields + inst.InstanceVar(&pf.c, "c", "", "") + }() + inst.New = func() (any, error) { + return r, nil + } + }) + + Register("test/params/nil-non-nilable-recovered", func(inst *Constructor[any]) { + var r any + func() { + defer func() { + r = recover() + }() + var pf paramFields + inst.InstanceVar(&pf.c, "c", "nil", "") + }() + inst.New = func() (any, error) { + return r, nil + } + }) + + Register("test/chan", func(inst *Constructor[chan struct{}]) { + inst.New = func() (chan struct{}, error) { + return make(chan struct{}), nil + } + }) + + Register("test/params/non-nil", func(inst *Constructor[paramFields]) { + var pf paramFields + inst.InstanceVar(&pf.c, "c", "test/custom", "") + inst.InstanceVar(&pf.p, "p", "test/custom-ptr", "") + inst.InstanceVar(&pf.ch, "ch", "test/chan", "") + inst.InstanceVar(&pf.a, "a", "test/custom", "") + inst.New = func() (paramFields, error) { + return pf, nil + } + }) +} + +func TestProfileParamDefault(t *testing.T) { + p := New() + var x int + if err := p.Instance("test/1", &x); err != nil { + t.Fatal(err) + } + if got, want := x, 122; got != want { + t.Errorf("got %v, want %v", got, want) + } + + p = New() + if err := p.Set("test/custom.x", "-100"); err != nil { + t.Fatal(err) + } + if err := p.Instance("test/1", &x); err != nil { + t.Fatal(err) + } + if got, want := x, 23; got != want { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestProfileDefaultInstance(t *testing.T) { + t.Run("basic", func(t *testing.T) { + p := New() + if err := p.Instance("test/default", nil); err != nil { + t.Fatal(err) + } + }) + t.Run("override-default", func(t *testing.T) { + p := New() + if err := p.Parse(strings.NewReader(` + instance custom13 test/custom ( + x = 13 + ) + instance test/default custom13 + `)); err != nil { + t.Fatal(err) + } + var c custom + if err := p.Instance("test/default", &c); err != nil { + t.Fatal(err) + } + if got, want := c.x, 13; got != want { + t.Errorf("got %v, want %v", got, want) + } + if err := p.Instance("test/default2", &c); err != nil { + t.Fatal(err) + } + if got, want := c.x, 13; got != want { + t.Errorf("got %v, want %v", got, want) + } + }) + t.Run("override-second-default", func(t *testing.T) { + p := New() + if err := p.Parse(strings.NewReader(` + instance custom13 test/custom ( + x = 13 + ) + instance test/default2 custom13 + `)); err != nil { + t.Fatal(err) + } + var c custom + if err := p.Instance("test/default", &c); err != nil { + t.Fatal(err) + } + if got, want := c.x, -1; got != want { + t.Errorf("got %v, want %v", got, want) + } + if err := p.Instance("test/default2", &c); err != nil { + t.Fatal(err) + } + if got, want := c.x, 13; got != want { + t.Errorf("got %v, want %v", got, want) + } + }) + t.Run("override-defaults-differently", func(t *testing.T) { + p := New() + if err := p.Parse(strings.NewReader(` + instance custom13 test/custom ( + x = 13 + ) + instance custom132 test/custom ( + x = 132 + ) + instance test/default custom13 + instance test/default2 custom132 + `)); err != nil { + t.Fatal(err) + } + var c custom + if err := p.Instance("test/default", &c); err != nil { + t.Fatal(err) + } + if got, want := c.x, 13; got != want { + t.Errorf("got %v, want %v", got, want) + } + if err := p.Instance("test/default2", &c); err != nil { + t.Fatal(err) + } + if got, want := c.x, 132; got != want { + t.Errorf("got %v, want %v", got, want) + } + }) + t.Run("override-default-instance-with-param", func(t *testing.T) { + t.Skip() // BXDS-2886 + p := New() + if err := p.Parse(strings.NewReader(` + instance test/default test/custom ( + x = 13 + ) + `)); err != nil { + t.Fatal(err) + } + var c custom + if err := p.Instance("test/default", &c); err != nil { + t.Fatal(err) + } + if got, want := c.x, 13; got != want { + t.Errorf("got %v, want %v", got, want) + } + if err := p.Instance("test/default2", &c); err != nil { + t.Fatal(err) + } + if got, want := c.x, 13; got != want { + t.Errorf("got %v, want %v", got, want) + } + }) + t.Run("set-default", func(t *testing.T) { + p := New() + if err := p.Parse(strings.NewReader(` + instance custom13 test/custom ( + x = 13 + ) + `)); err != nil { + t.Fatal(err) + } + if err := p.Set("test/default", "custom13"); err != nil { + t.Fatal(err) + } + var c custom + if err := p.Instance("test/default", &c); err != nil { + t.Fatal(err) + } + if got, want := c.x, 13; got != want { + t.Errorf("got %v, want %v", got, want) + } + if err := p.Instance("test/default2", &c); err != nil { + t.Fatal(err) + } + if got, want := c.x, 13; got != want { + t.Errorf("got %v, want %v", got, want) + } + }) + t.Run("set-second-default", func(t *testing.T) { + p := New() + if err := p.Parse(strings.NewReader(` + instance custom13 test/custom ( + x = 13 + ) + `)); err != nil { + t.Fatal(err) + } + if err := p.Set("test/default2", "custom13"); err != nil { + t.Fatal(err) + } + var c custom + if err := p.Instance("test/default", &c); err != nil { + t.Fatal(err) + } + if got, want := c.x, -1; got != want { + t.Errorf("got %v, want %v", got, want) + } + if err := p.Instance("test/default2", &c); err != nil { + t.Fatal(err) + } + if got, want := c.x, 13; got != want { + t.Errorf("got %v, want %v", got, want) + } + }) + t.Run("set-defaults-differently", func(t *testing.T) { + p := New() + if err := p.Parse(strings.NewReader(` + instance custom13 test/custom ( + x = 13 + ) + instance custom132 test/custom ( + x = 132 + ) + instance test/default custom13 + instance test/default2 custom132 + `)); err != nil { + t.Fatal(err) + } + if err := p.Set("test/default", "custom13"); err != nil { + t.Fatal(err) + } + if err := p.Set("test/default2", "custom132"); err != nil { + t.Fatal(err) + } + var c custom + if err := p.Instance("test/default", &c); err != nil { + t.Fatal(err) + } + if got, want := c.x, 13; got != want { + t.Errorf("got %v, want %v", got, want) + } + if err := p.Instance("test/default2", &c); err != nil { + t.Fatal(err) + } + if got, want := c.x, 132; got != want { + t.Errorf("got %v, want %v", got, want) + } + }) +} + +func TestProfile(t *testing.T) { + p := New() + err := p.Parse(strings.NewReader(` +param test/custom ( + x = 999 +) + +param test/1 ( + custom = test/custom + x = 1 +) + +instance testx test/1 ( + x = 100 +) + +instance testf test/custom ( + f = 1 +) + +instance test/default testf +`)) + if err != nil { + t.Fatal(err) + } + + var x int + if err = p.Instance("test/1", &x); err != nil { + t.Fatal(err) + } + if got, want := x, 1000; got != want { + t.Errorf("got %v, want %v", got, want) + } + + if err = p.Instance("testx", &x); err != nil { + t.Fatal(err) + } + if got, want := x, 1099; got != want { + t.Errorf("got %v, want %v", got, want) + } + + var str string + err = p.Instance("testx", &str) + if err == nil || !strings.Contains(err.Error(), "instance \"testx\" of type int is not assignable to provided pointer element type string") { + t.Error(err) + } + + var c custom + if err = p.Instance("testf", &c); err != nil { + t.Fatal(err) + } + if got, want := c.f, 1.; got != want { + t.Errorf("got %v, want %v", got, want) + } + + // Verify that test/default derives from testf. + if err = p.Instance("test/default", &c); err != nil { + t.Fatal(err) + } + if got, want := c.f, 1.; got != want { + t.Errorf("got %v, want %v", got, want) + } +} + +// TestNilInstances verifies that we handle nil/empty instances appropriately. +func TestNilInstances(t *testing.T) { + var ( + mustSet = func(p *Profile, path, value string) { + t.Helper() + if err := p.Set(path, value); err != nil { + t.Error(err) + } + } + mustInstance = func(p *Profile, name string, pa any) { + t.Helper() + if err := p.Instance(name, pa); err != nil { + t.Fatal(err) + } + } + mustEqual = func(got, want any) { + t.Helper() + if got != want { + t.Errorf("got %v, want %v", got, want) + } + } + ) + + var ( + p *Profile + pc *custom + a any + pf paramFields + ) + + // Verify that top-level instances can be nil. + p = New() + mustInstance(p, "test/custom-nil", &pc) + mustEqual(pc, (*custom)(nil)) + mustInstance(p, "test/default-custom-nil", &pc) + mustEqual(pc, (*custom)(nil)) + mustInstance(p, "test/untyped-nil", &a) + mustEqual(a, nil) + mustInstance(p, "test/default-untyped-nil", &a) + mustEqual(a, nil) + + // Verify that empty InstanceVar defaults produce nil parameters. + p = New() + mustInstance(p, "test/params/empty", &pf) + mustEqual(pf.p, (*custom)(nil)) + mustEqual(pf.ch, (chan struct{})(nil)) + mustEqual(pf.a, nil) + + // Verify that nil InstanceVar defaults produce nil parameters. + p = New() + mustInstance(p, "test/params/nil", &pf) + mustEqual(pf.p, (*custom)(nil)) + mustEqual(pf.ch, (chan struct{})(nil)) + mustEqual(pf.a, nil) + + // Verify that an InstanceVar default instance whose value is nil produces + // a nil parameter. + p = New() + mustInstance(p, "test/params/nil-instance", &pf) + mustEqual(pf.p, (*custom)(nil)) + + // Verify that InstanceVar panics setting an empty value default for an + // element type that cannot be assigned nil. + p = New() + // Set c to a valid instance, so that the invalid instance error does not + // obscure the recovered panic value. + mustSet(p, "test/params/empty-non-nilable-recovered.c", "test/custom") + mustInstance(p, "test/params/empty-non-nilable-recovered", &a) + if a == nil { + t.Error("expected non-nil-assignable empty default instance to panic") + } + + // Verify that InstanceVar panics setting a nil value default for an + // element type that cannot be assigned nil. + p = New() + // Set c to a valid instance, so that the invalid instance error does not + // obscure the recovered panic value. + mustSet(p, "test/params/nil-non-nilable-recovered.c", "test/custom") + mustInstance(p, "test/params/nil-non-nilable-recovered", &a) + if a == nil { + t.Error("expected non-nil-assignable nil default instance to panic") + } + + // Verify that a non-nil-assignable parameter set to an empty instance is + // invalid. + p = New() + mustSet(p, "test/params/non-nil.c", "") + if err := p.Instance("test/params/non-nil", &pf); err == nil { + t.Error("non-nil-assignable set to empty instance should return non-nil error") + } + + // Verify that a non-nil-assignable parameter set to nil is invalid. + p = New() + mustSet(p, "test/params/non-nil.c", "nil") + if err := p.Instance("test/params/non-nil", &pf); err == nil { + t.Error("non-nil-assignable set to nil should return non-nil error") + } + + // Verify that nil-assignable parameters can be set to empty, resulting in + // nil parameter values. + p = New() + mustSet(p, "test/params/non-nil.p", "") + mustSet(p, "test/params/non-nil.ch", "") + mustSet(p, "test/params/non-nil.a", "") + mustInstance(p, "test/params/non-nil", &pf) + mustEqual(pf.p, (*custom)(nil)) + mustEqual(pf.ch, (chan struct{})(nil)) + mustEqual(pf.a, nil) + + // Verify that nil-assignable parameters can be set to nil, resulting in + // nil parameter values. + p = New() + mustSet(p, "test/params/non-nil.p", "nil") + mustSet(p, "test/params/non-nil.ch", "nil") + mustSet(p, "test/params/non-nil.a", "nil") + mustInstance(p, "test/params/non-nil", &pf) + mustEqual(pf.p, (*custom)(nil)) + mustEqual(pf.ch, (chan struct{})(nil)) + mustEqual(pf.a, nil) + + // Verify that a nil-assignable parameter can be set to an instance whose + // value is nil, resulting in a nil parameter value. + p = New() + mustSet(p, "test/params/non-nil.p", "test/custom-nil") + mustInstance(p, "test/params/non-nil", &pf) + mustEqual(pf.p, (*custom)(nil)) + + // Verify that top-level instances cannot be set to empty. + p = New() + if err := p.Set("test/custom", ""); err == nil { + t.Error("top-level instance set to empty should return non-nil error") + } + + // Verify that top-level instances cannot be set to nil. + p = New() + if err := p.Set("test/custom", ""); err == nil { + t.Error("top-level instance set to nil should return non-nil error") + } +} + +func TestSetGet(t *testing.T) { + p := New() + err := p.Parse(strings.NewReader(` +param test/custom ( + x = 999 +) + +param test/1 ( + custom = test/custom + x = 1 +) + +instance testx test/1 ( + x = 100 +) + +instance testy test/1 + +`)) + if err != nil { + t.Fatal(err) + } + + var ( + mustGet = func(k, want string) { + t.Helper() + got, ok := p.Get(k) + if !ok { + t.Fatalf("key %v not found", k) + } + if got != want { + t.Fatalf("key %v: got %v, want %v", k, got, want) + } + } + mustSet = func(k, v string) { + t.Helper() + if err := p.Set(k, v); err != nil { + t.Fatalf("set %v %v: %v", k, v, err) + } + } + ) + + mustGet("testy", "test/1") + mustGet("test/1.x", "1") + mustGet("testx.x", "100") + mustGet("testx.custom", "test/custom") + mustGet("testx.custom.x", "999") + mustGet("testy.x", "1") + + mustSet("testx.custom.x", "1900") + mustGet("testx.custom.x", "1900") + mustSet("testx.custom.x", "-1900") + mustGet("testx.custom.x", "-1900") + mustSet("testx.custom.f", "3.14") + mustGet("testx.custom.f", "3.14") + mustSet("testx.custom.f", "-3.14") + mustGet("testx.custom.f", "-3.14") + + mustSet("testy", "testx") + mustGet("testy.x", "100") + +} + +// TestInstanceNames verifies that InstanceNames returns the correct set of +// instance names. +func TestInstanceNames(t *testing.T) { + p := New() + names := p.InstanceNames() + // Because global instances can be added from anywhere, we only verify that + // the returned names contains the instances added by this file. + for _, name := range []string{ + "test/1", + "test/custom", + "test/default", + } { + if _, ok := names[name]; !ok { + t.Errorf("missing instance name=%v", name) + } + } +} diff --git a/config/testdata/profile b/config/testdata/profile new file mode 100644 index 00000000..4285e652 --- /dev/null +++ b/config/testdata/profile @@ -0,0 +1,14 @@ +param app/auth/login ( + user = "marius" + password = "supersecret" +) + +param app/database ( + credentials = app/auth/login +) + +param app/frontend ( + database = app/database + credentials = app/auth/login +) + diff --git a/config/testdata/profile_felogin_password b/config/testdata/profile_felogin_password new file mode 100644 index 00000000..78f435d5 --- /dev/null +++ b/config/testdata/profile_felogin_password @@ -0,0 +1 @@ +param test/felogin password = "abc" diff --git a/crypto/encryption/passwd/doc.go b/crypto/encryption/passwd/doc.go deleted file mode 100644 index cb6a0548..00000000 --- a/crypto/encryption/passwd/doc.go +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright 2017 GRAIL, Inc. All rights reserved. -// Use of this source code is governed by the Apache-2.0 -// license that can be found in the LICENSE file. - -// Package passwd provides an interactive password based key registry for use -// with github.com/grailbio/crypto/encryption. -// -// It uses bcrypt to generate an encryption key and sha256 of that key to -// generate the password hash. Thus bcrypt is required for any brute force -// attacks on both the key and the hash. -// -// The bcrypt hash (23 bytes) is used as the encryption key, or more precisely, -// the lower 16 bytes are used as an AES128 key. The sha256 of the full 23 -// bytes is used the password 'hash' and is stored in the encrypted file -// along with metadata such as the salt, bcrypt cost etc. -// -// Thus, the full set of operations in outline is: -// -// encryption/write side: -// -// key, metadata = bcrypt(password) -// hash = sha256(key) -// -// file = hash, aes128(key, plaintext) -// -// decryption/read side: -// read hash and metadata from the file -// key = bcrypt(password, metadata) -// newhash = sha256(key) -// if newhash == hash { -// key is valid -// } -// -// The ReadAndHashPassword method implements the 'write' side and -// ReadAndComparePassword the 'read' side. This routines take care to read the -// password safely from a terminal and to keep the password in memory for as -// short a time as possible. The encryption key however is kept in memory. -package passwd diff --git a/crypto/encryption/passwd/passwd.go b/crypto/encryption/passwd/passwd.go deleted file mode 100644 index a99d8a24..00000000 --- a/crypto/encryption/passwd/passwd.go +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2017 GRAIL, Inc. All rights reserved. -// Use of this source code is governed by the Apache-2.0 -// license that can be found in the LICENSE file. - -package passwd - -import ( - "crypto/sha256" - "crypto/subtle" - "fmt" - "syscall" - - "golang.org/x/crypto/bcrypt" - "golang.org/x/crypto/ssh/terminal" -) - -// ReadAndHashPassword reads a password from stdin, taking care to not echo -// it and then immediately hashes the password using an expensive (bcrypt) -// hash function. This hash may be safely stored to disk. It also returns a -// second, stable, hash of the password that may be used for symmetric -// encryption of data. It returns the cost and salt used to generate the hash -// since they need to be available when verifying a password against the hash. -func ReadAndHashPassword() (hash []byte, key AESKey, salt []byte, cost, major, minor int, err error) { - fmt.Print("Enter Password: ") - password, err := terminal.ReadPassword(int(syscall.Stdin)) - fmt.Println("") - defer func() { - // Zero out the password as soon as it's hashed. - for i := range password { - password[i] = 0 - } - - }() - - if err != nil { - return - } - return hashPassword(password) -} - -func hashPassword(password []byte) (hash []byte, key AESKey, salt []byte, cost, major, minor int, err error) { - salt, err = bcrypt.GenerateSalt() - if err != nil { - return - } - - cost = bcrypt.DefaultCost - bkey, err := bcrypt.Bcrypt(password, salt, cost) - if err != nil { - return - } - - if n := copy(key[:], bkey); n < cap(key) { - err = fmt.Errorf("key too short") - return - } - sum := sha256.Sum256(key[:]) - hash = sum[:] - major = bcrypt.MajorVersion - minor = bcrypt.MinorVersion - return -} - -// ReadAndComparePassword reads a password from stdin, taking care to not echo -// it and then compares that password with the supplied hash. -func ReadAndComparePassword(hash, salt []byte, cost, major, minor int) (AESKey, error) { - fmt.Print("Enter Password: ") - password, err := terminal.ReadPassword(int(syscall.Stdin)) - fmt.Println("") - defer func() { - // Zero out the password as soon as it's hashed. - for i := range password { - password[i] = 0 - } - - }() - if err != nil { - return zeroKey, err - } - return comparePassword(password, hash, salt, cost, major, minor) -} - -func comparePassword(password, hash, salt []byte, cost, major, minor int) (AESKey, error) { - bkey, err := bcrypt.Bcrypt(password, salt, cost) - if err != nil { - return zeroKey, err - } - key := zeroKey - if n := copy(key[:], bkey); n < cap(key) { - return zeroKey, fmt.Errorf("key too short") - } - h := sha256.Sum256(key[:]) - if subtle.ConstantTimeCompare(h[:], hash) == 0 { - return zeroKey, fmt.Errorf("mismatched passwords") - } - return key, err -} diff --git a/crypto/encryption/passwd/passwd_test.go b/crypto/encryption/passwd/passwd_test.go deleted file mode 100644 index f255c83d..00000000 --- a/crypto/encryption/passwd/passwd_test.go +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright 2017 GRAIL, Inc. All rights reserved. -// Use of this source code is governed by the Apache-2.0 -// license that can be found in the LICENSE file. - -package passwd - -import ( - "bytes" - "crypto/hmac" - "crypto/sha512" - "strings" - "testing" - - "github.com/grailbio/base/crypto/encryption" - "github.com/grailbio/testutil" - "github.com/grailbio/testutil/expect" - "golang.org/x/crypto/bcrypt" -) - -func TestHash(t *testing.T) { - pw := []byte("any old pw") - - hash, key, salt, cost, major, minor, err := hashPassword(pw) - if err != nil { - t.Fatal(err) - } - - cmp := func(got, want int) { - if got != want { - t.Errorf("%v: got %d, want %d", testutil.Caller(1), got, want) - } - } - cmp(cost, bcrypt.DefaultCost) - cmp(major, bcrypt.MajorVersion) - cmp(minor, bcrypt.MinorVersion) - cmp(len(hash), 32) // sha256 - cmp(len(key), 16) // aes128 - - rkey, err := comparePassword(pw, hash, salt, cost, major, minor) - if err != nil { - t.Fatal(err) - } - - if !bytes.Equal(key[:], rkey[:]) { - t.Fatalf("mismatched passwords") - } - - rkey, err = comparePassword([]byte("oops"), hash, salt, cost, major, minor) - if err == nil || !strings.Contains(err.Error(), "mismatched") { - t.Fatalf("failed to detect mismatched passwords") - } - - if bytes.Equal(key[:], rkey[:]) { - t.Fatalf("incorrectly matched passwords and yet the keys are the same") - } -} - -func TestRegistry(t *testing.T) { - reg := NewKeyRegistry() - encryption.Register("pw", reg) - - if got, want := reg.BlockSize(), 16; got != want { - t.Errorf("got %v, want %v", got, want) - } - if got, want := reg.HMACSize(), 64; got != want { - t.Errorf("got %v, want %v", got, want) - } - pw := []byte("any old pw") - - hash, key, salt, cost, major, minor, err := hashPassword(pw) - if err != nil { - t.Fatal(err) - } - id, err := reg.generateKey(hash, key, salt, cost, major, minor) - if err != nil { - t.Fatal(err) - } - - kd := encryption.KeyDescriptor{Registry: "pw", ID: id} - enc, err := encryption.NewEncrypter(kd) - if err != nil { - t.Fatal(err) - } - pt := []byte("my message") - ct := make([]byte, enc.CiphertextSize(pt)) - enc.Encrypt(pt, ct) - - if bytes.Contains(ct, []byte(pt)) { - t.Fatalf("encryptiong failed: %s", ct) - } - - dec, _ := encryption.NewDecrypter(kd) - buf := make([]byte, dec.PlaintextSize(ct)) - sum, pt1, _ := dec.Decrypt(ct, buf) - - if !bytes.Equal(pt, pt1) { - t.Fatalf("encryption/decryption failed: %v != %v", pt, pt1) - } - - h := hmac.New(sha512.New, key[:]) - h.Write(pt) - if got, want := h.Sum(nil), sum; !bytes.Equal(got[:], want) { - t.Errorf("got %v, want %v", got, want) - } - - // Test errors: - decryptError := func(id string) error { - kd := encryption.KeyDescriptor{Registry: "pw", ID: []byte(id)} - dec, _ := encryption.NewDecrypter(kd) - _, _, err := dec.Decrypt(ct, buf) - return err - } - err = decryptError("{xx") - expect.HasSubstr(t, err, "failed to unmarshal ID") - - err = decryptError(`{"hash":"\t"}`) - expect.HasSubstr(t, err, "failed to decode hash") - - err = decryptError(`{"salt":"\t"}`) - expect.HasSubstr(t, err, "failed to decode salt") - -} diff --git a/crypto/encryption/passwd/registry.go b/crypto/encryption/passwd/registry.go deleted file mode 100644 index 8a6a008f..00000000 --- a/crypto/encryption/passwd/registry.go +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright 2017 GRAIL, Inc. All rights reserved. -// Use of this source code is governed by the Apache-2.0 -// license that can be found in the LICENSE file. - -package passwd - -import ( - "bytes" - "crypto/aes" - "crypto/cipher" - "crypto/hmac" - "crypto/sha512" - "encoding/hex" - "encoding/json" - "fmt" - "hash" - "sync" - - "github.com/grailbio/base/crypto/encryption" -) - -// AESKey represents a 16 byte AES key for AES 128. -type AESKey [16]byte - -// AES represents a passwd based key registry that uses AES encryption. -type AES struct { - mu sync.Mutex - keys map[string]AESKey - passwd []byte -} - -// NewKeyRegistry creates a new key registry. -func NewKeyRegistry() *AES { - return &AES{ - keys: map[string]AESKey{}, - } -} - -func init() { - encryption.Register("passwd-aes", NewKeyRegistry()) -} - -// SetIDAndKey stores the specified ID and Key in the key registry. -func (c *AES) SetIDAndKey(ID []byte, key AESKey) { - c.mu.Lock() - c.keys[string(ID)] = key - c.mu.Unlock() -} - -// The ID field in the key registry is used to encode the metadata -// and the hash of the password. -type idPayload struct { - Hash string `json:"hash"` - Salt string `json:"salt"` - Cost int `json:"cost"` - Major int `json:"major"` - Minor int `json:"minor"` -} - -// GenerateKey implements encryption.KeyRegistry. -func (c *AES) GenerateKey() (ID []byte, err error) { - hash, key, salt, cost, major, minor, err := ReadAndHashPassword() - if err != nil { - return nil, err - } - return c.generateKey(hash, key, salt, cost, major, minor) -} - -func (c *AES) generateKey(hash []byte, key AESKey, salt []byte, cost, major, minor int) (ID []byte, err error) { - payload := idPayload{ - hex.EncodeToString(hash), - hex.EncodeToString(salt), - cost, - major, - minor, - } - id, err := json.Marshal(&payload) - if err != nil { - return nil, fmt.Errorf("failed to marshal metadata: %v", err) - } - c.SetIDAndKey(id, key) - return id, nil -} - -// BlockSize implements encryption.KeyRegistry. -func (c *AES) BlockSize() int { - return aes.BlockSize -} - -// HMACSize implements encryption.KeyRegistry. -func (c *AES) HMACSize() int { - return sha512.Size -} - -var zeroKey = AESKey{} - -// NewBlock implements encryption.KeyRegistry. -func (c *AES) NewBlock(ID []byte, opts ...interface{}) (hmc hash.Hash, block cipher.Block, err error) { - id := idPayload{} - if err := json.Unmarshal(ID, &id); err != nil { - return nil, nil, fmt.Errorf("failed to unmarshal ID: %v", err) - } - c.mu.Lock() - key := c.keys[string(ID)] - c.mu.Unlock() - hash, err := hex.DecodeString(id.Hash) - if err != nil { - return nil, nil, fmt.Errorf("failed to decode hash: %v", err) - } - salt, err := hex.DecodeString(id.Salt) - if err != nil { - return nil, nil, fmt.Errorf("failed to decode salt: %v", err) - } - if bytes.Equal(key[:], zeroKey[:]) { - key, err = ReadAndComparePassword(hash, salt, id.Cost, id.Major, id.Minor) - if err != nil { - return nil, nil, err - } - } - c.mu.Lock() - c.keys[string(ID)] = key - c.mu.Unlock() - hmc = hmac.New(sha512.New, key[:]) - blk, err := aes.NewCipher(key[:]) - return hmc, blk, err -} - -// NewGCM implements encryption.KeyRegistry. -func (c *AES) NewGCM(block cipher.Block, opts ...interface{}) (aead cipher.AEAD, err error) { - return nil, fmt.Errorf("not implemented yet") -} diff --git a/crypto/encryption/passwd/testfile.go b/crypto/encryption/passwd/testfile.go deleted file mode 100644 index 40938679..00000000 --- a/crypto/encryption/passwd/testfile.go +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright 2017 GRAIL, Inc. All rights reserved. -// Use of this source code is governed by the Apache-2.0 -// license that can be found in the LICENSE file. - -// A simple utility to test writing and reading encrypted files. -// As such, it is intended to be run interactively and cannot be run as -// part of automated tests. It may run by executing: -// -// $ go run testfile.go -// -// +build ignore - -package main - -import ( - "os" - "os/exec" - "path/filepath" - "syscall" - - _ "github.com/grailbio/crypto/encryption/passwd" - "v.io/x/lib/gosh" -) - -const msg = ` -This is a interactive manual test for writing/reading encrypted files. -Run it and make sure that: -1. it prints 'hello\nsafe and secure\nworld\n' when you supply matching passwords. -2. it prints an error when the passwords don't match. - -In all cases, verify that the password is not echoed to the terminal. -` - -func main() { - sh := gosh.NewShell(nil) - defer sh.Cleanup() - tmpdir := sh.MakeTempDir() - filename := filepath.Join(tmpdir, "test.grail-rpk-kd") - - wr := filepath.Join(tmpdir, "write-grail-rpk-kd") - rd := filepath.Join(tmpdir, "read-grail-rpk-kd") - - gosh.BuildGoPkg(sh, tmpdir, "./testwrite.go", "-o", wr) - gosh.BuildGoPkg(sh, tmpdir, "./testread.go", "-o", rd) - - fd, err := syscall.Open("/dev/tty", syscall.O_RDWR, 0) - if err != nil { - panic(err) - } - - for _, name := range []string{wr, rd} { - tty := os.NewFile(uintptr(fd), "/dev/tty") - cmd := exec.Command(name, filename) - cmd.Stdin, cmd.Stdout, cmd.Stderr = tty, tty, tty - if err := cmd.Start(); err != nil { - panic(err) - } - if err := cmd.Wait(); err != nil { - panic(err) - } - } -} diff --git a/crypto/encryption/passwd/testread.go b/crypto/encryption/passwd/testread.go deleted file mode 100644 index 571ec6b9..00000000 --- a/crypto/encryption/passwd/testread.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2017 GRAIL, Inc. All rights reserved. -// Use of this source code is governed by the Apache-2.0 -// license that can be found in the LICENSE file. - -// A simple utility to test reading an encrypted file. -// As such, it is intended to be run interactively and cannot be run as -// part of automated tests. It may run by executing: -// -// $ go run testread.go -// -// It prints 'hello\nsafe and secure\nworld\n' if you enter the same password -// and file used for testwrite.go. -// -// +build ignore - -package main - -import ( - "fmt" - "os" - - _ "github.com/grailbio/base/crypto/encryption/passwd" - "github.com/grailbio/base/recordio/recordioutil" -) - -const msg = ` -This is a interactive manual test for reading an encrypted file. -Run it and make sure that: -1. it prints 'hello\nsafe and secure\nworld\n' when you supply matching passwords. - -Also, verify that the password is not echoed to the terminal. -` - -func main() { - file, err := os.Open(os.Args[1]) - if err != nil { - panic(err) - } - opts := recordioutil.ScannerOptsFromName(os.Args[1]) - in, err := recordioutil.NewScanner(file, opts) - if err != nil { - panic(err) - } - for in.Scan() { - fmt.Printf("%s", string(in.Bytes())) - } - if err := in.Err(); err != nil { - panic(err) - } -} diff --git a/crypto/encryption/passwd/testwrite.go b/crypto/encryption/passwd/testwrite.go deleted file mode 100644 index 39968c65..00000000 --- a/crypto/encryption/passwd/testwrite.go +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2017 GRAIL, Inc. All rights reserved. -// Use of this source code is governed by the Apache-2.0 -// license that can be found in the LICENSE file. - -// A simple utility to test writing an encrypted files. -// As such, it is intended to be run interactively and cannot be run as -// part of automated tests. It may run by executing: -// -// $ go run testwrite.go -// -// It will 'hello\nsafe and secure\nworld\n' to the encrypted file. Use -// $ go run testread.go -// to decrypt that file. -// -// +build ignore - -package main - -import ( - "os" - - "github.com/grailbio/base/crypto/encryption" - _ "github.com/grailbio/base/crypto/encryption/passwd" - "github.com/grailbio/base/recordio/recordioutil" -) - -const msg = ` -This is a interactive manual test for writing an encrypted file. Supply a -filename and password to encrypt that file with. - -Also, verify that the password is not echoed to the terminal. -` - -func main() { - reg, err := encryption.Lookup("passwd-aes") - if err != nil { - panic(err) - } - id, err := reg.GenerateKey() - kd := encryption.KeyDescriptor{Registry: "passwd-aes", - ID: id, - } - opts := recordioutil.WriterOpts{KeyDescriptor: &kd} - file, err := os.OpenFile(os.Args[1], os.O_CREATE|os.O_TRUNC|os.O_RDWR, 0666) - if err != nil { - panic(err) - } - out, err := recordioutil.NewWriter(file, opts) - if err != nil { - panic(err) - } - out.Write([]byte("hello\n")) - out.Write([]byte("safe and secure\n")) - out.Write([]byte("world\n")) - out.Flush() - file.Close() -} diff --git a/diagnostic/dump/default.go b/diagnostic/dump/default.go new file mode 100644 index 00000000..01dd19a8 --- /dev/null +++ b/diagnostic/dump/default.go @@ -0,0 +1,227 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package dump + +import ( + "archive/zip" + "context" + "encoding/json" + "expvar" + "fmt" + "io" + "os" + "path/filepath" + "runtime" + "runtime/pprof" + "strings" + "time" + + "github.com/grailbio/base/log" + "github.com/shirou/gopsutil/cpu" + "github.com/shirou/gopsutil/load" + "github.com/shirou/gopsutil/mem" +) + +// DefaultRegistry is a default registry that has this process's GUID as its ID. +var DefaultRegistry = NewRegistry(readExec()) + +// Register registers a new part to be included in the dump of the +// DefaultRegistry. name will become the filename of the part file in the dump +// tarball. f will be called to produce the contents of that file. +func Register(name string, f Func) { + DefaultRegistry.Register(name, f) +} + +// WriteDump writes a dump of the default registry. +func WriteDump(ctx context.Context, pfx string, zw *zip.Writer) { + DefaultRegistry.WriteDump(ctx, pfx, zw) +} + +// Name returns the name of the default registry. See (*Registry).Name. +func Name() string { + return DefaultRegistry.Name() +} + +// readExec returns a sanitized version of the executable name, if it can be +// determined. If not, returns "unknown". +func readExec() string { + const unknown = "unknown" + execPath, err := os.Executable() + if err != nil { + return unknown + } + rawExec := filepath.Base(execPath) + var sanitized strings.Builder + for _, r := range rawExec { + if (r == '-' || 'a' <= r && r <= 'z') || ('0' <= r && r <= '9') { + sanitized.WriteRune(r) + } + } + if sanitized.Len() == 0 { + return unknown + } + return sanitized.String() +} + +// shellQuote quotes a string to be used as an argument in an sh command line. +func shellQuote(s string) string { + // We wrap with single quotes, as they will work with any string except + // those with single quotes. We handle single quotes by tranforming them + // into "'\''" and letting the shell concatenate the strings back together. + return "'" + strings.Replace(s, "'", `'\''`, -1) + "'" +} + +// dumpCmdline writes the command-line of the current execution. It writes it +// in a format that can be directly pasted into sh to be run. +func dumpCmdline(ctx context.Context, w io.Writer) error { + args := make([]string, len(os.Args)) + for i := range args { + args[i] = shellQuote(os.Args[i]) + } + _, err := io.WriteString(w, strings.Join(args, " ")) + return err +} + +func dumpCpuinfo(ctx context.Context, w io.Writer) error { + info, err := cpu.InfoWithContext(ctx) + if err != nil { + return fmt.Errorf("error getting cpuinfo: %v", err) + } + s, err := json.MarshalIndent(info, "", " ") + if err != nil { + return fmt.Errorf("error marshaling cpuinfo: %v", err) + } + _, err = w.Write(s) + return err +} + +func dumpLoadinfo(ctx context.Context, w io.Writer) error { + type loadinfo struct { + Avg *load.AvgStat `json:"average"` + Misc *load.MiscStat `json:"miscellaneous"` + } + var info loadinfo + avg, err := load.AvgWithContext(ctx) + if err != nil { + return fmt.Errorf("error getting load averages: %v", err) + } + info.Avg = avg + misc, err := load.MiscWithContext(ctx) + if err != nil { + return fmt.Errorf("error getting miscellaneous load stats: %v", err) + } + info.Misc = misc + s, err := json.MarshalIndent(info, "", " ") + if err != nil { + return fmt.Errorf("error marshaling loadinfo: %v", err) + } + _, err = w.Write(s) + return err +} + +func dumpMeminfo(ctx context.Context, w io.Writer) error { + type meminfo struct { + Virtual *mem.VirtualMemoryStat `json:"virtualMemory"` + Runtime runtime.MemStats `json:"goRuntime"` + } + var info meminfo + vmem, err := mem.VirtualMemoryWithContext(ctx) + if err != nil { + return fmt.Errorf("error getting virtual memory stats: %v", err) + } + info.Virtual = vmem + runtime.ReadMemStats(&info.Runtime) + s, err := json.MarshalIndent(info, "", " ") + if err != nil { + return fmt.Errorf("error marshaling meminfo: %v", err) + } + _, err = w.Write(s) + if err != nil { + return fmt.Errorf("error writing memory stats: %v", err) + } + return nil +} + +// dumpGoroutine writes current goroutines with human-readable source +// locations. +func dumpGoroutine(ctx context.Context, w io.Writer) error { + p := pprof.Lookup("goroutine") + if p == nil { + panic("no goroutine profile") + } + // debug == 2 prints goroutine stacks in the same form as that printed for + // an unrecovered panic. + return p.WriteTo(w, 2) +} + +// dumpPprofHeap writes a pprof heap profile. +func dumpPprofHeap(ctx context.Context, w io.Writer) error { + p := pprof.Lookup("heap") + if p == nil { + panic("no heap profile") + } + return p.WriteTo(w, 0) +} + +// dumpPprofMutex writes a fraction of the stack traces of goroutines with +// contended mutexes. +func dumpPprofMutex(ctx context.Context, w io.Writer) error { + p := pprof.Lookup("mutex") + if p == nil { + panic("no mutex profile") + } + // debug == 1 makes use function names instead of hexadecimal addresses, so + // it can also be human-readable. + return p.WriteTo(w, 1) +} + +// dumpPprofHeap writes a pprof CPU profile sampled for 30 seconds or until the +// context is done, whichever is shorter. +func dumpPprofProfile(ctx context.Context, w io.Writer) error { + if err := pprof.StartCPUProfile(w); err != nil { + return err + } + startTime := time.Now() + defer pprof.StopCPUProfile() + select { + case <-time.After(30 * time.Second): + case <-ctx.Done(): + d := time.Since(startTime) + log.Debug.Printf("dump: CPU profile cut short to %s", d.String()) + } + return nil +} + +// dumpVars writes public variables exported by the expvar package. The output +// is equivalent to the output of the "/debug/vars" endpoint. +func dumpVars(ctx context.Context, w io.Writer) error { + if _, err := fmt.Fprintf(w, "{\n"); err != nil { + return err + } + var ( + err error + first = true + ) + expvar.Do(func(kv expvar.KeyValue) { + if !first { + if _, err = fmt.Fprintf(w, ",\n"); err != nil { + return + } + } + first = false + if _, err = fmt.Fprintf(w, "%q: %s", kv.Key, kv.Value); err != nil { + return + } + }) + if err != nil { + return err + } + if _, err := fmt.Fprintf(w, "\n}\n"); err != nil { + return err + } + return nil +} + +// Func is the type of a function that is registered in (*Registry).Register to diff --git a/diagnostic/dump/dump.go b/diagnostic/dump/dump.go new file mode 100644 index 00000000..61fe4e0e --- /dev/null +++ b/diagnostic/dump/dump.go @@ -0,0 +1,260 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +// Package dump provides the endpoint "debug/dump", registered with +// http.DefaultServeMux, which returns a dump of useful diagnostic information +// as a tarball. The base configuration includes several useful diagnostics +// (see init). You may also register your own dump parts to be included, e.g.: +// +// Register("mystuff", func(ctx context.Context, w io.Writer) error { +// w.Write([]byte("mystuff diagnostic data")) +// return nil +// }) +// +// The endpoint responds with a gzipped tarball. The Content-Disposition of the +// response suggests a pseudo-unique filename to make it easier to deal with +// multiple dumps. Use curl flags to accept the suggested filename +// (recommended). +// +// curl -OJ http://example:1234/debug/dump +// +// Note that it will take at least 30 seconds to respond, as some of the parts +// of the base configuration are 30-second profiles. +package dump + +import ( + "archive/zip" + "context" + "errors" + "fmt" + "io" + "io/ioutil" + "net/http" + "os" + "strings" + "sync" + "time" + + "github.com/grailbio/base/log" + "github.com/grailbio/base/traverse" +) + +// init registers commonly useful parts in the registry and configures +// http.DefaultServeMux with the endpoint "/debug/dump" for getting the dump. +func init() { + Register("cmdline", dumpCmdline) + Register("cpuinfo", dumpCpuinfo) + Register("loadinfo", dumpLoadinfo) + Register("meminfo", dumpMeminfo) + Register("pprof-goroutine", dumpGoroutine) + Register("pprof-heap", dumpPprofHeap) + Register("pprof-mutex", dumpPprofMutex) + Register("pprof-profile", dumpPprofProfile) + Register("vars", dumpVars) + http.Handle("/debug/dump.zip", DefaultRegistry) +} + +// ErrSkipPart signals that we should skip a part. Return this from your +// Func to silently ignore the part for the current dump. If your Func +// returns anything else non-nil, it will be logged as an error. This is +// mostly useful for keeping logs quiet for parts that are sometimes +// unavailable for non-error reasons. +var ErrSkipPart = errors.New("skip part") + +// part is one part of a dump. It is ultimately expressed as a single file that +// is part the tarball archive dump. +type part struct { + // name is the name of this part of the dump. It is used as the filename in + // the dump tarball. + name string + // f is called to produce the contents of this part of the dump. + f Func +} + +// Func is the function to be called when producing a dump for a part. +type Func func(ctx context.Context, w io.Writer) error + +// Registry maintains the set of parts that will compose the dump. +type Registry struct { + mu sync.Mutex + // id is the identifier of this registry, which eventually becomes part of + // the suggested filename for the dump. + id string + parts []part + + // createTime is the time at which this Registry was created with + // NewRegistry. + createTime time.Time +} + +// NewRegistry returns a new registry for the parts to be included in the dump. +func NewRegistry(id string) *Registry { + return &Registry{id: id, createTime: time.Now()} +} + +// Name returns a name for reg that is convenient for naming dump files, as it +// is pseudo-unique and includes the registry ID, the time at which the registry +// was created, and the duration from that creation time. +func (reg *Registry) Name() string { + sinceCreate := time.Since(reg.createTime) + ss := []string{reg.id, reg.createTime.Format(createTimeFormat), formatDuration(sinceCreate)} + return strings.Join(ss, ".") +} + +// Register registers a new part to be included in the dump of reg. Name will +// become the filename of the part file in the dump tarball. Func f will be +// called to produce the contents of that file. +func (reg *Registry) Register(name string, f Func) { + reg.mu.Lock() + defer reg.mu.Unlock() + for _, part := range reg.parts { + if part.name == name { + panic(fmt.Sprintf("duplicate part name %q", name)) + } + } + reg.parts = append(reg.parts, part{name: name, f: f}) +} + +// partFile is used by worker goroutines to communicate results back to the main +// dumping thread. Only one of err and file will be non-nil. +type partFile struct { + // part is the part to which this partFile applies. + part part + // err will be non-nil if there was an error producing the file of the part + // of the dump. + err error + // file will be non-nil in a successful result and will be the file that + // will be included in the dump tarball. + file *os.File +} + +// processPart is called by worker goroutines to process a single part. +func processPart(ctx context.Context, part part) partFile { + tmpfile, err := ioutil.TempFile("", "dump") + if err != nil { + return partFile{ + part: part, + err: fmt.Errorf("error creating temp file: %v", err), + } + } + if err := os.Remove(tmpfile.Name()); err != nil { + log.Printf("dump: error removing temp file %s: %v", tmpfile.Name(), err) + } + if err := part.f(ctx, tmpfile); err != nil { + _ = tmpfile.Close() + if err == ErrSkipPart { + return partFile{part: part, err: err} + } + return partFile{ + part: part, + err: fmt.Errorf("error writing part contents: %v", err), + } + } + if _, err := tmpfile.Seek(0, 0); err != nil { + _ = tmpfile.Close() + return partFile{ + part: part, + err: fmt.Errorf("error seeking to read temp file for dump: %v", err), + } + } + // The returned file will be closed downstream after its contents have been + // written to the dump. + return partFile{part: part, file: tmpfile} +} + +// writeFile writes a file to zw with filename name. +func writeFile(name string, f *os.File, zw *zip.Writer) error { + fi, err := f.Stat() + if err != nil { + return fmt.Errorf("error getting file stat of %q: %v", f.Name(), err) + } + hdr, err := zip.FileInfoHeader(fi) + if err != nil { + return fmt.Errorf("error building zip header of %q: %v", f.Name(), err) + } + hdr.Name = name + zfw, err := zw.CreateHeader(hdr) + if err != nil { + return fmt.Errorf("error writing zip header in diagnostic dump: %v", err) + } + if _, err = io.Copy(zfw, f); err != nil { + return fmt.Errorf("error writing diagnostic dump: %v", err) + } + return nil +} + +// writePart writes a single part to zw. pfx is the path that will be prepended +// to the part name to construct the full path of the entry in the archive. +func writePart(pfx string, p partFile, zw *zip.Writer) (err error) { + if p.err != nil { + if p.err == ErrSkipPart { + return nil + } + return fmt.Errorf("error dumping %s: %v", p.part.name, p.err) + } + defer func() { + closeErr := p.file.Close() + if err == nil && closeErr != nil { + err = fmt.Errorf("error closing temp file %q: %v", p.file.Name(), closeErr) + } + }() + if fileErr := writeFile(pfx+"/"+p.part.name, p.file, zw); fileErr != nil { + return fmt.Errorf("error writing %s to archive: %v", p.part.name, fileErr) + } + return nil +} + +// WriteDump writes the dump to w. pfx is prepended to the names of the parts of +// the dump, e.g. if pfx == "dump-123" and part name == "cpu", "dump-123/cpu" +// will be written into the archive. It returns no error, as it is best-effort. +func (reg *Registry) WriteDump(ctx context.Context, pfx string, zw *zip.Writer) { + reg.mu.Lock() + // Snapshot reg.parts to release the lock quickly. + parts := reg.parts + reg.mu.Unlock() + const concurrency = 8 + partFileC := make(chan partFile, concurrency) + go func() { + defer close(partFileC) + err := traverse.Parallel.Each(len(parts), func(i int) error { + partCtx, partCtxCancel := context.WithTimeout(ctx, 2*time.Minute) + partFile := processPart(partCtx, parts[i]) + partCtxCancel() + partFileC <- partFile + return nil + }) + if err != nil { + log.Error.Printf("dump: error processing parts: %v", err) + return + } + }() + for p := range partFileC { + if err := writePart(pfx, p, zw); err != nil { + log.Error.Printf("dump: error processing part %s: %v", p.part.name, err) + } + } +} + +var createTimeFormat = "2006-01-02-1504" + +func formatDuration(d time.Duration) string { + d = d.Round(time.Second) + h := d / time.Hour + d -= h * time.Hour + m := d / time.Minute + d -= m * time.Minute + s := d / time.Second + return fmt.Sprintf("%02dh%02dm%02ds", h, m, s) +} + +// ServeHTTP serves the dump with a Content-Disposition set with a unique filename. +func (reg *Registry) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/zip") + pfx := Name() + filename := pfx + ".zip" + w.Header().Set("Content-Disposition", "attachment; filename="+filename) + zw := zip.NewWriter(w) + defer zw.Close() // nolint: errcheck + reg.WriteDump(r.Context(), pfx, zw) +} diff --git a/diagnostic/dump/dump_test.go b/diagnostic/dump/dump_test.go new file mode 100644 index 00000000..837abeaf --- /dev/null +++ b/diagnostic/dump/dump_test.go @@ -0,0 +1,156 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package dump + +import ( + "archive/zip" + "bytes" + "context" + "errors" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "reflect" + "regexp" + "sort" + "sync" + "testing" +) + +func makeDumpConst(errC chan<- error, s string) Func { + return func(ctx context.Context, w io.Writer) error { + if _, err := w.Write([]byte(s)); err != nil { + // This should not happen, so we let the main test goroutine know. + errC <- err + } + return nil + } +} + +func makeDumpError(errC chan<- error, s string) Func { + return func(ctx context.Context, w io.Writer) error { + // Fake a partial failed write. + s := s[:len(s)/2] + if _, err := w.Write([]byte(s)); err != nil { + // This should not happen, so we let the main test goroutine know. + errC <- err + } + return errors.New("dump func error") + } +} + +func dumpSkipPart(_ context.Context, _ io.Writer) error { + return ErrSkipPart +} +func TestShellQuote(t *testing.T) { + for _, c := range []struct { + s string + want string + }{ + {``, `''`}, + {`'`, `''\'''`}, + {`hello`, `'hello'`}, + {`hello world`, `'hello world'`}, + {`hello'world`, `'hello'\''world'`}, + } { + if got, want := shellQuote(c.s), c.want; got != want { + t.Errorf("got %q, want %q", got, want) + } + } +} + +func verifyDump(t *testing.T, server *httptest.Server, dumpFuncErrC chan error, wantNames []string) { + var dumpFuncErr error + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + dumpFuncErr = <-dumpFuncErrC + }() + + resp, err := http.Get(server.URL + "/dump.zip") + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if got, want := resp.StatusCode, http.StatusOK; got != want { + t.Fatalf("got %v, want %v", got, want) + } + // Read the whole body, so we can immediately make sure that our dump + // funcs worked. + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("could not read dump body: %v", err) + } + close(dumpFuncErrC) + wg.Wait() + if dumpFuncErr != nil { + t.Fatalf("unexpected error writing dump part: %v", dumpFuncErr) + } + zr, err := zip.NewReader(bytes.NewReader(body), int64(len(body))) + if err != nil { + t.Fatal(err) + } + re := regexp.MustCompile(`.*/`) + var names []string + for _, entry := range zr.File { + // Strip the prefix to recover the original name. + name := re.ReplaceAllString(entry.Name, "") + names = append(names, name) + var contents bytes.Buffer + rc, err := entry.Open() + if err != nil { + t.Fatal(err) + } + if _, err := io.Copy(&contents, rc); err != nil { + t.Fatal(err) + } + if err := rc.Close(); err != nil { + t.Fatal(err) + } + // Assume contents are "-contents", matching our known + // construction of the dump contents. + if got, want := contents.String(), name+"-contents"; got != want { + t.Errorf("got %v, want %v", got, want) + } + } + sort.Strings(names) + sort.Strings(wantNames) + if got, want := names, wantNames; !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestServeHTTP(t *testing.T) { + reg := NewRegistry("abc") + dumpFuncErrC := make(chan error) + reg.Register("foo", makeDumpConst(dumpFuncErrC, "foo-contents")) + reg.Register("bar", makeDumpConst(dumpFuncErrC, "bar-contents")) + reg.Register("baz", makeDumpConst(dumpFuncErrC, "baz-contents")) + + mux := http.NewServeMux() + mux.Handle("/dump.zip", reg) + server := httptest.NewServer(mux) + + verifyDump(t, server, dumpFuncErrC, []string{"foo", "bar", "baz"}) +} + +func TestServeHTTPFailedParts(t *testing.T) { + reg := NewRegistry("abc") + dumpFuncErrC := make(chan error) + reg.Register("foo", makeDumpConst(dumpFuncErrC, "foo-contents")) + // Note that the following dump part funcs will return an error. + reg.Register("bar", makeDumpError(dumpFuncErrC, "bar-contents")) + reg.Register("baz", makeDumpError(dumpFuncErrC, "baz-contents")) + reg.Register("skip", dumpSkipPart) + + mux := http.NewServeMux() + mux.Handle("/dump.zip", reg) + server := httptest.NewServer(mux) + + // Verify that only the successful dump part func is in the dump. + verifyDump(t, server, dumpFuncErrC, []string{"foo"}) +} diff --git a/diagnostic/memsize/deep_size.go b/diagnostic/memsize/deep_size.go new file mode 100644 index 00000000..39b2abe1 --- /dev/null +++ b/diagnostic/memsize/deep_size.go @@ -0,0 +1,159 @@ +package memsize + +import ( + "reflect" + "unsafe" +) + +// DeepSize estimates the amount of memory used by a Go value. It's intended as a +// memory usage debugging aid. Argument must be a pointer to a value. +// +// Not thread safe. Behavior is undefined if any value reachable from the argument +// is concurrently mutated. In general, do not call this in production. +// +// Behavior: +// * Recursively descends into contained values (struct fields, slice elements, +// etc.), tracking visitation (by memory address) to handle cycles. +// * Only counts slice length, not unused capacity. +// * Only counts map key and value size, not map overhead. +// * Does not count functions or channels. +// +// The implementation relies on the Go garbage collector being non-compacting (not +// moving values in memory), due to thread non-safety noted above. This is true as +// of Go 1.13, but could change in the future. +func DeepSize(x interface{}) (numBytes int) { + if x == nil { + return 0 + } + v := reflect.ValueOf(x) + if v.Kind() != reflect.Ptr { + panic("must be a pointer") + } + if v.IsNil() { + return 0 + } + scanner := &memoryScanner{ + memory: &intervalSet{}, + visited: make(map[memoryAndKind]struct{}), + } + + unaddressableBytes := scanner.scan(v.Elem(), true) + return scanner.memory.totalCovered() + unaddressableBytes +} + +type memoryAndKind struct { + interval + reflect.Kind +} + +func getMemoryAndType(x reflect.Value) memoryAndKind { + start := x.UnsafeAddr() + size := int64(x.Type().Size()) + kind := x.Kind() + return memoryAndKind{ + interval: interval{start: start, length: size}, + Kind: kind, + } +} + +// memoryScanner can recursively scan memory used by a reflect.Value +// not thread safe +// scan should only be called once +type memoryScanner struct { + memory *intervalSet // memory is a set of memory locations that are used in scan() + visited map[memoryAndKind]struct{} // visited is a map of locations that have already been visited by scan +} + +// scan recursively traverses a reflect.Value and populates all +// x is the Value whose size is to be counted +// includeX indicates whether the bytes for x itself should be counted +// returns a count of unaddressable bytes. +func (s *memoryScanner) scan(x reflect.Value, includeX bool) (unaddressableBytes int) { + if x.CanAddr() { + memtype := getMemoryAndType(x) + if _, ok := s.visited[memtype]; ok { + return + } + s.visited[memtype] = struct{}{} + s.memory.add(memtype.interval) + } else if includeX { + unaddressableBytes += int(x.Type().Size()) + } + + switch x.Kind() { + case reflect.String: + m := x.String() + hdr := (*reflect.StringHeader)(unsafe.Pointer(&m)) + s.memory.add(interval{hdr.Data, int64(hdr.Len)}) + case reflect.Array: + if containsPointers(x.Type()) { // must scan each element individually + for i := 0; i < x.Len(); i++ { + unaddressableBytes += s.scan(x.Index(i), false) + } + } + case reflect.Slice: + if x.Len() > 0 { + if containsPointers(x.Index(0).Type()) { // must scan each element individually + for i := 0; i < x.Len(); i++ { + unaddressableBytes += s.scan(x.Index(i), true) + } + } else { // add the content of the slice to the memory counter + start := x.Pointer() + size := int64(x.Index(0).Type().Size()) * int64(x.Len()) + s.memory.add(interval{start: start, length: size}) + } + } + case reflect.Interface, reflect.Ptr: + if !x.IsNil() { + unaddressableBytes += s.scan(x.Elem(), true) + } + case reflect.Struct: + for _, fieldI := range structChild(x) { + unaddressableBytes += s.scan(fieldI, false) + } + case reflect.Map: + for _, key := range x.MapKeys() { + val := x.MapIndex(key) + unaddressableBytes += s.scan(key, true) + unaddressableBytes += s.scan(val, true) + } + case reflect.Func, reflect.Chan: + // Can't do better than this: + default: + } + return +} + +func containsPointers(x reflect.Type) bool { + switch x.Kind() { + case reflect.String, reflect.Slice, reflect.Map, reflect.Interface, reflect.Ptr: + return true + case reflect.Array: + if x.Len() > 0 { + return containsPointers(x.Elem()) + } + case reflect.Struct: + for i, n := 0, x.NumField(); i < n; i++ { + if containsPointers(x.Field(i).Type) { + return true + } + } + } + return false +} + +// v must be a struct kind. +// returns all the fields of this struct (recursively for nested structs) that are pointer types +func structChild(x reflect.Value) []reflect.Value { + var ret []reflect.Value + for i, n := 0, x.NumField(); i < n; i++ { + fieldI := x.Field(i) + switch fieldI.Kind() { + case reflect.Struct: + ret = append(ret, structChild(fieldI)...) + case reflect.Ptr, reflect.String, reflect.Interface, reflect.Slice, reflect.Map: + ret = append(ret, fieldI) + } + } + return ret +} diff --git a/diagnostic/memsize/deep_size_test.go b/diagnostic/memsize/deep_size_test.go new file mode 100644 index 00000000..51bd11b3 --- /dev/null +++ b/diagnostic/memsize/deep_size_test.go @@ -0,0 +1,215 @@ +package memsize + +import ( + "testing" + "unsafe" + + "github.com/stretchr/testify/assert" +) + +var ( + x int64 = 5 + pointerSize = int(unsafe.Sizeof(&x)) + sliceSize = int(unsafe.Sizeof(make([]int64, 0))) + stringSize = int(unsafe.Sizeof("abcde")) + mapSize = int(unsafe.Sizeof(make(map[int32]int32))) + temper Temper = nil + interfaceSize = int(unsafe.Sizeof(temper)) +) + +type TestCase struct { + x interface{} + expected int +} + +func TestPrimitives(t *testing.T) { + var int8val int8 = 3 + var uint8val uint8 = 3 + var int16val int16 = -2 + var uint16val uint16 = 2 + var int32val int32 = -54 + var uint32val uint32 = 54 + var int64val int64 = -34 + var uint64val uint64 = 34 + var boolval bool = true + var float64val float64 = 3.29 + var float32val float32 = 543.23 + var int8ptr *int8 + tests := []TestCase{ + {nil, 0}, + {&int64val, 8}, + {&uint64val, 8}, + {&int32val, 4}, + {&uint32val, 4}, + {&int16val, 2}, + {&uint16val, 2}, + {&int8val, 1}, + {&uint8val, 1}, + {&float64val, 8}, + {&float32val, 4}, + {&boolval, 1}, + {int8ptr, 0}, // nil pointer + {&int8ptr, pointerSize}, // pointer pointer + } + runTests(tests, t) +} + +func TestSlicesAndArray(t *testing.T) { + var int64val int64 + var int64val2 int64 + var int64ptr *int64 + var int8slice []int8 = []int8{0, 1, 2, 3, 4, 5, 6} + var int8sliceB = int8slice[0:4] + + type smallStruct struct { // size = 16 bytes + A, B int64 + } + + type smallPointerStruct struct { // size = 24 bytes + maybe 8 for ptr + A, B int64 + APtr *int64 + } + + type complexStruct struct { // size = 40 bytes + maybe 8 for ptr + T smallStruct + Y smallPointerStruct + } + + type structWithZeroArray struct { // size = 40 bytes + maybe 8 for ptr + M [0]int64 + W complexStruct + } + + tests := []TestCase{ + {&[]int64{1, 2, 3}, sliceSize + 8*3}, + {&[]*int64{int64ptr, &int64val}, sliceSize + pointerSize*2 + 8}, + {&[]*int64{&int64val, &int64val}, sliceSize + pointerSize*2 + 8}, + {&[]*int64{&int64val2, &int64val}, sliceSize + pointerSize*2 + 2*8}, + {&[]smallStruct{{}, {}}, sliceSize + 16*2}, + {&[]complexStruct{{}, {Y: smallPointerStruct{APtr: &int64val}}}, sliceSize + 40*2 + 8}, + {&[][3]int64{{1, 2, 3}, {4, 5, 6}}, sliceSize + 2*24}, + {&[...]int64{1, 2, 3}, 8 * 3}, + {&[0]int64{}, 0}, + {&[]structWithZeroArray{{}, {}}, sliceSize + 2*40}, + {&[...]smallPointerStruct{{A: 1, B: 1, APtr: &int64val}, {A: 1, B: 1, APtr: nil}}, 2*24 + 8}, + {&int8sliceB, sliceSize + 4}, + {&[][]int8{int8slice[0:4], int8slice[0:4]}, 3*sliceSize + 4}, // overlapping memory locations + {&[][]int8{int8slice[0:4], int8slice[2:6]}, 3*sliceSize + 6}, // overlapping memory locations + } + runTests(tests, t) + +} + +func TestStrings(t *testing.T) { + var emptyString = "" + var abcdefgString = "abcdefg" + tests := []TestCase{ + {&emptyString, stringSize}, + {&abcdefgString, stringSize + 7}, + {&[]string{"abcd", "defg"}, sliceSize + 2*stringSize + 2*4}, // no string interning + {&[]string{"abcd", "abcd"}, sliceSize + 2*stringSize + 4}, // string interning + } + runTests(tests, t) +} + +func TestMap(t *testing.T) { + var int8val int8 + tests := []TestCase{ + {&map[int64]int64{2: 3}, mapSize + 8 + 8}, + {&map[string]int32{"abc": 3}, mapSize + stringSize + 3 + 4}, + {&map[string]*int8{"abc": &int8val, "def": &int8val}, mapSize + 2*stringSize + 2*3 + 2*pointerSize + 1}, + } + runTests(tests, t) +} + +type Temper interface { + Temp() +} + +type TemperMock struct { + A int64 +} + +func (TemperMock) Temp() {} + +func TestStructs(t *testing.T) { + type struct1 struct { + A int64 + B float64 + } + + type struct2 struct { + A int64 + B float64 + temper Temper + } + + type recursiveType struct { + A int64 + ptr *recursiveType + } + + type nestedType1 struct { // 8 bytes + maybe 8 bytes + A *int64 + } + type nestedType2 struct { // 8 bytes + maybe 8 bytes + X nestedType1 + } + type nestedType3 struct { // 8 bytes + maybe 8 bytes + Y nestedType2 + } + + type structWithZeroArray struct { // 8 bytes + maybe 8 bytes + X [0]int64 + Y nestedType2 + } + + var int64val int64 + var recursiveVar1 = recursiveType{A: 1} + var recursiveVar2 = recursiveType{A: 2} + recursiveVar1.ptr = &recursiveVar2 + recursiveVar2.ptr = &recursiveVar1 + + tests := []TestCase{ + {&nestedType3{Y: nestedType2{X: nestedType1{A: &int64val}}}, pointerSize + 8}, + {&struct1{1, 1}, 16}, + {&struct2{1, 1, TemperMock{}}, 16 + interfaceSize + 8}, + {&struct2{1, 1, nil}, 16 + interfaceSize}, + {&recursiveVar1, 2 * (8 + pointerSize)}, + {&structWithZeroArray{Y: nestedType2{}, X: [0]int64{}}, 8}, + } + + runTests(tests, t) +} + +func TestCornerCaseTypes(t *testing.T) { + var chanVar chan int + tests := []TestCase{ + {&struct{ A func(x int) int }{A: func(x int) int { return x + 1 }}, pointerSize}, + {&chanVar, pointerSize}, + } + runTests(tests, t) +} + +func TestPanicOnNonNil(t *testing.T) { + tests := []interface{}{ + "abc", + 5, + 3.5, + struct { + A int + B int + }{A: 5, B: 5}, + } + for i := range tests { + assert.Panics(t, func() { DeepSize(tests[i]) }, "should panic") + } +} + +func runTests(tests []TestCase, t *testing.T) { + for i, test := range tests { + if got := DeepSize(test.x); got != test.expected { + t.Errorf("test %d: got %d, expected %d", i, got, test.expected) + } + } +} diff --git a/diagnostic/memsize/interval_set.go b/diagnostic/memsize/interval_set.go new file mode 100644 index 00000000..def325c9 --- /dev/null +++ b/diagnostic/memsize/interval_set.go @@ -0,0 +1,105 @@ +package memsize + +import ( + "sort" +) + +// interval represents a range of integers from start (inclusive) to start+length (not inclusive) +type interval struct { + start uintptr + length int64 +} + +// intervalSet is a collection of intervals +// new intervals can be added and the total covered size computed. +// with some frequency, intervals will be compacted to save memory. +type intervalSet struct { + data []interval + nextCompact int +} + +func (r *intervalSet) add(interval interval) { + if interval.length == 0 { + return + } + r.data = append(r.data, interval) + if len(r.data) >= r.nextCompact { + r.compact() + } +} + +// compact sorts the intervals and merges adjacent intervals if they overlap. +func (r *intervalSet) compact() { + defer r.setNextCompact() + if len(r.data) < 2 { + return + } + sort.Slice(r.data, func(i, j int) bool { + return r.data[i].start < r.data[j].start + }) + basePtr := 0 + aheadPtr := 1 + for aheadPtr < len(r.data) { + if overlaps(r.data[basePtr], r.data[aheadPtr]) { + r.data[basePtr].length = max(r.data[basePtr].length, int64(r.data[aheadPtr].start-r.data[basePtr].start)+r.data[aheadPtr].length) + aheadPtr++ + } else { + basePtr++ + r.data[basePtr] = r.data[aheadPtr] + aheadPtr++ + } + } + r.data = r.data[0 : basePtr+1] + + // if the data will fit into a much smaller backing array, then copy to a smaller backing array to save memory. + if len(r.data) < cap(r.data)/4 && len(r.data) > 100 { + dataCopy := append([]interval{}, r.data...) // copy r.data to smaller array + r.data = dataCopy + } +} + +// setNextCompact sets the size that r.data must reach before the next compacting +func (r *intervalSet) setNextCompact() { + r.nextCompact = int(float64(len(r.data)) * 1.2) // increase current length by at least 20% + if r.nextCompact < cap(r.data) { // do not compact before reaching capacity of data + r.nextCompact = cap(r.data) + } + if r.nextCompact < 10 { // do not compact before reaching 10 elements. + r.nextCompact = 10 + } +} + +// tests for overlaps between two intervals. +// precondition: x.start <= y.start +func overlaps(x, y interval) bool { + return x.start+uintptr(x.length) >= y.start +} + +// totalCovered returns the total number of integers covered by the intervalSet +func (r *intervalSet) totalCovered() int { + if len(r.data) == 0 { + return 0 + } + sort.Slice(r.data, func(i, j int) bool { + return r.data[i].start < r.data[j].start + }) + total := 0 + curInterval := interval{start: r.data[0].start, length: 0} // zero width interval for initialization + for _, val := range r.data { + if overlaps(curInterval, val) { // extend the current interval + curInterval.length = max(curInterval.length, int64(val.start-curInterval.start)+val.length) + } else { // start a new interval + total += int(curInterval.length) + curInterval = val + } + } + total += int(curInterval.length) + return total +} + +func max(i, j int64) int64 { + if i > j { + return i + } + return j +} diff --git a/diagnostic/memsize/interval_set_test.go b/diagnostic/memsize/interval_set_test.go new file mode 100644 index 00000000..99571bae --- /dev/null +++ b/diagnostic/memsize/interval_set_test.go @@ -0,0 +1,128 @@ +package memsize + +import ( + "math/rand" + "testing" +) + +func TestIntervalSet(t *testing.T) { + tests := []struct { + data []interval + expectedCovered []int + }{ + { + data: []interval{{1, 10}}, + expectedCovered: []int{10}, + }, + { + data: []interval{{1, 10}, {11, 10}}, + expectedCovered: []int{10, 20}, + }, + { + data: []interval{{1, 10}, {5, 10}}, + expectedCovered: []int{10, 14}, + }, + { + data: []interval{{1, 10}, {5, 10}, {6, 9}, {6, 10}, {100, 1}, {100, 2}, {101, 1}}, + expectedCovered: []int{10, 14, 14, 15, 16, 17, 17}, + }, + { + data: []interval{{100, 1}, {99, 1}, {99, 2}, {0, 10}, {10, 10}}, + expectedCovered: []int{1, 2, 2, 12, 22}, + }, + } + + for testId, test := range tests { + for numToAdd := range test.data { + var set intervalSet + for i := 0; i <= numToAdd; i++ { + newInterval := test.data[i] + set.add(newInterval) + } + got := set.totalCovered() + if got != test.expectedCovered[numToAdd] { + t.Errorf("test: %d, query: %d: got %v, expected %v", testId, numToAdd, got, test.expectedCovered[numToAdd]) + } + } + } +} + +func TestEmptySet(t *testing.T) { + var set intervalSet + if got := set.totalCovered(); got != 0 { + t.Errorf("empty set should have 0 coverage. got %d", got) + } +} + +// TestRandomAddition creates a random slice of intervals and adds them in random order to the intervalSet. +// It periodically checks TotalCovered() against another implementation of interval set. +func TestRandomAddition(t *testing.T) { + tests := []struct { + setsize int + testrepeats int + size int + intervalrepeats int + sampleProb float64 + }{ + { + setsize: 100, + testrepeats: 100, + size: 1, + intervalrepeats: 2, + sampleProb: .2, + }, + { + setsize: 1000, + testrepeats: 100, + size: 3, + intervalrepeats: 2, + sampleProb: .2, + }, + { + setsize: 1000, + testrepeats: 100, + size: 3, + intervalrepeats: 2, + sampleProb: 0, + }, + } + + for _, test := range tests { + for testRepeat := 0; testRepeat < test.testrepeats; testRepeat++ { + r := rand.New(rand.NewSource(int64(testRepeat))) + var intervals []interval + for i := 0; i < test.setsize; i++ { + for j := 0; j < test.intervalrepeats; j++ { + intervals = append(intervals, interval{start: uintptr(i), length: int64(test.size)}) + } + } + shuffle(intervals, int64(testRepeat)) + set := intervalSet{} + coveredMap := make(map[uintptr]struct{}) + for _, x := range intervals { + set.add(x) + for j := uintptr(0); j < uintptr(x.length); j++ { + coveredMap[x.start+j] = struct{}{} + } + if r.Float64() < test.sampleProb { + gotCovered := set.totalCovered() + if gotCovered != len(coveredMap) { + t.Errorf("set.Covered()=%d, len(coveredMap) = %d", gotCovered, len(coveredMap)) + } + } + } + + if gotSize := set.totalCovered(); gotSize != test.size+test.setsize-1 { + t.Errorf("total covering - got: %d, expected %d", gotSize, test.setsize+test.setsize-1) + } + } + } +} + +// randomly shuffle a set of intervals +func shuffle(set []interval, seed int64) { + r := rand.New(rand.NewSource(seed)) + r.Shuffle(len(set), func(i, j int) { + set[i], set[j] = set[j], set[i] + }) +} diff --git a/diagnostic/stringintern/intern.go b/diagnostic/stringintern/intern.go new file mode 100644 index 00000000..0096e2e3 --- /dev/null +++ b/diagnostic/stringintern/intern.go @@ -0,0 +1,104 @@ +package stringintern + +import ( + "reflect" +) + +// Intern will recursively traverse one or more objects and collapse all strings that are identical to the same pointer, saving memory. +// Inputs must be pointer types. +// String map keys are not interned. +// The path to all fields must be exported. It is not possible to modify unexported fields in a safe way. +// Example usage: +// var x = ... some complicated type with strings +// stringintern.Intern(&x) +// Warning: This is a potentially dangerous operation. +// Extreme care must be taken that no pointers exist to other structures that should not be modified. +// This method is not thread safe. No other threads should be reading or writing to x while it is being interned. +// It is safest to use this code for testing purposes to see how much memory can be saved by interning but then do the interning explicitly: +// sizeBefore := memsize.DeepSize(&x) +// stringintern.Intern(&x) +// sizeAfter := memsize.DeepSize(&x) +func Intern(x ...interface{}) { + myinterner := interner{ + dict: make(map[string]string), + locMap: make(map[addressAndType]struct{}), + } + for _, val := range x { + value := reflect.ValueOf(val) + if value.Kind() != reflect.Ptr { + panic("input kind must be a pointer") + } + myinterner.intern(value) + } +} + +type interner struct { + // dict stores the mapping of strings to their canonical interned version. + dict map[string]string + // keeps track of which memory locations have already been scanned. + // it is necessary to also store type because structs and fields can have the same address and must be differentiated. + locMap map[addressAndType]struct{} +} + +type addressAndType struct { + address uintptr + tp reflect.Type +} + +func (s *interner) intern(x reflect.Value) { + if x.CanAddr() { + addr := x.UnsafeAddr() + x.Type().Name() + if _, alreadyProcessed := s.locMap[addressAndType{addr, x.Type()}]; alreadyProcessed { + return + } + s.locMap[addressAndType{addr, x.Type()}] = struct{}{} // mark current memory location + } + switch x.Kind() { + case reflect.String: + if x.CanSet() { + val := x.String() + s.internString(&val) + x.SetString(val) + } + case reflect.Float64, reflect.Float32, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Complex64, reflect.Complex128, + reflect.Invalid, reflect.Chan, reflect.Bool, reflect.Uintptr, reflect.Func: + // noop. don't do anything. + case reflect.Struct: + for i := 0; i < x.NumField(); i++ { + s.intern(x.Field(i)) + } + case reflect.Ptr, reflect.Interface: + if !x.IsNil() { + s.intern(x.Elem()) + } + case reflect.Slice, reflect.Array: + for i := 0; i < x.Len(); i++ { + s.intern(x.Index(i)) + } + case reflect.Map: + for _, key := range x.MapKeys() { + val := x.MapIndex(key) + if val.Kind() == reflect.String { + stringVal := val.String() + s.internString(&stringVal) + x.SetMapIndex(key, reflect.ValueOf(stringVal)) + } else { + s.intern(val) + } + } + } +} + +// takes a pointer to a string. If string has previously been seen, it will change to interned version. +// otherwise adds to dictionary of interned strings. +func (s *interner) internString(x *string) { + if val, ok := s.dict[*x]; ok { + *x = val + } else { + s.dict[*x] = *x + } +} diff --git a/diagnostic/stringintern/intern_test.go b/diagnostic/stringintern/intern_test.go new file mode 100644 index 00000000..20a46beb --- /dev/null +++ b/diagnostic/stringintern/intern_test.go @@ -0,0 +1,103 @@ +package stringintern + +import ( + "testing" + + "github.com/grailbio/base/diagnostic/memsize" +) + +type AB struct { + A string + B string +} + +// return a string with the same content which does not share the same underlying memory. +func unintern(x string) string { + ret := "" + for _, t := range x { + ret = ret + string(t) + } + return ret +} + +type Test struct { + x []interface{} + sizeBefore, sizeAfter int +} + +func TestBasic(t *testing.T) { + var int64Var int = 3 + tests := []Test{ + { + x: []interface{}{&AB{"abc", unintern("abc")}}, + sizeBefore: 86, + sizeAfter: 83, + }, + { + x: []interface{}{&map[int]string{1: "abc", 2: unintern("abc")}}, + sizeBefore: 110, + sizeAfter: 107, + }, + { + x: []interface{}{&int64Var}, + sizeBefore: 56, + sizeAfter: 56, + }, + } + + for _, test := range tests { + runTest(test, t) + } +} + +func TestCircular(t *testing.T) { + type Circular struct { + A, B string + ptr *Circular + } + + type Nested struct { + X Circular + } + + circ := Circular{ + A: "abc", + B: unintern("abc"), + } + circ.ptr = &circ + + nested := Nested{X: Circular{ + A: "abc", + B: unintern("abc"), + }} + nested.X.ptr = &nested.X + + tests := []Test{ + { + x: []interface{}{&circ}, + sizeBefore: 94, + sizeAfter: 91, + }, + { + x: []interface{}{&nested}, + sizeBefore: 94, + sizeAfter: 91, + }, + } + + for _, test := range tests { + runTest(test, t) + } +} + +func runTest(test Test, t *testing.T) { + sizeBefore := memsize.DeepSize(&test.x) + Intern(test.x...) + sizeAfter := memsize.DeepSize(&test.x) + if sizeBefore != test.sizeBefore { + t.Errorf("sizeBefore: expected=%d, got=%d", test.sizeBefore, sizeBefore) + } + if sizeAfter != test.sizeAfter { + t.Errorf("sizeAfter: expected=%d, got=%d", test.sizeAfter, sizeAfter) + } +} diff --git a/digest/digest.go b/digest/digest.go index ea5e8ed5..1b5d63ab 100644 --- a/digest/digest.go +++ b/digest/digest.go @@ -8,6 +8,7 @@ package digest import ( + "bufio" "bytes" "crypto" "crypto/rand" @@ -19,11 +20,11 @@ import ( "fmt" "hash" "io" + mathrand "math/rand" "strings" ) const maxSize = 64 // To support SHA-512 -const defaultSize = 32 // Define digestHash constants to be used during (de)serialization of Digests. // crypto.Hash values are not guaranteed to be stable over releases. @@ -53,6 +54,8 @@ const ( BLAKE2b_256 // crypto.BLAKE2b_256 BLAKE2b_384 // crypto.BLAKE2b_384 BLAKE2b_512 // crypto.BLAKE2b_512 + + zeroString = "" ) var ( @@ -165,6 +168,9 @@ func (d *Digest) GobDecode(p []byte) error { // Parse parses a string representation of Digest, as defined by // Digest.String(). func Parse(s string) (Digest, error) { + if s == "" || s == zeroString { + return Digest{}, nil + } parts := strings.Split(s, ":") if len(parts) != 2 { return Digest{}, ErrInvalidDigest @@ -195,7 +201,9 @@ func ParseHash(h crypto.Hash, hx string) (Digest, error) { return d, nil } -func newDigest(h crypto.Hash, b []byte) Digest { +// New returns a new literal digest with the provided hash and +// value. +func New(h crypto.Hash, b []byte) Digest { d := Digest{h: h} copy(d.b[:], b) return d @@ -287,20 +295,41 @@ func (d Digest) IsAbbrev() bool { return bytes.HasSuffix(d.b[:], zeros[d.h.Size()/2:]) } -// Expands tells whether digest d expands the short digest e. +// NPrefix returns the number of nonzero leading bytes in the +// digest, after which the remaining bytes are zero. +func (d Digest) NPrefix() int { + for i := d.h.Size() - 1; i >= 0; i-- { + if d.b[i] != 0 { + return i + 1 + } + } + return 0 +} + +// Expands tells whether digest d expands digest e. func (d Digest) Expands(e Digest) bool { - return bytes.HasPrefix(d.b[:], e.b[:4]) + n := e.NPrefix() + return bytes.HasPrefix(d.b[:], e.b[:n]) } // String returns the full string representation of the digest: the digest // name, followed by ":", followed by its hexadecimal value. func (d Digest) String() string { if d.IsZero() { - return "" + return zeroString } return fmt.Sprintf("%s:%s", name[d.h], d.Hex()) } +// ShortString returns a short representation of the digest, comprising +// the digest name and its first n bytes. +func (d Digest) ShortString(n int) string { + if d.IsZero() { + return zeroString + } + return fmt.Sprintf("%s:%s", name[d.h], d.HexN(n)) +} + func (d Digest) valid() bool { return d.h.Available() && len(d.b) >= d.h.Size() } @@ -324,11 +353,23 @@ func (d *Digest) UnmarshalJSON(b []byte) error { // Digester computes digests based on a cryptographic hash function. type Digester crypto.Hash +// New returns a new digest with the provided literal contents. New +// panics if the digest size does not match the hash function's length. +func (d Digester) New(b []byte) Digest { + if crypto.Hash(d).Size() != len(b) { + panic("digest: bad digest length") + } + return New(crypto.Hash(d), b) +} + // Parse parses a string into a Digest with the cryptographic hash of // Digester. The input string is in the form of Digest.String, except // that the hash name may be omitted--it is then instead assumed to // be the hash function associated with the Digester. func (d Digester) Parse(s string) (Digest, error) { + if s == "" || s == zeroString { + return Digest{h: crypto.Hash(d)}, nil + } parts := strings.Split(s, ":") switch len(parts) { default: @@ -353,7 +394,7 @@ func (d Digester) FromBytes(p []byte) Digest { if _, err := w.Write(p); err != nil { panic("hash returned error " + err.Error()) } - return newDigest(crypto.Hash(d), w.Sum(nil)) + return New(crypto.Hash(d), w.Sum(nil)) } // FromString computes a Digest from a string. @@ -365,48 +406,86 @@ func (d Digester) FromString(s string) Digest { func (d Digester) FromDigests(digests ...Digest) Digest { w := crypto.Hash(d).New() for _, d := range digests { + // TODO(saito,pknudsgaaard,schandra) + // + // grail.com/pipeline/release/internal/reference passes an empty Digest and + // fails here. We need to be more principled about the values passed here, + // so we intentionally drop errors here. WriteDigest(w, d) } - return newDigest(crypto.Hash(d), w.Sum(nil)) + return New(crypto.Hash(d), w.Sum(nil)) } -// Rand returns a random digest generated by a cryptographically -// secure random number generator. -func (d Digester) Rand() Digest { +// Rand returns a random digest generated by the random +// provided generator. If no generator is provided (r is nil), +// Rand uses the system's cryptographically secure random +// number generator. +func (d Digester) Rand(r *mathrand.Rand) Digest { dg := Digest{h: crypto.Hash(d)} - if _, err := rand.Read(dg.b[:dg.h.Size()]); err != nil { + var ( + err error + p = dg.b[:dg.h.Size()] + ) + if r != nil { + _, err = r.Read(p) + } else { + _, err = rand.Read(p) + } + if err != nil { panic(err) } return dg } -// NewWriter returns a Writer that can be used to compute -// Digests of long inputs. +// NewWriter returns a Writer that can be used to compute Digests of long inputs. func (d Digester) NewWriter() Writer { - return Writer{crypto.Hash(d), crypto.Hash(d).New()} + hw := crypto.Hash(d).New() + return Writer{h: crypto.Hash(d), hw: hw, w: bufio.NewWriter(hw)} +} + +const digesterBufferSize = 256 + +// NewWriterShort returns a Writer that can be used to compute Digests of short inputs (ie, order of KBs) +func (d Digester) NewWriterShort() Writer { + hw := crypto.Hash(d).New() + return Writer{h: crypto.Hash(d), hw: hw, w: bufio.NewWriterSize(hw, digesterBufferSize)} } // Writer provides an io.Writer to which digested bytes are // written and from which a Digest is produced. type Writer struct { h crypto.Hash - w hash.Hash + hw hash.Hash + w *bufio.Writer } func (d Writer) Write(p []byte) (n int, err error) { return d.w.Write(p) } +func (d Writer) WriteString(s string) (n int, err error) { + return d.w.WriteString(s) +} + // Digest produces the current Digest of the Writer. // It does not reset its internal state. func (d Writer) Digest() Digest { - return newDigest(d.h, d.w.Sum(nil)) + if err := d.w.Flush(); err != nil { + panic(fmt.Sprintf("digest.Digest.Flush: %v", err)) + } + return New(d.h, d.hw.Sum(nil)) } // WriteDigest is a convenience function to write a (binary) // Digest to an io.Writer. Its format is two bytes representing // the hash function, followed by the hash value itself. +// +// Writing a zero digest is disallowed; WriteDigest panics in +// this case. func WriteDigest(w io.Writer, d Digest) (n int, err error) { + if d.IsZero() { + panic("digest.WriteDigest: attempted to write a zero digest") + } digestHash, ok := cryptoToDigestHashes[d.h] if !ok { return n, fmt.Errorf("cannot convert %v to a digestHash", d.h) @@ -469,10 +548,3 @@ func (d *Digester) UnmarshalJSON(b []byte) error { *d = Digester(val) return nil } - -// Slice defines sort.Interface on a slice of digests. -type Slice []Digest - -func (s Slice) Len() int { return len(s) } -func (s Slice) Less(i, j int) bool { return s[i].Less(s[j]) } -func (s Slice) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/digest/digest_test.go b/digest/digest_test.go index 8662c8f7..29269be6 100644 --- a/digest/digest_test.go +++ b/digest/digest_test.go @@ -38,12 +38,15 @@ func TestDigest(t *testing.T) { } dd, err := dig.Parse(tc.out) if err != nil { - t.Fatal(err) + t.Fatalf("parse failed: %v", err) } if got, want := dd, d; got != want { t.Fatalf("got %v want %v", got, want) } dd, err = dig.Parse(tc.out[len(tc.name)-1:]) + if err != nil { + t.Fatalf("parse failed: %v", err) + } if got, want := dd, d; got != want { t.Fatalf("got %v want %v", got, want) } @@ -83,6 +86,7 @@ func TestReadWrite(t *testing.T) { } // Test unknown digestHash d := Digest{} + d.h = ^crypto.Hash(0) // make it nonzero but invalid so we don't hit the zero hash panic var b bytes.Buffer _, err := WriteDigest(&b, d) if err == nil { @@ -178,6 +182,20 @@ func TestTruncate(t *testing.T) { } } +func TestNPrefix(t *testing.T) { + d := Digester(crypto.SHA256) + id, err := d.Parse("9909853c8cada54314ddc5f89fe5658e139aea88cab8c1479a8c35c902b1cb49") + if err != nil { + t.Fatal(err) + } + for n := 32; n >= 0; n-- { + id.Truncate(n) + if got, want := id.NPrefix(), n; got != want { + t.Errorf("got %v, want %v for %v", got, want, id) + } + } +} + func TestGob(t *testing.T) { id, err := Parse("sha256:9909853c8cada5431400c5f89fe5658e139aea88cab8c1479a8c35c902b1cb49") if err != nil { @@ -197,3 +215,15 @@ func TestGob(t *testing.T) { t.Errorf("got %v, want %v", got, want) } } + +func TestParse(t *testing.T) { + for _, hash := range []string{"", ""} { + h, err := Parse(hash) + if err != nil { + t.Fatal(err) + } + if got, want := h, (Digest{}); got != want { + t.Errorf("got %v, want %v", got, want) + } + } +} diff --git a/digest/digestreader.go b/digest/digestreader.go index 0ee4cf3b..1cb080c8 100644 --- a/digest/digestreader.go +++ b/digest/digestreader.go @@ -8,7 +8,6 @@ import ( "fmt" "io" "sync" - "sync/atomic" ) @@ -61,6 +60,7 @@ func (r *readerWrap) Read(p []byte) (int, error) { } q := p[:n] + // todo(ysiato, schandra, pknudsgaard) this looks like another intentional no-error-check like digest.go:407 r.digestWriter.Write(q) return n, r.err diff --git a/digest/digestrw_test.go b/digest/digestrw_test.go index 4a447ee3..267c5cf0 100644 --- a/digest/digestrw_test.go +++ b/digest/digestrw_test.go @@ -23,6 +23,7 @@ import ( "github.com/aws/aws-sdk-go/service/s3" "github.com/aws/aws-sdk-go/service/s3/s3manager" "github.com/grailbio/base/digest" + "github.com/grailbio/base/traverse" "github.com/grailbio/testutil" "github.com/grailbio/testutil/s3test" ) @@ -45,15 +46,15 @@ func TestDigestReader(t *testing.T) { order []int64 }{ { - &testutil.FakeContentAt{t, dataSize, 0, 0}, + &testutil.FakeContentAt{T: t, SizeInBytes: dataSize, Current: 0, FailureRate: 0}, []int64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, }, { - &testutil.FakeContentAt{t, dataSize, 0, 0}, + &testutil.FakeContentAt{T: t, SizeInBytes: dataSize, Current: 0, FailureRate: 0}, []int64{1, 0, 3, 2, 5, 4, 7, 6, 9, 8}, }, { - &testutil.FakeContentAt{t, dataSize, 0, 0}, + &testutil.FakeContentAt{T: t, SizeInBytes: dataSize, Current: 0, FailureRate: 0}, []int64{9, 8, 7, 6, 5, 4, 3, 2, 1, 0}, }, } { @@ -63,22 +64,17 @@ func TestDigestReader(t *testing.T) { t.Fatal("reader does not support ReaderAt") } - wg := sync.WaitGroup{} - wg.Add(len(test.order)) - for _, i := range test.order { - go func(index int64) { - defer wg.Done() - - size := min(segmentSize, (dataSize-index)*segmentSize) - d := make([]byte, size) - _, err := readerAt.ReadAt(d, index*int64(segmentSize)) - if err != nil { - t.Fatal(err) - } - }(i) - time.Sleep(10 * time.Millisecond) + err := traverse.Each(len(test.order), func(jobIdx int) error { + time.Sleep(10 * time.Duration(jobIdx) * time.Millisecond) + index := test.order[jobIdx] + size := min(segmentSize, (dataSize-index)*segmentSize) + d := make([]byte, size) + _, err := readerAt.ReadAt(d, index*int64(segmentSize)) + return err + }) + if err != nil { + t.Fatal(err) } - wg.Wait() actual, err := dra.Digest() if err != nil { @@ -86,7 +82,7 @@ func TestDigestReader(t *testing.T) { } writer := digester.NewWriter() - content := &testutil.FakeContentAt{t, dataSize, 0, 0} + content := &testutil.FakeContentAt{T: t, SizeInBytes: dataSize, Current: 0, FailureRate: 0} if _, err := io.Copy(writer, content); err != nil { t.Fatal(err) } @@ -125,23 +121,19 @@ func TestDigestWriter(t *testing.T) { dwa := digester.NewWriterAt(context.Background(), output) - wg := sync.WaitGroup{} - wg.Add(len(test)) - for _, i := range test { + err = traverse.Each(len(test), func(jobIdx int) error { + time.Sleep(5 * time.Duration(jobIdx) * time.Millisecond) + i := test[jobIdx] segmentString := strings.Repeat(fmt.Sprintf("%c", 'a'+i), 100) offset := int64(i * len(segmentString)) - go func() { - _, err := dwa.WriteAt([]byte(segmentString), offset) - if err != nil { - t.Fatal(err) - } - wg.Done() - }() - time.Sleep(5 * time.Millisecond) - } - wg.Wait() + _, e := dwa.WriteAt([]byte(segmentString), offset) + return e + }) output.Close() + if err != nil { + t.Fatal(err) + } expected, err := dwa.Digest() if err != nil { @@ -203,7 +195,7 @@ func TestS3ManagerUpload(t *testing.T) { size := int64(93384620) // Completely random number. digester := digest.Digester(crypto.SHA256) - contentAt := &testutil.FakeContentAt{t, size, 0, 0} + contentAt := &testutil.FakeContentAt{T: t, SizeInBytes: size, Current: 0, FailureRate: 0} client.SetFileContentAt("test/test/test", contentAt, "fakesha") reader := digester.NewReader(contentAt) @@ -232,7 +224,7 @@ func TestS3ManagerUpload(t *testing.T) { } dw := digester.NewWriter() - content := &testutil.FakeContentAt{t, size, 0, 0} + content := &testutil.FakeContentAt{T: t, SizeInBytes: size, Current: 0, FailureRate: 0} if _, err := io.Copy(dw, content); err != nil { t.Fatal(err) } @@ -250,7 +242,7 @@ func TestS3ManagerDownload(t *testing.T) { size := int64(86738922) // Completely random number. digester := digest.Digester(crypto.SHA256) - contentAt := &testutil.FakeContentAt{t, size, 0, 0.001} + contentAt := &testutil.FakeContentAt{T: t, SizeInBytes: size, Current: 0, FailureRate: 0.001} client.SetFileContentAt("test/test/test", contentAt, "fakesha") writer := digester.NewWriterAt(context.Background(), contentAt) @@ -280,7 +272,7 @@ func TestS3ManagerDownload(t *testing.T) { } dw := digester.NewWriter() - content := &testutil.FakeContentAt{t, size, 0, 0} + content := &testutil.FakeContentAt{T: t, SizeInBytes: size, Current: 0, FailureRate: 0} if _, err := io.Copy(dw, content); err != nil { t.Fatal(err) } diff --git a/embedbin/binaries_test.go b/embedbin/binaries_test.go new file mode 100644 index 00000000..8ef83107 --- /dev/null +++ b/embedbin/binaries_test.go @@ -0,0 +1,259 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package embedbin + +// Stripped version of C program: +// void main(){printf("hello world");} +var svelteLinuxElfBinary = []byte{ + 0x7f, 0x45, 0x4c, 0x46, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x03, 0x00, 0x01, 0x00, 0x00, 0x00, + 0xf0, 0x82, 0x04, 0x08, 0x34, 0x00, 0x00, 0x00, 0x70, 0x07, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x34, 0x00, 0x20, 0x00, 0x07, 0x00, 0x28, 0x00, + 0x1b, 0x00, 0x1a, 0x00, 0x06, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, + 0x34, 0x80, 0x04, 0x08, 0x34, 0x80, 0x04, 0x08, 0xe0, 0x00, 0x00, 0x00, + 0xe0, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x14, 0x01, 0x00, 0x00, 0x14, 0x81, 0x04, 0x08, + 0x14, 0x81, 0x04, 0x08, 0x13, 0x00, 0x00, 0x00, 0x13, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x04, 0x08, 0x00, 0x80, 0x04, 0x08, + 0x70, 0x04, 0x00, 0x00, 0x70, 0x04, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, + 0x00, 0x10, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x70, 0x04, 0x00, 0x00, + 0x70, 0x94, 0x04, 0x08, 0x70, 0x94, 0x04, 0x08, 0x0c, 0x01, 0x00, 0x00, + 0x10, 0x01, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x84, 0x04, 0x00, 0x00, 0x84, 0x94, 0x04, 0x08, + 0x84, 0x94, 0x04, 0x08, 0xd0, 0x00, 0x00, 0x00, 0xd0, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x28, 0x01, 0x00, 0x00, 0x28, 0x81, 0x04, 0x08, 0x28, 0x81, 0x04, 0x08, + 0x20, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x51, 0xe5, 0x74, 0x64, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x2f, 0x6c, 0x69, 0x62, 0x2f, 0x6c, 0x64, 0x2d, 0x6c, 0x69, 0x6e, 0x75, + 0x78, 0x2e, 0x73, 0x6f, 0x2e, 0x32, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x47, 0x4e, 0x55, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, + 0x00, 0x20, 0x00, 0x20, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xad, 0x4b, 0xe3, 0xc0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x2e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xb2, 0x01, 0x00, 0x00, + 0x12, 0x00, 0x00, 0x00, 0x29, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x8f, 0x01, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, 0x1a, 0x00, 0x00, 0x00, + 0x5c, 0x84, 0x04, 0x08, 0x04, 0x00, 0x00, 0x00, 0x11, 0x00, 0x0f, 0x00, + 0x00, 0x5f, 0x5f, 0x67, 0x6d, 0x6f, 0x6e, 0x5f, 0x73, 0x74, 0x61, 0x72, + 0x74, 0x5f, 0x5f, 0x00, 0x6c, 0x69, 0x62, 0x63, 0x2e, 0x73, 0x6f, 0x2e, + 0x36, 0x00, 0x5f, 0x49, 0x4f, 0x5f, 0x73, 0x74, 0x64, 0x69, 0x6e, 0x5f, + 0x75, 0x73, 0x65, 0x64, 0x00, 0x70, 0x75, 0x74, 0x73, 0x00, 0x5f, 0x5f, + 0x6c, 0x69, 0x62, 0x63, 0x5f, 0x73, 0x74, 0x61, 0x72, 0x74, 0x5f, 0x6d, + 0x61, 0x69, 0x6e, 0x00, 0x47, 0x4c, 0x49, 0x42, 0x43, 0x5f, 0x32, 0x2e, + 0x30, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x02, 0x00, 0x01, 0x00, + 0x01, 0x00, 0x01, 0x00, 0x10, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x10, 0x69, 0x69, 0x0d, 0x00, 0x00, 0x02, 0x00, + 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x54, 0x95, 0x04, 0x08, + 0x06, 0x01, 0x00, 0x00, 0x64, 0x95, 0x04, 0x08, 0x07, 0x01, 0x00, 0x00, + 0x68, 0x95, 0x04, 0x08, 0x07, 0x02, 0x00, 0x00, 0x6c, 0x95, 0x04, 0x08, + 0x07, 0x03, 0x00, 0x00, 0x55, 0x89, 0xe5, 0x53, 0x83, 0xec, 0x04, 0xe8, + 0x00, 0x00, 0x00, 0x00, 0x5b, 0x81, 0xc3, 0xd8, 0x12, 0x00, 0x00, 0x8b, + 0x93, 0xfc, 0xff, 0xff, 0xff, 0x85, 0xd2, 0x74, 0x05, 0xe8, 0x1e, 0x00, + 0x00, 0x00, 0xe8, 0xb5, 0x00, 0x00, 0x00, 0xe8, 0x70, 0x01, 0x00, 0x00, + 0x58, 0x5b, 0xc9, 0xc3, 0xff, 0x35, 0x5c, 0x95, 0x04, 0x08, 0xff, 0x25, + 0x60, 0x95, 0x04, 0x08, 0x00, 0x00, 0x00, 0x00, 0xff, 0x25, 0x64, 0x95, + 0x04, 0x08, 0x68, 0x00, 0x00, 0x00, 0x00, 0xe9, 0xe0, 0xff, 0xff, 0xff, + 0xff, 0x25, 0x68, 0x95, 0x04, 0x08, 0x68, 0x08, 0x00, 0x00, 0x00, 0xe9, + 0xd0, 0xff, 0xff, 0xff, 0xff, 0x25, 0x6c, 0x95, 0x04, 0x08, 0x68, 0x10, + 0x00, 0x00, 0x00, 0xe9, 0xc0, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x31, 0xed, 0x5e, 0x89, + 0xe1, 0x83, 0xe4, 0xf0, 0x50, 0x54, 0x52, 0x68, 0xa0, 0x83, 0x04, 0x08, + 0x68, 0xb0, 0x83, 0x04, 0x08, 0x51, 0x56, 0x68, 0x74, 0x83, 0x04, 0x08, + 0xe8, 0xb3, 0xff, 0xff, 0xff, 0xf4, 0x90, 0x90, 0x90, 0x90, 0x90, 0x90, + 0x90, 0x90, 0x90, 0x90, 0x90, 0x90, 0x90, 0x90, 0x55, 0x89, 0xe5, 0x83, + 0xec, 0x08, 0x80, 0x3d, 0x7c, 0x95, 0x04, 0x08, 0x00, 0x74, 0x0c, 0xeb, + 0x1c, 0x83, 0xc0, 0x04, 0xa3, 0x78, 0x95, 0x04, 0x08, 0xff, 0xd2, 0xa1, + 0x78, 0x95, 0x04, 0x08, 0x8b, 0x10, 0x85, 0xd2, 0x75, 0xeb, 0xc6, 0x05, + 0x7c, 0x95, 0x04, 0x08, 0x01, 0xc9, 0xc3, 0x90, 0x55, 0x89, 0xe5, 0x83, + 0xec, 0x08, 0xa1, 0x80, 0x94, 0x04, 0x08, 0x85, 0xc0, 0x74, 0x12, 0xb8, + 0x00, 0x00, 0x00, 0x00, 0x85, 0xc0, 0x74, 0x09, 0xc7, 0x04, 0x24, 0x80, + 0x94, 0x04, 0x08, 0xff, 0xd0, 0xc9, 0xc3, 0x90, 0x8d, 0x4c, 0x24, 0x04, + 0x83, 0xe4, 0xf0, 0xff, 0x71, 0xfc, 0x55, 0x89, 0xe5, 0x51, 0x83, 0xec, + 0x04, 0xc7, 0x04, 0x24, 0x60, 0x84, 0x04, 0x08, 0xe8, 0x43, 0xff, 0xff, + 0xff, 0xb8, 0x00, 0x00, 0x00, 0x00, 0x83, 0xc4, 0x04, 0x59, 0x5d, 0x8d, + 0x61, 0xfc, 0xc3, 0x90, 0x55, 0x89, 0xe5, 0x5d, 0xc3, 0x8d, 0x74, 0x26, + 0x00, 0x8d, 0xbc, 0x27, 0x00, 0x00, 0x00, 0x00, 0x55, 0x89, 0xe5, 0x57, + 0x56, 0x53, 0xe8, 0x4f, 0x00, 0x00, 0x00, 0x81, 0xc3, 0x9d, 0x11, 0x00, + 0x00, 0x83, 0xec, 0x0c, 0xe8, 0xab, 0xfe, 0xff, 0xff, 0x8d, 0xbb, 0x18, + 0xff, 0xff, 0xff, 0x8d, 0x83, 0x18, 0xff, 0xff, 0xff, 0x29, 0xc7, 0xc1, + 0xff, 0x02, 0x85, 0xff, 0x74, 0x24, 0x31, 0xf6, 0x8b, 0x45, 0x10, 0x89, + 0x44, 0x24, 0x08, 0x8b, 0x45, 0x0c, 0x89, 0x44, 0x24, 0x04, 0x8b, 0x45, + 0x08, 0x89, 0x04, 0x24, 0xff, 0x94, 0xb3, 0x18, 0xff, 0xff, 0xff, 0x83, + 0xc6, 0x01, 0x39, 0xf7, 0x75, 0xde, 0x83, 0xc4, 0x0c, 0x5b, 0x5e, 0x5f, + 0x5d, 0xc3, 0x8b, 0x1c, 0x24, 0xc3, 0x90, 0x90, 0x55, 0x89, 0xe5, 0x53, + 0x83, 0xec, 0x04, 0xa1, 0x70, 0x94, 0x04, 0x08, 0x83, 0xf8, 0xff, 0x74, + 0x12, 0x31, 0xdb, 0xff, 0xd0, 0x8b, 0x83, 0x6c, 0x94, 0x04, 0x08, 0x83, + 0xeb, 0x04, 0x83, 0xf8, 0xff, 0x75, 0xf0, 0x83, 0xc4, 0x04, 0x5b, 0x5d, + 0xc3, 0x90, 0x90, 0x90, 0x55, 0x89, 0xe5, 0x53, 0x83, 0xec, 0x04, 0xe8, + 0x00, 0x00, 0x00, 0x00, 0x5b, 0x81, 0xc3, 0x10, 0x11, 0x00, 0x00, 0xe8, + 0xcc, 0xfe, 0xff, 0xff, 0x59, 0x5b, 0xc9, 0xc3, 0x03, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x02, 0x00, 0x48, 0x69, 0x20, 0x57, 0x6f, 0x72, 0x6c, 0x64, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x74, 0x82, 0x04, 0x08, 0x0d, 0x00, 0x00, 0x00, + 0x3c, 0x84, 0x04, 0x08, 0x04, 0x00, 0x00, 0x00, 0x48, 0x81, 0x04, 0x08, + 0xf5, 0xfe, 0xff, 0x6f, 0x70, 0x81, 0x04, 0x08, 0x05, 0x00, 0x00, 0x00, + 0xe0, 0x81, 0x04, 0x08, 0x06, 0x00, 0x00, 0x00, 0x90, 0x81, 0x04, 0x08, + 0x0a, 0x00, 0x00, 0x00, 0x4a, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x15, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x58, 0x95, 0x04, 0x08, 0x02, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x11, 0x00, 0x00, 0x00, + 0x17, 0x00, 0x00, 0x00, 0x5c, 0x82, 0x04, 0x08, 0x11, 0x00, 0x00, 0x00, + 0x54, 0x82, 0x04, 0x08, 0x12, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x13, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0xfe, 0xff, 0xff, 0x6f, + 0x34, 0x82, 0x04, 0x08, 0xff, 0xff, 0xff, 0x6f, 0x01, 0x00, 0x00, 0x00, + 0xf0, 0xff, 0xff, 0x6f, 0x2a, 0x82, 0x04, 0x08, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x84, 0x94, 0x04, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xba, 0x82, 0x04, 0x08, 0xca, 0x82, 0x04, 0x08, 0xda, 0x82, 0x04, 0x08, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x7c, 0x94, 0x04, 0x08, + 0x00, 0x47, 0x43, 0x43, 0x3a, 0x20, 0x28, 0x47, 0x4e, 0x55, 0x29, 0x20, + 0x34, 0x2e, 0x32, 0x2e, 0x33, 0x20, 0x28, 0x55, 0x62, 0x75, 0x6e, 0x74, + 0x75, 0x20, 0x34, 0x2e, 0x32, 0x2e, 0x33, 0x2d, 0x32, 0x75, 0x62, 0x75, + 0x6e, 0x74, 0x75, 0x37, 0x29, 0x00, 0x00, 0x47, 0x43, 0x43, 0x3a, 0x20, + 0x28, 0x47, 0x4e, 0x55, 0x29, 0x20, 0x34, 0x2e, 0x32, 0x2e, 0x33, 0x20, + 0x28, 0x55, 0x62, 0x75, 0x6e, 0x74, 0x75, 0x20, 0x34, 0x2e, 0x32, 0x2e, + 0x33, 0x2d, 0x32, 0x75, 0x62, 0x75, 0x6e, 0x74, 0x75, 0x37, 0x29, 0x00, + 0x00, 0x47, 0x43, 0x43, 0x3a, 0x20, 0x28, 0x47, 0x4e, 0x55, 0x29, 0x20, + 0x34, 0x2e, 0x32, 0x2e, 0x33, 0x20, 0x28, 0x55, 0x62, 0x75, 0x6e, 0x74, + 0x75, 0x20, 0x34, 0x2e, 0x32, 0x2e, 0x33, 0x2d, 0x32, 0x75, 0x62, 0x75, + 0x6e, 0x74, 0x75, 0x37, 0x29, 0x00, 0x00, 0x47, 0x43, 0x43, 0x3a, 0x20, + 0x28, 0x47, 0x4e, 0x55, 0x29, 0x20, 0x34, 0x2e, 0x32, 0x2e, 0x33, 0x20, + 0x28, 0x55, 0x62, 0x75, 0x6e, 0x74, 0x75, 0x20, 0x34, 0x2e, 0x32, 0x2e, + 0x33, 0x2d, 0x32, 0x75, 0x62, 0x75, 0x6e, 0x74, 0x75, 0x37, 0x29, 0x00, + 0x00, 0x47, 0x43, 0x43, 0x3a, 0x20, 0x28, 0x47, 0x4e, 0x55, 0x29, 0x20, + 0x34, 0x2e, 0x32, 0x2e, 0x33, 0x20, 0x28, 0x55, 0x62, 0x75, 0x6e, 0x74, + 0x75, 0x20, 0x34, 0x2e, 0x32, 0x2e, 0x33, 0x2d, 0x32, 0x75, 0x62, 0x75, + 0x6e, 0x74, 0x75, 0x37, 0x29, 0x00, 0x00, 0x47, 0x43, 0x43, 0x3a, 0x20, + 0x28, 0x47, 0x4e, 0x55, 0x29, 0x20, 0x34, 0x2e, 0x32, 0x2e, 0x33, 0x20, + 0x28, 0x55, 0x62, 0x75, 0x6e, 0x74, 0x75, 0x20, 0x34, 0x2e, 0x32, 0x2e, + 0x33, 0x2d, 0x32, 0x75, 0x62, 0x75, 0x6e, 0x74, 0x75, 0x37, 0x29, 0x00, + 0x00, 0x47, 0x43, 0x43, 0x3a, 0x20, 0x28, 0x47, 0x4e, 0x55, 0x29, 0x20, + 0x34, 0x2e, 0x32, 0x2e, 0x33, 0x20, 0x28, 0x55, 0x62, 0x75, 0x6e, 0x74, + 0x75, 0x20, 0x34, 0x2e, 0x32, 0x2e, 0x33, 0x2d, 0x32, 0x75, 0x62, 0x75, + 0x6e, 0x74, 0x75, 0x37, 0x29, 0x00, 0x00, 0x2e, 0x73, 0x68, 0x73, 0x74, + 0x72, 0x74, 0x61, 0x62, 0x00, 0x2e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x70, + 0x00, 0x2e, 0x6e, 0x6f, 0x74, 0x65, 0x2e, 0x41, 0x42, 0x49, 0x2d, 0x74, + 0x61, 0x67, 0x00, 0x2e, 0x67, 0x6e, 0x75, 0x2e, 0x68, 0x61, 0x73, 0x68, + 0x00, 0x2e, 0x64, 0x79, 0x6e, 0x73, 0x79, 0x6d, 0x00, 0x2e, 0x64, 0x79, + 0x6e, 0x73, 0x74, 0x72, 0x00, 0x2e, 0x67, 0x6e, 0x75, 0x2e, 0x76, 0x65, + 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x00, 0x2e, 0x67, 0x6e, 0x75, 0x2e, 0x76, + 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x72, 0x00, 0x2e, 0x72, 0x65, + 0x6c, 0x2e, 0x64, 0x79, 0x6e, 0x00, 0x2e, 0x72, 0x65, 0x6c, 0x2e, 0x70, + 0x6c, 0x74, 0x00, 0x2e, 0x69, 0x6e, 0x69, 0x74, 0x00, 0x2e, 0x74, 0x65, + 0x78, 0x74, 0x00, 0x2e, 0x66, 0x69, 0x6e, 0x69, 0x00, 0x2e, 0x72, 0x6f, + 0x64, 0x61, 0x74, 0x61, 0x00, 0x2e, 0x65, 0x68, 0x5f, 0x66, 0x72, 0x61, + 0x6d, 0x65, 0x00, 0x2e, 0x63, 0x74, 0x6f, 0x72, 0x73, 0x00, 0x2e, 0x64, + 0x74, 0x6f, 0x72, 0x73, 0x00, 0x2e, 0x6a, 0x63, 0x72, 0x00, 0x2e, 0x64, + 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x00, 0x2e, 0x67, 0x6f, 0x74, 0x00, + 0x2e, 0x67, 0x6f, 0x74, 0x2e, 0x70, 0x6c, 0x74, 0x00, 0x2e, 0x64, 0x61, + 0x74, 0x61, 0x00, 0x2e, 0x62, 0x73, 0x73, 0x00, 0x2e, 0x63, 0x6f, 0x6d, + 0x6d, 0x65, 0x6e, 0x74, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x14, 0x81, 0x04, 0x08, 0x14, 0x01, 0x00, 0x00, 0x13, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x13, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x28, 0x81, 0x04, 0x08, 0x28, 0x01, 0x00, 0x00, + 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x25, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x48, 0x81, 0x04, 0x08, + 0x48, 0x01, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x21, 0x00, 0x00, 0x00, 0xf6, 0xff, 0xff, 0x6f, 0x02, 0x00, 0x00, 0x00, + 0x70, 0x81, 0x04, 0x08, 0x70, 0x01, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x90, 0x81, 0x04, 0x08, 0x90, 0x01, 0x00, 0x00, + 0x50, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x33, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0xe0, 0x81, 0x04, 0x08, + 0xe0, 0x01, 0x00, 0x00, 0x4a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x3b, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0x6f, 0x02, 0x00, 0x00, 0x00, + 0x2a, 0x82, 0x04, 0x08, 0x2a, 0x02, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, 0xfe, 0xff, 0xff, 0x6f, + 0x02, 0x00, 0x00, 0x00, 0x34, 0x82, 0x04, 0x08, 0x34, 0x02, 0x00, 0x00, + 0x20, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x57, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x82, 0x04, 0x08, + 0x54, 0x02, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x60, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x5c, 0x82, 0x04, 0x08, 0x5c, 0x02, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x69, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x74, 0x82, 0x04, 0x08, 0x74, 0x02, 0x00, 0x00, + 0x30, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x64, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0xa4, 0x82, 0x04, 0x08, + 0xa4, 0x02, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x6f, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0xf0, 0x82, 0x04, 0x08, 0xf0, 0x02, 0x00, 0x00, 0x4c, 0x01, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x75, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x3c, 0x84, 0x04, 0x08, 0x3c, 0x04, 0x00, 0x00, + 0x1c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x7b, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x58, 0x84, 0x04, 0x08, + 0x58, 0x04, 0x00, 0x00, 0x11, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x83, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x6c, 0x84, 0x04, 0x08, 0x6c, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x8d, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x70, 0x94, 0x04, 0x08, 0x70, 0x04, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x94, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x78, 0x94, 0x04, 0x08, + 0x78, 0x04, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x9b, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x80, 0x94, 0x04, 0x08, 0x80, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0xa0, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x84, 0x94, 0x04, 0x08, 0x84, 0x04, 0x00, 0x00, + 0xd0, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0xa9, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x54, 0x95, 0x04, 0x08, + 0x54, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xae, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x58, 0x95, 0x04, 0x08, 0x58, 0x05, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xb7, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x70, 0x95, 0x04, 0x08, 0x70, 0x05, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xbd, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x7c, 0x95, 0x04, 0x08, + 0x7c, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xc2, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x7c, 0x05, 0x00, 0x00, 0x26, 0x01, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xa2, 0x06, 0x00, 0x00, + 0xcb, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, +} diff --git a/embedbin/create.go b/embedbin/create.go new file mode 100644 index 00000000..e144741f --- /dev/null +++ b/embedbin/create.go @@ -0,0 +1,87 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package embedbin + +import ( + "archive/zip" + "io" + "os" +) + +// WriteOpt is an option to NewWriter. +type WriteOpt func(*Writer) + +// Deflate compresses embedded files. +var Deflate WriteOpt = func(w *Writer) { + w.embedMethod = zip.Deflate +} + +// Writer is used to append embedbin files to an existing binary. +type Writer struct { + w io.Writer + + embedOffset int64 + embedZ *zip.Writer + embedMethod uint16 // no compression by default +} + +// NewFileWriter returns a writer that can be used to append embedbin +// files to the binary represented by the provided file. +// NewFileWriter removes any existing embedbin files that may be +// attached to the binary. It relies on content sniffing (see Sniff) +// to determine its offset. +func NewFileWriter(file *os.File) (*Writer, error) { + info, err := file.Stat() + if err != nil { + return nil, err + } + embedOffset, err := Sniff(file, info.Size()) + if err != nil { + return nil, err + } + if err = file.Truncate(embedOffset); err != nil { + return nil, err + } + _, err = file.Seek(0, io.SeekEnd) + if err != nil { + return nil, err + } + return NewWriter(file, embedOffset), nil +} + +// NewWriter returns a writer that may be used to append embedbin +// files to the writer w. The writer should be positioned at the end +// of the base binary image. +func NewWriter(w io.Writer, embedOffset int64, opts ...WriteOpt) *Writer { + ew := Writer{w: w, embedOffset: embedOffset, embedZ: zip.NewWriter(w)} + for _, opt := range opts { + opt(&ew) + } + return &ew +} + +// Create returns a Writer into which the named file should be written. +// The image's contents must be written before the next call to Create or Close. +func (w *Writer) Create(name string) (io.Writer, error) { + return w.embedZ.CreateHeader(&zip.FileHeader{ + Name: name, + Method: w.embedMethod, + }) +} + +// Flush flushes the unwritten data to the underlying file. +func (w *Writer) Flush() error { + return w.embedZ.Flush() +} + +// Close should be called after all embedded files have been written. +// No more files can be written after a call to Close. +func (w *Writer) Close() error { + if err := w.embedZ.Close(); err != nil { + return err + } + _, err := writeFooter(w.w, w.embedOffset) + return err +} diff --git a/embedbin/embedbin.go b/embedbin/embedbin.go new file mode 100644 index 00000000..11963341 --- /dev/null +++ b/embedbin/embedbin.go @@ -0,0 +1,171 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package embedbin + +import ( + "archive/zip" + "errors" + "fmt" + "io" + "io/ioutil" + "os" + "sync" +) + +var ( + selfOnce sync.Once + self *Reader + selfErr error +) + +var ( + // ErrNoSuchFile is returned when the embedbin does not contain an + // embedded file with the requested name. + ErrNoSuchFile = errors.New("embedded file does not exist") + // ErrCorruptedImage is returned when the embedbin image has been + // corrupted. + ErrCorruptedImage = errors.New("corrupted embedbin image") +) + +// Info provides information for an embedded file. +type Info struct { + Name string + Size int64 +} + +func (info Info) String() string { + return fmt.Sprintf("%s: %d", info.Name, info.Size) +} + +// Reader reads images from an embedbin. +type Reader struct { + base io.ReaderAt + + embedOffset int64 + embedZ *zip.Reader +} + +// Self reads the currently executing binary image as an embedbin and +// returns a reader to it. +func Self() (*Reader, error) { + selfOnce.Do(func() { + filename, err := os.Executable() + if err != nil { + selfErr = err + return + } + f, err := os.Open(filename) + if err != nil { + selfErr = err + return + } + info, err := f.Stat() + if err != nil { + selfErr = err + return + } + embedOffset, err := Sniff(f, info.Size()) + if err != nil { + selfErr = err + return + } + self, selfErr = NewReader(f, embedOffset, info.Size()) + }) + return self, selfErr +} + +// OpenFile parses the provided ReaderAt with the provided size. The +// file's contents are parsed to determine the offset of the embedbin's +// archive. OpenFile returns an error if the file is not an embedbin. +func OpenFile(r io.ReaderAt, size int64) (*Reader, error) { + offset, err := Sniff(r, size) + if err != nil { + return nil, err + } + return NewReader(r, offset, size) +} + +// NewReader returns a new embedbin reader from the provided reader. +func NewReader(r io.ReaderAt, embedOffset, totalSize int64) (*Reader, error) { + rd := &Reader{ + base: io.NewSectionReader(r, 0, embedOffset), + embedOffset: embedOffset, + } + if embedOffset == totalSize { + return rd, nil + } + var err error + rd.embedZ, err = zip.NewReader(io.NewSectionReader(r, embedOffset, totalSize-embedOffset), totalSize-embedOffset) + if err != nil { + return nil, err + } + return rd, nil +} + +// List returns information about embedded files. +func (r *Reader) List() []Info { + if r.embedZ == nil { + return nil + } + infos := make([]Info, len(r.embedZ.File)) + for i, f := range r.embedZ.File { + infos[i] = Info{ + Name: f.Name, + Size: int64(f.UncompressedSize64), + } + } + return infos +} + +// Open returns a ReadCloser for the original executable, without appended +// embedded files. +func (r *Reader) OpenBase() (io.ReadCloser, error) { + return ioutil.NopCloser(io.NewSectionReader(r.base, 0, 1<<63-1)), nil +} + +// Open returns a ReadCloser for the named embedded file. +// Open returns ErrNoSuchImage if the embedbin does not contain the file. +func (r *Reader) Open(name string) (io.ReadCloser, error) { + if r.embedZ == nil { + return nil, ErrNoSuchFile + } + for _, f := range r.embedZ.File { + if f.Name == name { + return f.Open() + } + } + return nil, ErrNoSuchFile +} + +// StatBase returns the information for the base image. +func (r *Reader) StatBase() Info { + return Info{Size: r.embedOffset} +} + +// Stat returns the information for the named embedded file. +// It returns a boolean indicating whether the requested file was found. +func (r *Reader) Stat(name string) (info Info, ok bool) { + info.Name = name + for _, f := range r.embedZ.File { + if f.Name == name { + info.Size = int64(f.UncompressedSize64) + ok = true + return + } + } + return +} + +// Sniff sniffs a binary's embedbin offset. Sniff returns errors +// returned by the provided reader, or ErrCorruptedImage if the binary is identified +// as an embedbin image with a checksum mismatch. +func Sniff(r io.ReaderAt, size int64) (offset int64, err error) { + offset, err = readFooter(r, size) + if err == errNoFooter { + err = nil + offset = size + } + return +} diff --git a/embedbin/embedbin_test.go b/embedbin/embedbin_test.go new file mode 100644 index 00000000..9b8746ed --- /dev/null +++ b/embedbin/embedbin_test.go @@ -0,0 +1,157 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package embedbin + +import ( + "bytes" + "io" + "io/ioutil" + "os" + "testing" +) + +func TestEmbedbin(t *testing.T) { + filename, err := os.Executable() + if err != nil { + t.Fatal(err) + } + body, err := ioutil.ReadFile(filename) + if err != nil { + t.Fatal(err) + } + + self, err := Self() + if err != nil { + t.Fatal(err) + } + r, err := self.OpenBase() + if err != nil { + t.Fatal(err) + } + defer r.Close() + embedded, err := ioutil.ReadAll(r) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(body, embedded) { + t.Error("content mismatch") + } +} + +func TestEmbedbinNonExist(t *testing.T) { + self, err := Self() + if err != nil { + t.Fatal(err) + } + _, err = self.Open("nonexistent") + if got, want := err, ErrNoSuchFile; got != want { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestSniff(t *testing.T) { + filename, err := os.Executable() + if err != nil { + t.Fatal(err) + } + f, err := os.Open(filename) + if err != nil { + t.Fatal(err) + } + defer f.Close() + info, err := f.Stat() + if err != nil { + t.Fatal(err) + } + + size, err := Sniff(f, info.Size()) + if err != nil { + t.Fatal(err) + } + if got, want := size, info.Size(); got != want { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestCreate(t *testing.T) { + f, err := ioutil.TempFile("", "") + must(t, err) + _, err = f.Write(svelteLinuxElfBinary) + must(t, err) + w, err := NewFileWriter(f) + must(t, err) + dw, err := w.Create("darwin/amd64") + must(t, err) + _, err = dw.Write([]byte("darwin/amd64")) + must(t, err) + dw, err = w.Create("darwin/386") + must(t, err) + _, err = dw.Write([]byte("darwin/386")) + must(t, err) + must(t, w.Close()) + info, err := f.Stat() + must(t, err) + r, err := OpenFile(f, info.Size()) + must(t, err) + + cases := []struct { + base bool + name string + body []byte + }{ + {base: true, body: svelteLinuxElfBinary}, + {name: "darwin/amd64", body: []byte("darwin/amd64")}, + {name: "darwin/386", body: []byte("darwin/386")}, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + var rc io.ReadCloser + if c.base { + rc, err = r.OpenBase() + } else { + rc, err = r.Open(c.name) + } + if err != nil { + t.Fatal(err) + } + mustBytes(t, rc, c.body) + must(t, rc.Close()) + if c.base { + return + } + info, ok := r.Stat(c.name) + if !ok { + t.Errorf("%s/%t: not found", c.name, c.base) + return + } + if got, want := info.Size, int64(len(c.body)); got != want { + t.Errorf("%s: got %v, want %v", c.name, got, want) + } + }) + } + + _, err = r.Open("nonexistent") + if got, want := err, ErrNoSuchFile; got != want { + t.Errorf("got %v, want %v", got, want) + } +} + +func must(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } +} + +func mustBytes(t *testing.T, r io.Reader, want []byte) { + t.Helper() + got, err := ioutil.ReadAll(r) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(got, want) { + t.Errorf("got %s, want %s", got, want) + } +} diff --git a/embedbin/footer.go b/embedbin/footer.go new file mode 100644 index 00000000..2ed69962 --- /dev/null +++ b/embedbin/footer.go @@ -0,0 +1,49 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package embedbin + +import ( + "encoding/binary" + "errors" + "io" + + "github.com/cespare/xxhash" +) + +const magic uint32 = 0xec90a479 +const headersz = 20 + +var ( + errNoFooter = errors.New("binary contains no footer") + + bin = binary.LittleEndian +) + +func writeFooter(w io.Writer, offset int64) (int, error) { + var p [headersz]byte + bin.PutUint64(p[:8], uint64(offset)) + bin.PutUint32(p[8:12], magic) + bin.PutUint64(p[12:20], xxhash.Sum64(p[:12])) + return w.Write(p[:]) +} + +func readFooter(r io.ReaderAt, size int64) (offset int64, err error) { + if size < headersz { + return 0, errNoFooter + } + var p [headersz]byte + _, err = r.ReadAt(p[:], size-headersz) + if err != nil { + return 0, err + } + if bin.Uint32(p[8:12]) != magic { + return 0, errNoFooter + } + offset = int64(bin.Uint64(p[:8])) + if xxhash.Sum64(p[:12]) != bin.Uint64(p[12:20]) { + return 0, ErrCorruptedImage + } + return +} diff --git a/embedbin/footer_test.go b/embedbin/footer_test.go new file mode 100644 index 00000000..b4e4a8b8 --- /dev/null +++ b/embedbin/footer_test.go @@ -0,0 +1,79 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package embedbin + +import ( + "bytes" + "io" + "testing" +) + +func TestReadWriteFooter(t *testing.T) { + for _, sz := range []int64{0, 12, 1e12, 1e13 + 4} { + var b bytes.Buffer + if _, err := writeFooter(&b, sz); err != nil { + t.Error(err) + continue + } + off, err := readFooter(bytes.NewReader(b.Bytes()), int64(b.Len())) + if err != nil { + t.Error(err) + continue + } + if got, want := off, sz; got != want { + t.Errorf("got %v, want %v", got, want) + } + + padded := paddedReaderAt{bytes.NewReader(b.Bytes()), int64(sz) * 100} + off, err = readFooter(padded, int64(sz)*100+int64(b.Len())) + if err != nil { + t.Error(err) + continue + } + if got, want := off, sz; got != want { + t.Errorf("got %v, want %v", got, want) + } + } +} + +func TestCorruptedFooter(t *testing.T) { + var b bytes.Buffer + if _, err := writeFooter(&b, 1234); err != nil { + t.Fatal(err) + } + n := b.Len() + for i := 0; i < n; i++ { + if i >= n-12 && i < n-8 { + continue //skip magic + } + p := make([]byte, b.Len()) + copy(p, b.Bytes()) + p[i]++ + _, err := readFooter(bytes.NewReader(p), int64(len(p))) + if got, want := err, ErrCorruptedImage; got != want { + t.Errorf("got %v, want %v", got, want) + } + } +} + +type paddedReaderAt struct { + io.ReaderAt + N int64 +} + +func (r paddedReaderAt) ReadAt(p []byte, off int64) (n int, err error) { + off -= r.N + for i := range p { + p[i] = 0 + } + switch { + case off < -int64(len(p)): + return len(p), nil + case off < 0: + p = p[-off:] + off = 0 + } + return r.ReaderAt.ReadAt(p, off) +} diff --git a/errors/clean_up.go b/errors/clean_up.go new file mode 100644 index 00000000..75802b40 --- /dev/null +++ b/errors/clean_up.go @@ -0,0 +1,41 @@ +package errors + +import ( + "context" + "fmt" +) + +// CleanUp is defer-able syntactic sugar that calls f and reports an error, if any, +// to *err. Pass the caller's named return error. Example usage: +// +// func processFile(filename string) (_ int, err error) { +// f, err := os.Open(filename) +// if err != nil { ... } +// defer errors.CleanUp(f.Close, &err) +// ... +// } +// +// If the caller returns with its own error, any error from cleanUp will be chained. +func CleanUp(cleanUp func() error, dst *error) { + addErr(cleanUp(), dst) +} + +// CleanUpCtx is CleanUp for a context-ful cleanUp. +func CleanUpCtx(ctx context.Context, cleanUp func(context.Context) error, dst *error) { + addErr(cleanUp(ctx), dst) +} + +func addErr(err2 error, dst *error) { + if err2 == nil { + return + } + if *dst == nil { + *dst = err2 + return + } + // Note: We don't chain err2 as *dst's cause because *dst may already have a meaningful cause. + // Also, even if *dst didn't, err2 may be something entirely different, and suggesting it's + // the cause could be misleading. + // TODO: Consider using a standardized multiple-errors representation like sync/multierror's. + *dst = E(*dst, fmt.Sprintf("second error in Close: %v", err2)) +} diff --git a/errors/clean_up_test.go b/errors/clean_up_test.go new file mode 100644 index 00000000..b2527c87 --- /dev/null +++ b/errors/clean_up_test.go @@ -0,0 +1,62 @@ +package errors + +import ( + "context" + "errors" + "strconv" + "testing" + + "github.com/stretchr/testify/assert" +) + +type errCallable struct{ error } + +func (e errCallable) Func() error { return e.error } +func (e errCallable) FuncCtx(context.Context) error { return e.error } + +func TestCleanUp(t *testing.T) { + const ( + closeMsg = "close [seuozr]" + returnMsg = "return [mntbnb]" + ) + + for callIdx, call := range []func(errCallable, *error){ + func(e errCallable, err *error) { CleanUp(e.Func, err) }, + func(e errCallable, err *error) { CleanUpCtx(context.Background(), e.FuncCtx, err) }, + } { + t.Run(strconv.Itoa(callIdx), func(t *testing.T) { + // No return error, no close error. + gotErr := func() (err error) { + e := errCallable{} + defer call(e, &err) + return nil + }() + assert.NoError(t, gotErr) + + // No return error, close error. + gotErr = func() (err error) { + e := errCallable{errors.New(closeMsg)} + defer call(e, &err) + return nil + }() + assert.Equal(t, gotErr.Error(), closeMsg) + + // Return error, no close error. + gotErr = func() (err error) { + e := errCallable{} + defer call(e, &err) + return errors.New(returnMsg) + }() + assert.Equal(t, gotErr.Error(), returnMsg) + + // Return error, close error. + gotErr = func() (err error) { + e := errCallable{errors.New(closeMsg)} + defer call(e, &err) + return errors.New(returnMsg) + }() + assert.Contains(t, gotErr.Error(), returnMsg) + assert.Contains(t, gotErr.Error(), closeMsg) + }) + } +} diff --git a/errors/errors.go b/errors/errors.go index 958767f8..e737b7b0 100644 --- a/errors/errors.go +++ b/errors/errors.go @@ -24,8 +24,11 @@ import ( "fmt" "os" "runtime" + "strings" + "syscall" "github.com/grailbio/base/log" + "v.io/v23/verror" ) func init() { @@ -68,24 +71,45 @@ const ( TooManyTries // Precondition indicates that a precondition was not met. Precondition + // OOM indicates that an OOM condition was encountered. + OOM + // Remote indicates an error returned by an RPC, as distinct from errors in + // the machinery to execute the RPC, e.g. network issues, machine health, + // etc. + Remote + // ResourcesExhausted indicates that there were insufficient resources. + ResourcesExhausted maxKind ) var kinds = map[Kind]string{ - Other: "unknown error", - Canceled: "operation was canceled", - Timeout: "operation timed out", - NotExist: "resource does not exist", - NotAllowed: "access denied", - NotSupported: "operation not supported", - Exists: "resource already exists", - Integrity: "integrity error", - Unavailable: "resource unavailable", - Invalid: "invalid argument", - Net: "network error", - TooManyTries: "too many tries", - Precondition: "precondition failed", + Other: "unknown error", + Canceled: "operation was canceled", + Timeout: "operation timed out", + NotExist: "resource does not exist", + NotAllowed: "access denied", + NotSupported: "operation not supported", + Exists: "resource already exists", + Integrity: "integrity error", + Unavailable: "resource unavailable", + Invalid: "invalid argument", + Net: "network error", + TooManyTries: "too many tries", + Precondition: "precondition failed", + OOM: "out of memory", + Remote: "remote error", + ResourcesExhausted: "resources exhausted", +} + +// kindStdErrs maps some Kinds to the standard library's equivalent. +var kindStdErrs = map[Kind]error{ + Canceled: context.Canceled, + Timeout: context.DeadlineExceeded, + NotExist: os.ErrNotExist, + NotAllowed: os.ErrPermission, + Exists: os.ErrExist, + Invalid: os.ErrInvalid, } // String returns a human-readable explanation of the error kind k. @@ -93,6 +117,33 @@ func (k Kind) String() string { return kinds[k] } +var kindErrnos = map[Kind]syscall.Errno{ + Canceled: syscall.EINTR, + Timeout: syscall.ETIMEDOUT, + NotExist: syscall.ENOENT, + NotAllowed: syscall.EACCES, + // We map to ENOTSUP instead of ENOSYS, as ENOTSUP is more granular, + // signifying that there may be configurations of a functionality that may + // be supported. ENOSYS, in contrast, signals that an entire + // function(ality) is not supported. If we need to express the distinction + // in the future, we can add a new kind. + NotSupported: syscall.ENOTSUP, + Exists: syscall.EEXIST, + Unavailable: syscall.EAGAIN, + Invalid: syscall.EINVAL, + Net: syscall.ENETUNREACH, + TooManyTries: syscall.EINVAL, + Precondition: syscall.EAGAIN, + OOM: syscall.ENOMEM, + Remote: syscall.EREMOTE, +} + +// Errno maps k to an equivalent Errno or returns false if there's no good match. +func (k Kind) Errno() (syscall.Errno, bool) { + errno, ok := kindErrnos[k] + return errno, ok +} + // Severity defines an Error's severity. An Error's severity determines // whether an error-producing operation may be retried or not. type Severity int @@ -153,7 +204,8 @@ type Error struct { // // - Kind: sets the Error's kind // - Severity: set the Error's severity -// - string: sets the Error's message +// - string: sets the Error's message; multiple strings are +// separated by a single space // - *Error: copies the error and sets the error's cause // - error: sets the Error's cause // @@ -179,6 +231,7 @@ func E(args ...interface{}) error { panic("no args") } e := new(Error) + var msg strings.Builder for _, arg := range args { switch arg := arg.(type) { case Kind: @@ -186,7 +239,10 @@ func E(args ...interface{}) error { case Severity: e.Severity = arg case string: - e.Message = arg + if msg.Len() > 0 { + msg.WriteString(" ") + } + msg.WriteString(arg) case *Error: copy := *arg if len(args) == 1 { @@ -206,6 +262,7 @@ func E(args ...interface{}) error { } } } + e.Message = msg.String() if e.Err == nil { return e } @@ -231,19 +288,46 @@ func E(args ...interface{}) error { if e.Kind != Other { break } - if os.IsNotExist(e.Err) { - e.Kind = NotExist - } else if e.Err == context.Canceled { - e.Kind = Canceled - } else if err, ok := e.Err.(interface { - Timeout() bool - }); ok && err.Timeout() { + // Note: Loop over kind instead of kindStdErrs for determinism. + for kind := Kind(0); kind < maxKind; kind++ { + stdErr := kindStdErrs[kind] + if stdErr != nil && errors.Is(e.Err, stdErr) { + e.Kind = kind + break + } + } + if e.Kind != Other { + break + } + // Interpret verror errors. + if err, ok := asVerrorE(e.Err); ok { + // TODO: Kill this workaround for chained ErrNoAccess errors. See + // https://github.com/vanadium/core/pull/282 . Once we upgrade + // verror to a version whose error.Is supports matching of chained + // errors, we can kill the string check. + // + // Separately, we can consider expanding to map more verror error + // types. + if errors.Is(err, verror.ErrNoAccess) || + strings.Contains(err.Error(), string(verror.ErrNoAccess.ID)) { + e.Kind = NotAllowed + } + } + if e.Kind != Other { + break + } + if isTimeoutErr(e.Err) { e.Kind = Timeout } } return e } +func isTimeoutErr(err error) bool { + t, ok := err.(interface{ Timeout() bool }) + return ok && t.Timeout() +} + // Recover recovers any error into an *Error. If the passed-in Error is already // an error, it is simply returned; otherwise it is wrapped in an error. func Recover(err error) *Error { @@ -305,6 +389,36 @@ func (e *Error) Temporary() bool { return e.Severity <= Temporary } +// Unwrap returns e's cause, if any, or nil. It lets the standard library's +// errors.Unwrap work with *Error. +func (e *Error) Unwrap() error { + return e.Err +} + +// Is tells whether e.Kind is equivalent to err. +// +// This implements interoperability with the standard library's errors.Is: +// errors.Is(e, errors.Canceled) +// works if e.Kind corresponds (in this example, Canceled). This is useful when +// passing *Error to third-party libraries, for example. Users should still +// prefer this package's Is for their own tests because it's less prone to error +// (type checking disallows accidentally swapped arguments). +// +// Note: This match does not recurse into err's cause, if any; see the standard +// library's errors.Is for how this is used. +func (e *Error) Is(err error) bool { + if err == nil { + return false + } + if err == kindStdErrs[e.Kind] { + return true + } + if e.Kind == Timeout && isTimeoutErr(err) { + return true + } + return false +} + type gobError struct { Kind Kind Severity Severity @@ -367,6 +481,12 @@ func (e *Error) GobDecode(p []byte) error { // Is tells whether an error has a specified kind, except for the // indeterminate kind Other. In the case an error has kind Other, the // chain is traversed until a non-Other error is encountered. +// +// This is similar to the standard library's errors.Is(err, target). That +// traverses err's chain looking for one that matches target, where target may +// be os.ErrNotExist, etc. *Error has an explicit Kind instead of an error-typed +// target, but (*Error).Is defines an error -> Kind relationship to allow +// interoperability. func Is(kind Kind, err error) bool { if err == nil { return false @@ -423,6 +543,21 @@ func Match(err1, err2 error) bool { return true } +// Visit calls the given function for every error object in the chain, including +// itself. Recursion stops after the function finds an error object of type +// other than *Error. +func Visit(err error, callback func(err error)) { + callback(err) + for { + next, ok := err.(*Error) + if !ok { + break + } + err = next.Err + callback(err) + } +} + // New is synonymous with errors.New, and is provided here so that // users need only import one errors package. func New(msg string) error { @@ -435,3 +570,15 @@ func pad(b *bytes.Buffer, s string) { } b.WriteString(s) } + +func asVerrorE(err error) (verror.E, bool) { + switch e := err.(type) { + case verror.E: + return e, true + case *verror.E: + if e != nil { + return *e, true + } + } + return verror.E{}, false +} diff --git a/errors/errors_test.go b/errors/errors_test.go index fdb1e6d5..579a3509 100644 --- a/errors/errors_test.go +++ b/errors/errors_test.go @@ -11,10 +11,14 @@ import ( goerrors "errors" "fmt" "os" + "strconv" "testing" + "time" fuzz "github.com/google/gofuzz" "github.com/grailbio/base/errors" + "github.com/grailbio/base/vcontext" + "v.io/v23/verror" ) // generate random errors and test encoding, etc. (fuzz) @@ -127,3 +131,139 @@ func TestGobEncodingFuzz(t *testing.T) { } } } + +func TestMessage(t *testing.T) { + for _, c := range []struct { + err error + message string + }{ + {errors.E("hello"), "hello"}, + {errors.E("hello", "world"), "hello world"}, + } { + if got, want := c.err.Error(), c.message; got != want { + t.Errorf("got %v, want %v", got, want) + } + } +} + +func TestStdInterop(t *testing.T) { + tests := []struct { + name string + makeErr func() (cleanUp func(), _ error) + kind errors.Kind + target error + }{ + { + "not exist", + func() (cleanUp func(), _ error) { + _, err := os.Open("/dev/notexist") + return func() {}, err + }, + errors.NotExist, + os.ErrNotExist, + }, + { + "canceled", + func() (cleanUp func(), _ error) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + <-ctx.Done() + return func() {}, ctx.Err() + }, + errors.Canceled, + context.Canceled, + }, + { + "timeout", + func() (cleanUp func(), _ error) { + ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(-time.Minute)) + <-ctx.Done() + return cancel, ctx.Err() + }, + errors.Timeout, + context.DeadlineExceeded, + }, + { + "timeout interface", + func() (cleanUp func(), _ error) { + return func() {}, apparentTimeoutError{} + }, + errors.Timeout, + nil, // Doesn't match a stdlib error. + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + cleanUp, err := test.makeErr() + defer cleanUp() + for errIdx, err := range []error{ + err, + errors.E(err), + errors.E(err, "wrapped", errors.Fatal), + } { + t.Run(strconv.Itoa(errIdx), func(t *testing.T) { + if got, want := errors.Is(test.kind, err), true; got != want { + t.Errorf("got %v, want %v", got, want) + } + if test.target != nil { + if got, want := goerrors.Is(err, test.target), true; got != want { + t.Errorf("got %v, want %v", got, want) + } + } + // err should not match wrapped target. + if got, want := goerrors.Is(err, fmt.Errorf("%w", test.target)), false; got != want { + t.Errorf("got %v, want %v", got, want) + } + }) + } + }) + } +} + +func TestVerrorInterop(t *testing.T) { + err := errors.E(verror.ErrNoAccess.Errorf(vcontext.Background(), "test error")) + if got, want := errors.Recover(err).Kind, errors.NotAllowed; got != want { + t.Errorf("got %v, want %v", got, want) + } +} + +type apparentTimeoutError struct{} + +func (e apparentTimeoutError) Error() string { return "timeout" } +func (e apparentTimeoutError) Timeout() bool { return true } + +// TestEKindDeterminism ensures that errors.E's Kind detection (based on the +// cause chain of the input error) is deterministic. That is, if the input +// error has multiple causes (according to goerrors.Is), E chooses one +// consistently. User code that handles errors based on Kind will behave +// predictably. +// +// This is a regression test for an issue found while introducing (*Error).Is +// (D65766) which makes it easier for an error chain to match multiple causes. +func TestEKindDeterminism(t *testing.T) { + const N = 100 + numKind := make(map[errors.Kind]int) + for i := 0; i < N; i++ { + // Construct err with a cause chain that matches Canceled due to a + // Kind and NotExist by wrapping the stdlib error. + err := errors.E( + fmt.Errorf("%w", + errors.E("canceled", errors.Canceled, + fmt.Errorf("%w", os.ErrNotExist)))) + // Sanity check: err is detected as both targets. + if got, want := goerrors.Is(err, os.ErrNotExist), true; got != want { + t.Errorf("got %v, want %v", got, want) + } + if got, want := goerrors.Is(err, context.Canceled), true; got != want { + t.Errorf("got %v, want %v", got, want) + } + numKind[err.(*errors.Error).Kind]++ + } + // Now, ensure the assigned Kind is Canceled, the lower number. + if got, want := len(numKind), 1; got != want { + t.Errorf("got %v, want %v", got, want) + } + if got, want := numKind[errors.Canceled], N; got != want { + t.Errorf("got %v, want %v", got, want) + } +} diff --git a/errorreporter/errorreporter.go b/errors/once.go similarity index 66% rename from errorreporter/errorreporter.go rename to errors/once.go index 7b5c5baa..0232a0a0 100644 --- a/errorreporter/errorreporter.go +++ b/errors/once.go @@ -2,9 +2,7 @@ // Use of this source code is governed by the Apache-2.0 // license that can be found in the LICENSE file. -// Package errorreporter is used to accumulate errors from -// multiple threads. -package errorreporter +package errors import ( "sync" @@ -12,12 +10,15 @@ import ( "unsafe" ) -// T accumulates errors across multiple threads. Thread safe. +// Once captures at most one error. Errors are safely set across +// multiple goroutines. +// +// A zero Once is ready to use. // // Example: -// e := errorreporter.T{} -// e.Set(errors.New("test error 0")) -type T struct { +// var e errors.Once +// e.Set(errors.New("test error 0")) +type Once struct { // Ignored is a list of errors that will be dropped in Set(). Ignored // typically includes io.EOF. Ignored []error @@ -25,9 +26,9 @@ type T struct { err unsafe.Pointer // stores *error } -// Err returns the first non-nil error passed to Set. Calling Err is cheap -// (~1ns). -func (e *T) Err() error { +// Err returns the first non-nil error passed to Set. Calling Err is +// cheap (~1ns). +func (e *Once) Err() error { p := atomic.LoadPointer(&e.err) // Acquire load if p == nil { return nil @@ -35,9 +36,9 @@ func (e *T) Err() error { return *(*error)(p) } -// Set sets an error. If called multiple times, only the first error is -// remembered. -func (e *T) Set(err error) { +// Set sets this instance's error to err. Only the first error +// is set; subsequent calls are ignored. +func (e *Once) Set(err error) { if err != nil { for _, ignored := range e.Ignored { if err == ignored { diff --git a/errorreporter/errorreporter_test.go b/errors/once_test.go similarity index 81% rename from errorreporter/errorreporter_test.go rename to errors/once_test.go index 6c160cd3..3b0ca45a 100644 --- a/errorreporter/errorreporter_test.go +++ b/errors/once_test.go @@ -2,20 +2,19 @@ // Use of this source code is governed by the Apache-2.0 // license that can be found in the LICENSE file. -package errorreporter_test +package errors_test import ( - "errors" "fmt" "runtime" "testing" - "github.com/grailbio/base/errorreporter" + "github.com/grailbio/base/errors" "github.com/stretchr/testify/require" ) -func TestError(t *testing.T) { - e := errorreporter.T{} +func TestOnce(t *testing.T) { + e := errors.Once{} require.NoError(t, e.Err()) e.Set(errors.New("testerror")) @@ -27,7 +26,7 @@ func TestError(t *testing.T) { } func BenchmarkReadNoError(b *testing.B) { - e := errorreporter.T{} + e := errors.Once{} for i := 0; i < b.N; i++ { if e.Err() != nil { require.Fail(b, "err") @@ -36,7 +35,7 @@ func BenchmarkReadNoError(b *testing.B) { } func BenchmarkReadError(b *testing.B) { - e := errorreporter.T{} + e := errors.Once{} e.Set(errors.New("testerror")) for i := 0; i < b.N; i++ { if e.Err() == nil { @@ -46,15 +45,15 @@ func BenchmarkReadError(b *testing.B) { } func BenchmarkSet(b *testing.B) { - e := errorreporter.T{} + e := errors.Once{} err := errors.New("testerror") for i := 0; i < b.N; i++ { e.Set(err) } } -func ExampleErrorReporter() { - e := errorreporter.T{} +func ExampleOnce() { + e := errors.Once{} fmt.Printf("Error: %v\n", e.Err()) e.Set(errors.New("test error 0")) fmt.Printf("Error: %v\n", e.Err()) diff --git a/eventlog/cloudwatch/cloudwatch.go b/eventlog/cloudwatch/cloudwatch.go new file mode 100644 index 00000000..5e21db9c --- /dev/null +++ b/eventlog/cloudwatch/cloudwatch.go @@ -0,0 +1,325 @@ +// Copyright 2020 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package cloudwatch + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/cloudwatchlogs" + "github.com/aws/aws-sdk-go/service/cloudwatchlogs/cloudwatchlogsiface" + "github.com/grailbio/base/config" + "github.com/grailbio/base/errors" + "github.com/grailbio/base/eventlog/internal/marshal" + "github.com/grailbio/base/log" + "github.com/grailbio/base/must" +) + +// maxBatchSize is the maximum batch size of the events that we send to CloudWatch Logs, "calculated +// as the sum of all event messages in UTF-8, plus 26 bytes for each log event". See: +// https://docs.aws.amazon.com/sdk-for-go/api/service/cloudwatchlogs/#CloudWatchLogs.PutLogEvents +const maxBatchSize = 1048576 + +// maxSingleMessageSize is the maximum size, in bytes, of a single message +// string of CloudWatch Log event. +const maxSingleMessageSize = maxBatchSize - 26 + +// syncInterval is the maximum duration that we will wait before sending any +// buffered messages to CloudWatch. +const syncInterval = 1 * time.Second + +// eventBufferSize is size of the channel buffer used to process events. When +// this buffer is full, new events are dropped. +const eventBufferSize = 32768 + +func init() { + config.Register("eventer/cloudwatch", func(constr *config.Constructor[*Eventer]) { + var sess *session.Session + constr.InstanceVar(&sess, "aws", "aws", "AWS configuration for all CloudWatch calls") + var group string + constr.StringVar(&group, "group", "eventlog", "the CloudWatch log group of the stream to which events will be sent") + var stream string + constr.StringVar(&stream, "stream", "", "the CloudWatch log stream to which events will be sent") + constr.Doc = "eventer/cloudwatch configures an eventer that sends events to a CloudWatch log stream" + constr.New = func() (*Eventer, error) { + cw := cloudwatchlogs.New(sess) + if stream == "" { + // All the information of RFC 3339 but without colons, as stream + // names cannot have colons. + const layout = "20060102T150405.000-0700" + // The default stream name incorporates the current executable + // name and time for some measure of uniqueness and usefulness. + stream = strings.Join([]string{readExec(), time.Now().Format(layout)}, "~") + } + return NewEventer(cw, group, stream), nil + } + }) +} + +type ( + // Eventer logs events to CloudWatch Logs. + Eventer struct { + client cloudwatchlogsiface.CloudWatchLogsAPI + group string + stream string + + cancel func() + + // eventc is used to send events that are batched and sent to CloudWatch + // Logs. + eventc chan event + + // syncc is used to force syncing of events to CloudWatch Logs. + syncc chan struct{} + // syncDonec is used to block for syncing. + syncDonec chan struct{} + + // donec is used to signal that the event processing loop is done. + donec chan struct{} + + initOnce sync.Once + initErr error + + sequenceToken *string + + // loggedFullBuffer prevents consecutive printing of "buffer full" log + // messages inside Event. We get a full buffer when we are overloaded with + // many messages. Logging each dropped message is very noisy, so we suppress + // consecutive logging. + loggedFullBuffer int32 + + now func() time.Time + } + + opts struct { + now func() time.Time + } + EventerOption func(*opts) +) + +type event struct { + timestamp time.Time + typ string + fieldPairs []interface{} +} + +// OptNow configures NewEventer to obtain timestamps from now. now must be non-nil. +// +// Example use case: reducing precision to make it more difficult to correlate which events +// likely came from the same user (who may have done a few things in one minute, etc.). +func OptNow(now func() time.Time) EventerOption { + must.True(now != nil) + return func(o *opts) { o.now = now } +} + +// NewEventer returns a *CloudWatchLogger. It does create the group or +// stream until the first event is logged. +func NewEventer( + client cloudwatchlogsiface.CloudWatchLogsAPI, group, stream string, options ...EventerOption, +) *Eventer { + opts := opts{now: time.Now} + for _, option := range options { + option(&opts) + } + eventer := &Eventer{ + client: client, + group: group, + stream: stream, + eventc: make(chan event, eventBufferSize), + syncc: make(chan struct{}), + syncDonec: make(chan struct{}), + donec: make(chan struct{}), + now: opts.now, + } + var ctx context.Context + ctx, eventer.cancel = context.WithCancel(context.Background()) + go eventer.loop(ctx) + return eventer +} + +func (c *Eventer) String() string { + return fmt.Sprintf("CloudWatch Logs: %s/%s", c.group, c.stream) +} + +// Event implements Eventer. +func (c *Eventer) Event(typ string, fieldPairs ...interface{}) { + select { + case c.eventc <- event{timestamp: c.now(), typ: typ, fieldPairs: fieldPairs}: + atomic.StoreInt32(&c.loggedFullBuffer, 0) + default: + if atomic.LoadInt32(&c.loggedFullBuffer) == 0 { + log.Error.Printf("Eventer: dropping log events: buffer full") + atomic.StoreInt32(&c.loggedFullBuffer, 1) + } + } +} + +// Init initializes the group and stream used by c. It will only attempt +// initialization once, subsequently returning the result of that attempt. +func (c *Eventer) Init(ctx context.Context) error { + // TODO: Initialize with loadingcache.Value so concurrent Init()s each respect their own + // context's cancellation. + c.initOnce.Do(func() { + defer func() { + if c.initErr != nil { + log.Error.Printf("Eventer: failed to initialize event log: %v", c.initErr) + } + }() + var err error + _, err = c.client.CreateLogGroupWithContext(ctx, &cloudwatchlogs.CreateLogGroupInput{ + LogGroupName: aws.String(c.group), + }) + if err != nil { + aerr, ok := err.(awserr.Error) + if !ok || aerr.Code() != cloudwatchlogs.ErrCodeResourceAlreadyExistsException { + c.initErr = errors.E(fmt.Sprintf("could not create CloudWatch log group %s", c.group), err) + return + } + } + _, err = c.client.CreateLogStreamWithContext(ctx, &cloudwatchlogs.CreateLogStreamInput{ + LogGroupName: aws.String(c.group), + LogStreamName: aws.String(c.stream), + }) + if err != nil { + aerr, ok := err.(awserr.Error) + if ok && aerr.Code() != cloudwatchlogs.ErrCodeResourceAlreadyExistsException { + c.initErr = errors.E(fmt.Sprintf("could not create CloudWatch log stream %s", c.stream), err) + return + } + } + }) + return c.initErr +} + +// sync syncs all buffered events to CloudWatch. This is mostly useful for +// testing. +func (c *Eventer) sync() { + c.syncc <- struct{}{} + <-c.syncDonec +} + +func (c *Eventer) Close() error { + c.cancel() + <-c.donec + return nil +} + +func (c *Eventer) loop(ctx context.Context) { + var ( + syncTimer = time.NewTimer(syncInterval) + inputLogEvents []*cloudwatchlogs.InputLogEvent + batchSize int + ) + sync := func(drainTimer bool) { + defer func() { + inputLogEvents = nil + batchSize = 0 + if !syncTimer.Stop() && drainTimer { + <-syncTimer.C + } + syncTimer.Reset(syncInterval) + }() + if len(inputLogEvents) == 0 { + return + } + if err := c.Init(ctx); err != nil { + return + } + response, err := c.client.PutLogEventsWithContext(ctx, &cloudwatchlogs.PutLogEventsInput{ + LogEvents: inputLogEvents, + LogGroupName: aws.String(c.group), + LogStreamName: aws.String(c.stream), + SequenceToken: c.sequenceToken, + }) + if err != nil { + log.Error.Printf("Eventer: PutLogEvents error: %v", err) + if aerr, ok := err.(*cloudwatchlogs.InvalidSequenceTokenException); ok { + c.sequenceToken = aerr.ExpectedSequenceToken + } + return + } + c.sequenceToken = response.NextSequenceToken + } + process := func(e event) { + s, err := marshal.Marshal(e.typ, e.fieldPairs) + if err != nil { + log.Error.Printf("Eventer: dropping log event: %v", err) + return + } + if len(s) > maxSingleMessageSize { + log.Error.Printf("Eventer: dropping log event: message too large") + return + } + newBatchSize := batchSize + len(s) + 26 + if newBatchSize > maxBatchSize { + sync(true) + } + inputLogEvents = append(inputLogEvents, &cloudwatchlogs.InputLogEvent{ + Message: aws.String(s), + Timestamp: aws.Int64(e.timestamp.UnixNano() / 1000000), + }) + } + drainEvents := func() { + drainLoop: + for { + select { + case e := <-c.eventc: + process(e) + default: + break drainLoop + } + } + } + for { + select { + case <-c.syncc: + drainEvents() + sync(false) + c.syncDonec <- struct{}{} + case <-syncTimer.C: + sync(false) + case e := <-c.eventc: + process(e) + case <-ctx.Done(): + close(c.eventc) + for e := range c.eventc { + process(e) + } + sync(true) + close(c.donec) + return + } + } +} + +// readExec returns a sanitized version of the executable name, if it can be +// determined. If not, returns "unknown". +func readExec() string { + const unknown = "unknown" + execPath, err := os.Executable() + if err != nil { + return unknown + } + rawExec := filepath.Base(execPath) + var sanitized strings.Builder + for _, r := range rawExec { + if (r == '-' || 'a' <= r && r <= 'z') || ('0' <= r && r <= '9') { + sanitized.WriteRune(r) + } + } + if sanitized.Len() == 0 { + return unknown + } + return sanitized.String() +} diff --git a/eventlog/cloudwatch/cloudwatch_test.go b/eventlog/cloudwatch/cloudwatch_test.go new file mode 100644 index 00000000..213d0b0c --- /dev/null +++ b/eventlog/cloudwatch/cloudwatch_test.go @@ -0,0 +1,223 @@ +// Copyright 2020 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package cloudwatch + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/cloudwatchlogs" + "github.com/aws/aws-sdk-go/service/cloudwatchlogs/cloudwatchlogsiface" + "github.com/grailbio/base/eventlog/internal/marshal" +) + +const testGroup = "testGroup" +const testStream = "testStream" +const typ = "testEventType" +const k = "testFieldKey" + +type logsAPIFake struct { + cloudwatchlogsiface.CloudWatchLogsAPI + + groupInput *cloudwatchlogs.CreateLogGroupInput + streamInput *cloudwatchlogs.CreateLogStreamInput + eventsInputs []*cloudwatchlogs.PutLogEventsInput + + sequenceMu sync.Mutex + sequence int +} + +func (f *logsAPIFake) CreateLogGroupWithContext(ctx context.Context, + input *cloudwatchlogs.CreateLogGroupInput, + opts ...request.Option) (*cloudwatchlogs.CreateLogGroupOutput, error) { + + f.groupInput = input + return nil, nil +} + +func (f *logsAPIFake) CreateLogStreamWithContext(ctx context.Context, + input *cloudwatchlogs.CreateLogStreamInput, + opts ...request.Option) (*cloudwatchlogs.CreateLogStreamOutput, error) { + + f.streamInput = input + return nil, nil +} + +func (f *logsAPIFake) PutLogEventsWithContext(ctx context.Context, + input *cloudwatchlogs.PutLogEventsInput, + opts ...request.Option) (*cloudwatchlogs.PutLogEventsOutput, error) { + + var ts *int64 + for _, event := range input.LogEvents { + if ts != nil && *event.Timestamp < *ts { + return nil, &cloudwatchlogs.InvalidParameterException{} + } + ts = event.Timestamp + } + + nextSequenceToken, err := func() (*string, error) { + f.sequenceMu.Lock() + defer f.sequenceMu.Unlock() + if f.sequence != 0 { + sequenceToken := fmt.Sprintf("%d", f.sequence) + if input.SequenceToken == nil || sequenceToken != *input.SequenceToken { + return nil, &cloudwatchlogs.InvalidSequenceTokenException{ + ExpectedSequenceToken: &sequenceToken, + } + } + } + f.sequence++ + nextSequenceToken := fmt.Sprintf("%d", f.sequence) + return &nextSequenceToken, nil + }() + if err != nil { + return nil, err + } + + f.eventsInputs = append(f.eventsInputs, input) + return &cloudwatchlogs.PutLogEventsOutput{ + NextSequenceToken: nextSequenceToken, + }, nil +} + +func (f *logsAPIFake) logEvents() []*cloudwatchlogs.InputLogEvent { + var events []*cloudwatchlogs.InputLogEvent + for _, input := range f.eventsInputs { + events = append(events, input.LogEvents...) + } + return events +} + +func (f *logsAPIFake) incrNextSequence() { + f.sequenceMu.Lock() + defer f.sequenceMu.Unlock() + f.sequence++ +} + +// TestEvent verifies that logged events are sent to CloudWatch correctly. +func TestEvent(t *testing.T) { + const N = 1000 + + if eventBufferSize < N { + panic("keep N <= eventBufferSize to make sure no events are dropped") + } + + // Note: Access to nowUnixMillis is unsynchronized because now() is only called in Event(), + // not in any background or asynchronous goroutine. + var nowUnixMillis int64 = 1600000000000 // Arbitrary time in 2020. + now := func() time.Time { + return time.UnixMilli(nowUnixMillis) + } + + // Log events. + cw := &logsAPIFake{} + e := NewEventer(cw, testGroup, testStream, OptNow(now)) + wantTimestamps := make([]time.Time, N) + for i := 0; i < N; i++ { + k := fmt.Sprintf("k%d", i) + e.Event(typ, k, i) + wantTimestamps[i] = now() + nowUnixMillis += time.Hour.Milliseconds() + } + e.Close() + + // Make sure events get to CloudWatch with the right contents and in order. + events := cw.logEvents() + if got, want := len(events), N; got != want { + t.Errorf("got %v, want %v", got, want) + } + for i, event := range events { + k := fmt.Sprintf("k%d", i) + m, err := marshal.Marshal(typ, []interface{}{k, i}) + if err != nil { + t.Fatalf("error marshaling event: %v", err) + } + if got, want := *event.Message, m; got != want { + t.Errorf("got %v, want %v", got, want) + continue + } + if got, want := time.UnixMilli(*event.Timestamp), wantTimestamps[i]; !want.Equal(got) { + t.Errorf("got %v, want %v", got, want) + continue + } + } +} + +// TestBufferFull verifies that exceeding the event buffer leads to, at worst, +// dropped events. Events that are not dropped should still be logged in order. +func TestBufferFull(t *testing.T) { + const N = 100 * 1000 + + // Log many events, overwhelming buffer. + cw := &logsAPIFake{} + e := NewEventer(cw, testGroup, testStream) + for i := 0; i < N; i++ { + e.Event(typ, k, i) + } + e.Close() + + events := cw.logEvents() + if N < len(events) { + t.Fatalf("more events sent to CloudWatch than were logged: %d < %d", N, len(events)) + } + assertOrdered(t, events) +} + +// TestInvalidSequenceToken verifies that we recover if our sequence token gets +// out of sync. This should not happen, as we should be the only thing writing +// to a given log stream, but we try to recover anyway. +func TestInvalidSequenceToken(t *testing.T) { + cw := &logsAPIFake{} + e := NewEventer(cw, testGroup, testStream) + + e.Event(typ, k, 0) + e.sync() + cw.incrNextSequence() + e.Event(typ, k, 1) + e.sync() + e.Event(typ, k, 2) + e.sync() + e.Close() + + events := cw.logEvents() + if 3 < len(events) { + t.Fatalf("more events sent to CloudWatch than were logged: 3 < %d", len(events)) + } + if len(events) < 2 { + t.Errorf("did not successfully re-sync sequence token") + } + assertOrdered(t, events) +} + +// assertOrdered asserts that the values of field k are increasing for events. +// This is how we construct events sent to the Eventer, so we use this +// verify that the events sent to the CloudWatch Logs API are ordered correctly. +func assertOrdered(t *testing.T, events []*cloudwatchlogs.InputLogEvent) { + t.Helper() + last := -1 + for _, event := range events { + var m map[string]interface{} + if err := json.Unmarshal([]byte(*event.Message), &m); err != nil { + t.Fatalf("could not unmarshal event message: %v", err) + } + v, ok := m[k] + if !ok { + t.Errorf("event message does not contain test key %q: %s", k, *event.Message) + continue + } + // All numeric values are unmarshaled as float64, so we need to convert + // back to int. + vi := int(v.(float64)) + if vi <= last { + t.Errorf("event out of order; expected %d < %d", last, vi) + continue + } + } +} diff --git a/eventlog/eventlog.go b/eventlog/eventlog.go new file mode 100644 index 00000000..30425165 --- /dev/null +++ b/eventlog/eventlog.go @@ -0,0 +1,91 @@ +// Copyright 2020 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +// Package eventlog provides logging of semi-structured events, particularly in +// service of downstream analysis, e.g. when machines are started, when a user +// issues a command, when failures happen. +// +// For example, you can log events to CloudWatch Logs: +// +// sess := session.NewSession() +// cw := cloudwatchlogs.New(sess) +// e := cloudwatch.NewEventer(cw, "myLogGroup", "myLogStream") +// e.Event("rpcRetry", "org", "GRAIL", "retry", 0, "maxRetry", 10) +// e.Event("machineStopped", "addr", "192.168.1.1", "duration", 3600.0, "startTime": 1584140534) +// +// These events can now be analyzed and monitored using CloudWatch Logs tooling. +package eventlog + +import ( + "bytes" + + "github.com/grailbio/base/config" + "github.com/grailbio/base/log" +) + +func init() { + config.Register("eventer/nop", func(constr *config.Constructor[Nop]) { + constr.Doc = "eventer/nop configures a no-op event logger" + constr.New = func() (Nop, error) { + return Nop{}, nil + } + }) + config.Register("eventer/log-info", func(constr *config.Constructor[Log]) { + constr.Doc = "eventer/log-info configures an eventer that writes events at log level info" + constr.New = func() (Log, error) { + return Log(log.Info), nil + } + }) + // We do the most conservative thing by default, making event logging a + // no-op. + config.Default("eventer", "eventer/nop") +} + +// Eventer is called to log events. +type Eventer interface { + // Event logs an event of typ with (key string, value interface{}) fields given in fieldPairs + // as k0, v0, k1, v1, ...kn, vn. For example: + // + // s.Event("machineStart", "addr", "192.168.1.2", "time", time.Now().Unix()) + // + // The value will be serialized as JSON. + // + // The key "eventType" is reserved. Field keys must be unique. Any violation will result + // in the event being dropped and logged. + // + // Implementations must be safe for concurrent use. + Event(typ string, fieldPairs ...interface{}) +} + +// Nop is a no-op Eventer. +type Nop struct{} + +var _ Eventer = Nop{} + +func (Nop) String() string { + return "disabled" +} + +// Event implements Eventer. +func (Nop) Event(_ string, _ ...interface{}) {} + +// Log is an Eventer that writes events to the logger. It's intended for debugging/development. +type Log log.Level + +var _ Eventer = Log(log.Debug) + +// Event implements Eventer. +func (l Log) Event(typ string, fieldPairs ...interface{}) { + f := bytes.NewBufferString("eventlog: %s {") + for i := range fieldPairs { + f.WriteString("%v") + if i%2 == 0 { + f.WriteString(": ") + } else if i < len(fieldPairs)-1 { + f.WriteString(", ") + } + } + f.WriteByte('}') + log.Level(l).Printf(f.String(), append([]interface{}{typ}, fieldPairs...)...) +} diff --git a/eventlog/internal/marshal/marshal.go b/eventlog/internal/marshal/marshal.go new file mode 100644 index 00000000..55258785 --- /dev/null +++ b/eventlog/internal/marshal/marshal.go @@ -0,0 +1,44 @@ +// Copyright 2020 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package marshal + +import ( + "encoding/json" + "fmt" +) + +// eventTypeFieldKey is the name of the JSON member holding the event type, +// passed as typ, in the JSON string returned by marshal. It is a reserved field +// key name. +const eventTypeFieldKey = "eventType" + +// marshal marshal event information into a JSON string. Field keys must be +// unique, otherwise marshal returns an error. +func Marshal(typ string, fieldPairs []interface{}) (string, error) { + if len(fieldPairs)%2 != 0 { + return "", fmt.Errorf("len(fieldPairs) must be even; %d is not even", len(fieldPairs)) + } + fields := make(map[string]interface{}) + for i := 0; i < len(fieldPairs); i++ { + key, isString := fieldPairs[i].(string) + if !isString { + return "", fmt.Errorf("field key at fieldPairs[%d] must be a string: %v", i, fieldPairs[i]) + } + if key == eventTypeFieldKey { + return "", fmt.Errorf("field key at fieldPairs[%d] is '%s'; '%s' is reserved", i, eventTypeFieldKey, eventTypeFieldKey) + } + if _, dupKey := fields[key]; dupKey { + return "", fmt.Errorf("key %q at fieldPairs[%d] already used; duplicate keys not allowed", key, i) + } + i++ + fields[key] = fieldPairs[i] + } + fields[eventTypeFieldKey] = typ + bs, err := json.Marshal(fields) + if err != nil { + return "", fmt.Errorf("error marshaling fields to JSON: %v", err) + } + return string(bs), nil +} diff --git a/eventlog/internal/marshal/marshal_test.go b/eventlog/internal/marshal/marshal_test.go new file mode 100644 index 00000000..3c0f138d --- /dev/null +++ b/eventlog/internal/marshal/marshal_test.go @@ -0,0 +1,101 @@ +// Copyright 2020 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package marshal + +import ( + "encoding/json" + "strings" + "testing" +) + +// TestMarshal verifies that Marshal behaves properly in both success and +// failure cases. For success cases, it roundtrips the marshaled string and +// verifies the result. For failure cases, it checks for expected error +// messages. +func TestMarshal(t *testing.T) { + for _, c := range []struct { + name string + fieldPairs []interface{} + // errNeedle is "" if we expect no error. Otherwise, it is a string that + // we expect to see in the resulting err.Error(). + errNeedle string + }{ + { + "no fields", + []interface{}{}, + "", + }, + { + "simple", + []interface{}{"k0", "v0"}, + "", + }, + { + "mixed value types", + // Numeric types turn into float64s in JSON. + []interface{}{"k0", "v0", "k1", float64(1), "k2", true}, + "", + }, + { + "odd field pairs", + []interface{}{"k0", "v0", "k1"}, + "even", + }, + { + "non-string key", + []interface{}{0, "v0"}, + "string", + }, + { + "duplicate keys", + []interface{}{"k0", "v0", "k0", "v1"}, + "duplicate", + }, + } { + t.Run(c.name, func(t *testing.T) { + marshalOK := c.errNeedle == "" + s, err := Marshal(c.name, c.fieldPairs) + if got, want := err == nil, marshalOK; got != want { + t.Fatalf("got %v, want %v", got, want) + } + if !marshalOK { + if !strings.Contains(err.Error(), c.errNeedle) { + t.Errorf("error %q does not contain expected substring %q", err.Error(), c.errNeedle) + } + return + } + var m map[string]interface{} + err = json.Unmarshal([]byte(s), &m) + if err != nil { + t.Fatalf("unmarshaling failed: %v", err) + } + // The +1 is for the eventType. + if got, want := len(m), (len(c.fieldPairs)/2)+1; got != want { + t.Errorf("got %v, want %v", got, want) + } + typ, ok := m[eventTypeFieldKey] + if ok { + if got, want := typ, c.name; got != want { + t.Errorf("got %v, want %v", got, want) + } + } else { + t.Errorf("eventType field not marshaled") + } + for i := 0; i < len(c.fieldPairs); i++ { + key := c.fieldPairs[i].(string) + i++ + value := c.fieldPairs[i] + mvalue, ok := m[key] + if !ok { + t.Errorf("field with key %q not marshaled", key) + continue + } + if got, want := mvalue, value; got != want { + t.Errorf("got %v(%T), want %v(%T)", got, got, want, want) + } + } + }) + } +} diff --git a/fatbin/binaries_test.go b/fatbin/binaries_test.go new file mode 100644 index 00000000..a0c47a1d --- /dev/null +++ b/fatbin/binaries_test.go @@ -0,0 +1,259 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package fatbin + +// Stripped version of C program: +// void main(){printf("hello world");} +var svelteLinuxElfBinary = []byte{ + 0x7f, 0x45, 0x4c, 0x46, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x03, 0x00, 0x01, 0x00, 0x00, 0x00, + 0xf0, 0x82, 0x04, 0x08, 0x34, 0x00, 0x00, 0x00, 0x70, 0x07, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x34, 0x00, 0x20, 0x00, 0x07, 0x00, 0x28, 0x00, + 0x1b, 0x00, 0x1a, 0x00, 0x06, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, + 0x34, 0x80, 0x04, 0x08, 0x34, 0x80, 0x04, 0x08, 0xe0, 0x00, 0x00, 0x00, + 0xe0, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x14, 0x01, 0x00, 0x00, 0x14, 0x81, 0x04, 0x08, + 0x14, 0x81, 0x04, 0x08, 0x13, 0x00, 0x00, 0x00, 0x13, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x04, 0x08, 0x00, 0x80, 0x04, 0x08, + 0x70, 0x04, 0x00, 0x00, 0x70, 0x04, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, + 0x00, 0x10, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x70, 0x04, 0x00, 0x00, + 0x70, 0x94, 0x04, 0x08, 0x70, 0x94, 0x04, 0x08, 0x0c, 0x01, 0x00, 0x00, + 0x10, 0x01, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x84, 0x04, 0x00, 0x00, 0x84, 0x94, 0x04, 0x08, + 0x84, 0x94, 0x04, 0x08, 0xd0, 0x00, 0x00, 0x00, 0xd0, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x28, 0x01, 0x00, 0x00, 0x28, 0x81, 0x04, 0x08, 0x28, 0x81, 0x04, 0x08, + 0x20, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x51, 0xe5, 0x74, 0x64, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x2f, 0x6c, 0x69, 0x62, 0x2f, 0x6c, 0x64, 0x2d, 0x6c, 0x69, 0x6e, 0x75, + 0x78, 0x2e, 0x73, 0x6f, 0x2e, 0x32, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x47, 0x4e, 0x55, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, + 0x00, 0x20, 0x00, 0x20, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xad, 0x4b, 0xe3, 0xc0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x2e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xb2, 0x01, 0x00, 0x00, + 0x12, 0x00, 0x00, 0x00, 0x29, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x8f, 0x01, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, 0x1a, 0x00, 0x00, 0x00, + 0x5c, 0x84, 0x04, 0x08, 0x04, 0x00, 0x00, 0x00, 0x11, 0x00, 0x0f, 0x00, + 0x00, 0x5f, 0x5f, 0x67, 0x6d, 0x6f, 0x6e, 0x5f, 0x73, 0x74, 0x61, 0x72, + 0x74, 0x5f, 0x5f, 0x00, 0x6c, 0x69, 0x62, 0x63, 0x2e, 0x73, 0x6f, 0x2e, + 0x36, 0x00, 0x5f, 0x49, 0x4f, 0x5f, 0x73, 0x74, 0x64, 0x69, 0x6e, 0x5f, + 0x75, 0x73, 0x65, 0x64, 0x00, 0x70, 0x75, 0x74, 0x73, 0x00, 0x5f, 0x5f, + 0x6c, 0x69, 0x62, 0x63, 0x5f, 0x73, 0x74, 0x61, 0x72, 0x74, 0x5f, 0x6d, + 0x61, 0x69, 0x6e, 0x00, 0x47, 0x4c, 0x49, 0x42, 0x43, 0x5f, 0x32, 0x2e, + 0x30, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x02, 0x00, 0x01, 0x00, + 0x01, 0x00, 0x01, 0x00, 0x10, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x10, 0x69, 0x69, 0x0d, 0x00, 0x00, 0x02, 0x00, + 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x54, 0x95, 0x04, 0x08, + 0x06, 0x01, 0x00, 0x00, 0x64, 0x95, 0x04, 0x08, 0x07, 0x01, 0x00, 0x00, + 0x68, 0x95, 0x04, 0x08, 0x07, 0x02, 0x00, 0x00, 0x6c, 0x95, 0x04, 0x08, + 0x07, 0x03, 0x00, 0x00, 0x55, 0x89, 0xe5, 0x53, 0x83, 0xec, 0x04, 0xe8, + 0x00, 0x00, 0x00, 0x00, 0x5b, 0x81, 0xc3, 0xd8, 0x12, 0x00, 0x00, 0x8b, + 0x93, 0xfc, 0xff, 0xff, 0xff, 0x85, 0xd2, 0x74, 0x05, 0xe8, 0x1e, 0x00, + 0x00, 0x00, 0xe8, 0xb5, 0x00, 0x00, 0x00, 0xe8, 0x70, 0x01, 0x00, 0x00, + 0x58, 0x5b, 0xc9, 0xc3, 0xff, 0x35, 0x5c, 0x95, 0x04, 0x08, 0xff, 0x25, + 0x60, 0x95, 0x04, 0x08, 0x00, 0x00, 0x00, 0x00, 0xff, 0x25, 0x64, 0x95, + 0x04, 0x08, 0x68, 0x00, 0x00, 0x00, 0x00, 0xe9, 0xe0, 0xff, 0xff, 0xff, + 0xff, 0x25, 0x68, 0x95, 0x04, 0x08, 0x68, 0x08, 0x00, 0x00, 0x00, 0xe9, + 0xd0, 0xff, 0xff, 0xff, 0xff, 0x25, 0x6c, 0x95, 0x04, 0x08, 0x68, 0x10, + 0x00, 0x00, 0x00, 0xe9, 0xc0, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x31, 0xed, 0x5e, 0x89, + 0xe1, 0x83, 0xe4, 0xf0, 0x50, 0x54, 0x52, 0x68, 0xa0, 0x83, 0x04, 0x08, + 0x68, 0xb0, 0x83, 0x04, 0x08, 0x51, 0x56, 0x68, 0x74, 0x83, 0x04, 0x08, + 0xe8, 0xb3, 0xff, 0xff, 0xff, 0xf4, 0x90, 0x90, 0x90, 0x90, 0x90, 0x90, + 0x90, 0x90, 0x90, 0x90, 0x90, 0x90, 0x90, 0x90, 0x55, 0x89, 0xe5, 0x83, + 0xec, 0x08, 0x80, 0x3d, 0x7c, 0x95, 0x04, 0x08, 0x00, 0x74, 0x0c, 0xeb, + 0x1c, 0x83, 0xc0, 0x04, 0xa3, 0x78, 0x95, 0x04, 0x08, 0xff, 0xd2, 0xa1, + 0x78, 0x95, 0x04, 0x08, 0x8b, 0x10, 0x85, 0xd2, 0x75, 0xeb, 0xc6, 0x05, + 0x7c, 0x95, 0x04, 0x08, 0x01, 0xc9, 0xc3, 0x90, 0x55, 0x89, 0xe5, 0x83, + 0xec, 0x08, 0xa1, 0x80, 0x94, 0x04, 0x08, 0x85, 0xc0, 0x74, 0x12, 0xb8, + 0x00, 0x00, 0x00, 0x00, 0x85, 0xc0, 0x74, 0x09, 0xc7, 0x04, 0x24, 0x80, + 0x94, 0x04, 0x08, 0xff, 0xd0, 0xc9, 0xc3, 0x90, 0x8d, 0x4c, 0x24, 0x04, + 0x83, 0xe4, 0xf0, 0xff, 0x71, 0xfc, 0x55, 0x89, 0xe5, 0x51, 0x83, 0xec, + 0x04, 0xc7, 0x04, 0x24, 0x60, 0x84, 0x04, 0x08, 0xe8, 0x43, 0xff, 0xff, + 0xff, 0xb8, 0x00, 0x00, 0x00, 0x00, 0x83, 0xc4, 0x04, 0x59, 0x5d, 0x8d, + 0x61, 0xfc, 0xc3, 0x90, 0x55, 0x89, 0xe5, 0x5d, 0xc3, 0x8d, 0x74, 0x26, + 0x00, 0x8d, 0xbc, 0x27, 0x00, 0x00, 0x00, 0x00, 0x55, 0x89, 0xe5, 0x57, + 0x56, 0x53, 0xe8, 0x4f, 0x00, 0x00, 0x00, 0x81, 0xc3, 0x9d, 0x11, 0x00, + 0x00, 0x83, 0xec, 0x0c, 0xe8, 0xab, 0xfe, 0xff, 0xff, 0x8d, 0xbb, 0x18, + 0xff, 0xff, 0xff, 0x8d, 0x83, 0x18, 0xff, 0xff, 0xff, 0x29, 0xc7, 0xc1, + 0xff, 0x02, 0x85, 0xff, 0x74, 0x24, 0x31, 0xf6, 0x8b, 0x45, 0x10, 0x89, + 0x44, 0x24, 0x08, 0x8b, 0x45, 0x0c, 0x89, 0x44, 0x24, 0x04, 0x8b, 0x45, + 0x08, 0x89, 0x04, 0x24, 0xff, 0x94, 0xb3, 0x18, 0xff, 0xff, 0xff, 0x83, + 0xc6, 0x01, 0x39, 0xf7, 0x75, 0xde, 0x83, 0xc4, 0x0c, 0x5b, 0x5e, 0x5f, + 0x5d, 0xc3, 0x8b, 0x1c, 0x24, 0xc3, 0x90, 0x90, 0x55, 0x89, 0xe5, 0x53, + 0x83, 0xec, 0x04, 0xa1, 0x70, 0x94, 0x04, 0x08, 0x83, 0xf8, 0xff, 0x74, + 0x12, 0x31, 0xdb, 0xff, 0xd0, 0x8b, 0x83, 0x6c, 0x94, 0x04, 0x08, 0x83, + 0xeb, 0x04, 0x83, 0xf8, 0xff, 0x75, 0xf0, 0x83, 0xc4, 0x04, 0x5b, 0x5d, + 0xc3, 0x90, 0x90, 0x90, 0x55, 0x89, 0xe5, 0x53, 0x83, 0xec, 0x04, 0xe8, + 0x00, 0x00, 0x00, 0x00, 0x5b, 0x81, 0xc3, 0x10, 0x11, 0x00, 0x00, 0xe8, + 0xcc, 0xfe, 0xff, 0xff, 0x59, 0x5b, 0xc9, 0xc3, 0x03, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x02, 0x00, 0x48, 0x69, 0x20, 0x57, 0x6f, 0x72, 0x6c, 0x64, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x74, 0x82, 0x04, 0x08, 0x0d, 0x00, 0x00, 0x00, + 0x3c, 0x84, 0x04, 0x08, 0x04, 0x00, 0x00, 0x00, 0x48, 0x81, 0x04, 0x08, + 0xf5, 0xfe, 0xff, 0x6f, 0x70, 0x81, 0x04, 0x08, 0x05, 0x00, 0x00, 0x00, + 0xe0, 0x81, 0x04, 0x08, 0x06, 0x00, 0x00, 0x00, 0x90, 0x81, 0x04, 0x08, + 0x0a, 0x00, 0x00, 0x00, 0x4a, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x00, 0x00, 0x15, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x58, 0x95, 0x04, 0x08, 0x02, 0x00, 0x00, 0x00, + 0x18, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x11, 0x00, 0x00, 0x00, + 0x17, 0x00, 0x00, 0x00, 0x5c, 0x82, 0x04, 0x08, 0x11, 0x00, 0x00, 0x00, + 0x54, 0x82, 0x04, 0x08, 0x12, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x13, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0xfe, 0xff, 0xff, 0x6f, + 0x34, 0x82, 0x04, 0x08, 0xff, 0xff, 0xff, 0x6f, 0x01, 0x00, 0x00, 0x00, + 0xf0, 0xff, 0xff, 0x6f, 0x2a, 0x82, 0x04, 0x08, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x84, 0x94, 0x04, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xba, 0x82, 0x04, 0x08, 0xca, 0x82, 0x04, 0x08, 0xda, 0x82, 0x04, 0x08, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x7c, 0x94, 0x04, 0x08, + 0x00, 0x47, 0x43, 0x43, 0x3a, 0x20, 0x28, 0x47, 0x4e, 0x55, 0x29, 0x20, + 0x34, 0x2e, 0x32, 0x2e, 0x33, 0x20, 0x28, 0x55, 0x62, 0x75, 0x6e, 0x74, + 0x75, 0x20, 0x34, 0x2e, 0x32, 0x2e, 0x33, 0x2d, 0x32, 0x75, 0x62, 0x75, + 0x6e, 0x74, 0x75, 0x37, 0x29, 0x00, 0x00, 0x47, 0x43, 0x43, 0x3a, 0x20, + 0x28, 0x47, 0x4e, 0x55, 0x29, 0x20, 0x34, 0x2e, 0x32, 0x2e, 0x33, 0x20, + 0x28, 0x55, 0x62, 0x75, 0x6e, 0x74, 0x75, 0x20, 0x34, 0x2e, 0x32, 0x2e, + 0x33, 0x2d, 0x32, 0x75, 0x62, 0x75, 0x6e, 0x74, 0x75, 0x37, 0x29, 0x00, + 0x00, 0x47, 0x43, 0x43, 0x3a, 0x20, 0x28, 0x47, 0x4e, 0x55, 0x29, 0x20, + 0x34, 0x2e, 0x32, 0x2e, 0x33, 0x20, 0x28, 0x55, 0x62, 0x75, 0x6e, 0x74, + 0x75, 0x20, 0x34, 0x2e, 0x32, 0x2e, 0x33, 0x2d, 0x32, 0x75, 0x62, 0x75, + 0x6e, 0x74, 0x75, 0x37, 0x29, 0x00, 0x00, 0x47, 0x43, 0x43, 0x3a, 0x20, + 0x28, 0x47, 0x4e, 0x55, 0x29, 0x20, 0x34, 0x2e, 0x32, 0x2e, 0x33, 0x20, + 0x28, 0x55, 0x62, 0x75, 0x6e, 0x74, 0x75, 0x20, 0x34, 0x2e, 0x32, 0x2e, + 0x33, 0x2d, 0x32, 0x75, 0x62, 0x75, 0x6e, 0x74, 0x75, 0x37, 0x29, 0x00, + 0x00, 0x47, 0x43, 0x43, 0x3a, 0x20, 0x28, 0x47, 0x4e, 0x55, 0x29, 0x20, + 0x34, 0x2e, 0x32, 0x2e, 0x33, 0x20, 0x28, 0x55, 0x62, 0x75, 0x6e, 0x74, + 0x75, 0x20, 0x34, 0x2e, 0x32, 0x2e, 0x33, 0x2d, 0x32, 0x75, 0x62, 0x75, + 0x6e, 0x74, 0x75, 0x37, 0x29, 0x00, 0x00, 0x47, 0x43, 0x43, 0x3a, 0x20, + 0x28, 0x47, 0x4e, 0x55, 0x29, 0x20, 0x34, 0x2e, 0x32, 0x2e, 0x33, 0x20, + 0x28, 0x55, 0x62, 0x75, 0x6e, 0x74, 0x75, 0x20, 0x34, 0x2e, 0x32, 0x2e, + 0x33, 0x2d, 0x32, 0x75, 0x62, 0x75, 0x6e, 0x74, 0x75, 0x37, 0x29, 0x00, + 0x00, 0x47, 0x43, 0x43, 0x3a, 0x20, 0x28, 0x47, 0x4e, 0x55, 0x29, 0x20, + 0x34, 0x2e, 0x32, 0x2e, 0x33, 0x20, 0x28, 0x55, 0x62, 0x75, 0x6e, 0x74, + 0x75, 0x20, 0x34, 0x2e, 0x32, 0x2e, 0x33, 0x2d, 0x32, 0x75, 0x62, 0x75, + 0x6e, 0x74, 0x75, 0x37, 0x29, 0x00, 0x00, 0x2e, 0x73, 0x68, 0x73, 0x74, + 0x72, 0x74, 0x61, 0x62, 0x00, 0x2e, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x70, + 0x00, 0x2e, 0x6e, 0x6f, 0x74, 0x65, 0x2e, 0x41, 0x42, 0x49, 0x2d, 0x74, + 0x61, 0x67, 0x00, 0x2e, 0x67, 0x6e, 0x75, 0x2e, 0x68, 0x61, 0x73, 0x68, + 0x00, 0x2e, 0x64, 0x79, 0x6e, 0x73, 0x79, 0x6d, 0x00, 0x2e, 0x64, 0x79, + 0x6e, 0x73, 0x74, 0x72, 0x00, 0x2e, 0x67, 0x6e, 0x75, 0x2e, 0x76, 0x65, + 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x00, 0x2e, 0x67, 0x6e, 0x75, 0x2e, 0x76, + 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x72, 0x00, 0x2e, 0x72, 0x65, + 0x6c, 0x2e, 0x64, 0x79, 0x6e, 0x00, 0x2e, 0x72, 0x65, 0x6c, 0x2e, 0x70, + 0x6c, 0x74, 0x00, 0x2e, 0x69, 0x6e, 0x69, 0x74, 0x00, 0x2e, 0x74, 0x65, + 0x78, 0x74, 0x00, 0x2e, 0x66, 0x69, 0x6e, 0x69, 0x00, 0x2e, 0x72, 0x6f, + 0x64, 0x61, 0x74, 0x61, 0x00, 0x2e, 0x65, 0x68, 0x5f, 0x66, 0x72, 0x61, + 0x6d, 0x65, 0x00, 0x2e, 0x63, 0x74, 0x6f, 0x72, 0x73, 0x00, 0x2e, 0x64, + 0x74, 0x6f, 0x72, 0x73, 0x00, 0x2e, 0x6a, 0x63, 0x72, 0x00, 0x2e, 0x64, + 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x00, 0x2e, 0x67, 0x6f, 0x74, 0x00, + 0x2e, 0x67, 0x6f, 0x74, 0x2e, 0x70, 0x6c, 0x74, 0x00, 0x2e, 0x64, 0x61, + 0x74, 0x61, 0x00, 0x2e, 0x62, 0x73, 0x73, 0x00, 0x2e, 0x63, 0x6f, 0x6d, + 0x6d, 0x65, 0x6e, 0x74, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x0b, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x14, 0x81, 0x04, 0x08, 0x14, 0x01, 0x00, 0x00, 0x13, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x13, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x28, 0x81, 0x04, 0x08, 0x28, 0x01, 0x00, 0x00, + 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x25, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x48, 0x81, 0x04, 0x08, + 0x48, 0x01, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x21, 0x00, 0x00, 0x00, 0xf6, 0xff, 0xff, 0x6f, 0x02, 0x00, 0x00, 0x00, + 0x70, 0x81, 0x04, 0x08, 0x70, 0x01, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x90, 0x81, 0x04, 0x08, 0x90, 0x01, 0x00, 0x00, + 0x50, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x33, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0xe0, 0x81, 0x04, 0x08, + 0xe0, 0x01, 0x00, 0x00, 0x4a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x3b, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0x6f, 0x02, 0x00, 0x00, 0x00, + 0x2a, 0x82, 0x04, 0x08, 0x2a, 0x02, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x02, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, 0xfe, 0xff, 0xff, 0x6f, + 0x02, 0x00, 0x00, 0x00, 0x34, 0x82, 0x04, 0x08, 0x34, 0x02, 0x00, 0x00, + 0x20, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x57, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x54, 0x82, 0x04, 0x08, + 0x54, 0x02, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, + 0x60, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x5c, 0x82, 0x04, 0x08, 0x5c, 0x02, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, + 0x05, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x69, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x74, 0x82, 0x04, 0x08, 0x74, 0x02, 0x00, 0x00, + 0x30, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x64, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0xa4, 0x82, 0x04, 0x08, + 0xa4, 0x02, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x6f, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0xf0, 0x82, 0x04, 0x08, 0xf0, 0x02, 0x00, 0x00, 0x4c, 0x01, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x75, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x06, 0x00, 0x00, 0x00, 0x3c, 0x84, 0x04, 0x08, 0x3c, 0x04, 0x00, 0x00, + 0x1c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x7b, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x58, 0x84, 0x04, 0x08, + 0x58, 0x04, 0x00, 0x00, 0x11, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x83, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, + 0x6c, 0x84, 0x04, 0x08, 0x6c, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x8d, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x70, 0x94, 0x04, 0x08, 0x70, 0x04, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x94, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x78, 0x94, 0x04, 0x08, + 0x78, 0x04, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x9b, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x80, 0x94, 0x04, 0x08, 0x80, 0x04, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0xa0, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x84, 0x94, 0x04, 0x08, 0x84, 0x04, 0x00, 0x00, + 0xd0, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0xa9, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x54, 0x95, 0x04, 0x08, + 0x54, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0xae, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x58, 0x95, 0x04, 0x08, 0x58, 0x05, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0xb7, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, 0x70, 0x95, 0x04, 0x08, 0x70, 0x05, 0x00, 0x00, + 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xbd, 0x00, 0x00, 0x00, + 0x08, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x7c, 0x95, 0x04, 0x08, + 0x7c, 0x05, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0xc2, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x7c, 0x05, 0x00, 0x00, 0x26, 0x01, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xa2, 0x06, 0x00, 0x00, + 0xcb, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, +} diff --git a/fatbin/create.go b/fatbin/create.go new file mode 100644 index 00000000..4cffebbb --- /dev/null +++ b/fatbin/create.go @@ -0,0 +1,75 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package fatbin + +import ( + "archive/zip" + "io" + "os" +) + +// Writer is used to append fatbin images to an existing binary. +type Writer struct { + z *zip.Writer + w io.Writer + off int64 + goos, goarch string +} + +// NewFileWriter returns a writer that can be used to append fatbin +// images to the binary represented by the provided file. +// NewFileWriter removes any existing fatbin images that may be +// attached to the binary. It relies on content sniffing (see Sniff) +// to determine its offset. +func NewFileWriter(file *os.File) (*Writer, error) { + info, err := file.Stat() + if err != nil { + return nil, err + } + goos, goarch, offset, err := Sniff(file, info.Size()) + if err != nil { + return nil, err + } + if err := file.Truncate(offset); err != nil { + return nil, err + } + _, err = file.Seek(0, io.SeekEnd) + if err != nil { + return nil, err + } + return NewWriter(file, offset, goos, goarch), nil +} + +// NewWriter returns a writer that may be used to append fatbin +// images to the writer w. The writer should be positioned at the end +// of the base binary image. +func NewWriter(w io.Writer, offset int64, goos, goarch string) *Writer { + return &Writer{z: zip.NewWriter(w), w: w, off: offset, goos: goos, goarch: goarch} +} + +// Create returns a Writer into which the image for the provided goos +// and goarch should be written. The image's contents must be written +// before the next call to Create or Close. +func (w *Writer) Create(goos, goarch string) (io.Writer, error) { + return w.z.Create(goos + "/" + goarch) +} + +// Flush flushes the unwritten data to the underlying file. +func (w *Writer) Flush() error { + return w.z.Flush() +} + +// Close should be called after all images have been written. No more +// images can be written after a call to Close. +func (w *Writer) Close() error { + if err := w.z.SetComment(w.goos + "/" + w.goarch); err != nil { + return err + } + if err := w.z.Close(); err != nil { + return err + } + _, err := writeFooter(w.w, w.off) + return err +} diff --git a/fatbin/fatbin.go b/fatbin/fatbin.go new file mode 100644 index 00000000..84951cbc --- /dev/null +++ b/fatbin/fatbin.go @@ -0,0 +1,293 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +// Package fatbin implements a simple fat binary format, and provides +// facilities for creating fat binaries and accessing its variants. +// +// A fatbin binary is a base binary with a zip archive appended, +// containing copies of the same binary targeted to different +// GOOS/GOARCH combinations. The zip archive contains one entry for +// each supported architecture and operating system combination. +// At the end of a fatbin image is a footer, storing the offset of the +// zip archive as well as a magic constant used to identify fatbin +// images: +// +// [8]offset[4]magic[8]checksum +// +// The checksum is a 64-bit xxhash checksum of the offset and +// magic fields. The magic value is 0x5758ba2c. +package fatbin + +import ( + "archive/zip" + "debug/elf" + "debug/macho" + "errors" + "fmt" + "io" + "io/ioutil" + "os" + "runtime" + "strings" + "sync" + + "github.com/grailbio/base/log" +) + +var ( + selfOnce sync.Once + self *Reader + selfErr error +) + +var ( + // ErrNoSuchImage is returned when the fatbin does not contain an + // image for the requested GOOS/GOARCH combination. + ErrNoSuchImage = errors.New("image does not exist") + // ErrCorruptedImage is returned when the fatbin image has been + // corrupted. + ErrCorruptedImage = errors.New("corrupted fatbin image") +) + +// Info provides information for an embedded binary. +type Info struct { + Goos, Goarch string + Size int64 +} + +func (info Info) String() string { + return fmt.Sprintf("%s/%s: %d", info.Goos, info.Goarch, info.Size) +} + +// Reader reads images from a fatbin. +type Reader struct { + self io.ReaderAt + goos, goarch string + offset int64 + + z *zip.Reader +} + +// Self reads the currently executing binary image as a fatbin and +// returns a reader to it. +func Self() (*Reader, error) { + selfOnce.Do(func() { + filename, err := os.Executable() + if err != nil { + selfErr = err + return + } + f, err := os.Open(filename) + if err != nil { + selfErr = err + return + } + info, err := f.Stat() + if err != nil { + selfErr = err + return + } + _, _, offset, err := Sniff(f, info.Size()) + if err != nil { + selfErr = err + return + } + self, selfErr = NewReader(f, offset, info.Size(), runtime.GOOS, runtime.GOARCH) + }) + return self, selfErr +} + +// OpenFile parses the provided ReaderAt with the provided size. The +// file's contents is parsed to determine the offset of the fatbin's +// archive. OpenFile returns an error if the file is not a fatbin. +func OpenFile(r io.ReaderAt, size int64) (*Reader, error) { + goos, goarch, offset, err := Sniff(r, size) + if err != nil { + return nil, err + } + return NewReader(r, offset, size, goos, goarch) +} + +// NewReader returns a new fatbin reader from the provided reader. +// The offset should be the offset of the fatbin archive; size is the +// total file size. The provided goos and goarch are that of the base +// binary. +func NewReader(r io.ReaderAt, offset, size int64, goos, goarch string) (*Reader, error) { + rd := &Reader{ + self: io.NewSectionReader(r, 0, offset), + goos: goos, + goarch: goarch, + offset: offset, + } + if offset == size { + return rd, nil + } + var err error + rd.z, err = zip.NewReader(io.NewSectionReader(r, offset, size-offset), size-offset) + if err != nil { + return nil, err + } + return rd, nil +} + +// GOOS returns the base binary GOOS. +func (r *Reader) GOOS() string { return r.goos } + +// GOARCH returns the base binary GOARCH. +func (r *Reader) GOARCH() string { return r.goarch } + +// List returns information about embedded binary images. +func (r *Reader) List() []Info { + infos := make([]Info, len(r.z.File)) + for i, f := range r.z.File { + elems := strings.SplitN(f.Name, "/", 2) + if len(elems) != 2 { + log.Error.Printf("invalid fatbin: found name %s", f.Name) + continue + } + infos[i] = Info{ + Goos: elems[0], + Goarch: elems[1], + Size: int64(f.UncompressedSize64), + } + } + return infos +} + +// Open returns a ReadCloser from which the binary with the provided +// goos and goarch can be read. Open returns ErrNoSuchImage if the +// fatbin does not contain an image for the requested goos and +// goarch. +func (r *Reader) Open(goos, goarch string) (io.ReadCloser, error) { + if goos == r.goos && goarch == r.goarch { + sr := io.NewSectionReader(r.self, 0, 1<<63-1) + return ioutil.NopCloser(sr), nil + } + + if r.z == nil { + return nil, ErrNoSuchImage + } + + look := goos + "/" + goarch + for _, f := range r.z.File { + if f.Name == look { + return f.Open() + } + } + return nil, ErrNoSuchImage +} + +// Stat returns the information for the image identified by the +// provided GOOS and GOARCH. It returns a boolean indicating +// whether the requested image was found. +func (r *Reader) Stat(goos, goarch string) (info Info, ok bool) { + info.Goos = goos + info.Goarch = goarch + if goos == r.goos && goarch == r.goarch { + info.Size = r.offset + ok = true + return + } + look := goos + "/" + goarch + for _, f := range r.z.File { + if f.Name == look { + info.Size = int64(f.UncompressedSize64) + ok = true + return + } + } + return +} + +func sectionEndAligned(s *elf.Section) int64 { + return int64(((s.Offset + s.FileSize) + (s.Addralign - 1)) & -s.Addralign) +} + +// Sniff sniffs a binary's goos, goarch, and fatbin offset. Sniff returns errors +// returned by the provided reader, or ErrCorruptedImage if the binary is identified +// as a fatbin image with a checksum mismatch. +func Sniff(r io.ReaderAt, size int64) (goos, goarch string, offset int64, err error) { + for _, s := range sniffers { + var ok bool + goos, goarch, ok = s(r) + if ok { + break + } + } + if goos == "" { + goos = "unknown" + } + if goarch == "" { + goarch = "unknown" + } + offset, err = readFooter(r, size) + if err == errNoFooter { + err = nil + offset = size + } + return +} + +type sniffer func(r io.ReaderAt) (goos, goarch string, ok bool) + +var sniffers = []sniffer{sniffElf, sniffMacho} + +func sniffElf(r io.ReaderAt) (goos, goarch string, ok bool) { + file, err := elf.NewFile(r) + if err != nil { + return + } + ok = true + switch file.OSABI { + default: + goos = "unknown" + case elf.ELFOSABI_NONE, elf.ELFOSABI_LINUX: + goos = "linux" + case elf.ELFOSABI_NETBSD: + goos = "netbsd" + case elf.ELFOSABI_OPENBSD: + goos = "openbsd" + } + switch file.Machine { + default: + goarch = "unknown" + case elf.EM_386: + goarch = "386" + case elf.EM_X86_64: + goarch = "amd64" + case elf.EM_ARM: + goarch = "arm" + case elf.EM_AARCH64: + goarch = "arm64" + } + return +} + +func sniffMacho(r io.ReaderAt) (goos, goarch string, ok bool) { + file, err := macho.NewFile(r) + if err != nil { + return + } + ok = true + // We assume mach-o is only used in Darwin. This is not exposed + // by the mach-o files. + goos = "darwin" + switch file.Cpu { + default: + goarch = "unknown" + case macho.Cpu386: + goarch = "386" + case macho.CpuAmd64: + goarch = "amd64" + case macho.CpuArm: + goarch = "arm" + case macho.CpuArm64: + goarch = "arm64" + case macho.CpuPpc: + goarch = "ppc" + case macho.CpuPpc64: + goarch = "ppc64" + } + return +} diff --git a/fatbin/fatbin_test.go b/fatbin/fatbin_test.go new file mode 100644 index 00000000..b47958a8 --- /dev/null +++ b/fatbin/fatbin_test.go @@ -0,0 +1,178 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package fatbin + +import ( + "bytes" + "io" + "io/ioutil" + "os" + "runtime" + "testing" +) + +func TestFatbin(t *testing.T) { + filename, err := os.Executable() + if err != nil { + t.Fatal(err) + } + body, err := ioutil.ReadFile(filename) + if err != nil { + t.Fatal(err) + } + + self, err := Self() + if err != nil { + t.Fatal(err) + } + r, err := self.Open(runtime.GOOS, runtime.GOARCH) + if err != nil { + t.Fatal(err) + } + defer r.Close() + embedded, err := ioutil.ReadAll(r) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(body, embedded) { + t.Error("content mismatch") + } + + info, ok := self.Stat(runtime.GOOS, runtime.GOARCH) + if !ok { + t.Fatal(runtime.GOOS, "/", runtime.GOARCH, ": not found") + } + if got, want := info.Size, int64(len(embedded)); got != want { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestFatbinNonExist(t *testing.T) { + self, err := Self() + if err != nil { + t.Fatal(err) + } + _, err = self.Open("nonexistent", "nope") + if got, want := err, ErrNoSuchImage; got != want { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestSniff(t *testing.T) { + filename, err := os.Executable() + if err != nil { + t.Fatal(err) + } + f, err := os.Open(filename) + if err != nil { + t.Fatal(err) + } + defer f.Close() + info, err := f.Stat() + if err != nil { + t.Fatal(err) + } + + goos, goarch, size, err := Sniff(f, info.Size()) + if err != nil { + t.Fatal(err) + } + if got, want := goarch, runtime.GOARCH; got != want { + t.Errorf("got %v, want %v", got, want) + } + if got, want := goos, runtime.GOOS; got != want { + t.Errorf("got %v, want %v", got, want) + } + if got, want := size, info.Size(); got != want { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestLinuxElf(t *testing.T) { + r := bytes.NewReader(svelteLinuxElfBinary) + goos, goarch, size, err := Sniff(r, int64(r.Len())) + if err != nil { + t.Fatal(err) + } + if got, want := goos, "linux"; got != want { + t.Errorf("got %v, want %v", got, want) + } + if got, want := goarch, "386"; got != want { + t.Errorf("got %v, want %v", got, want) + } + if got, want := size, len(svelteLinuxElfBinary); got != int64(want) { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestCreate(t *testing.T) { + f, err := ioutil.TempFile("", "") + must(t, err) + _, err = f.Write(svelteLinuxElfBinary) + must(t, err) + w, err := NewFileWriter(f) + must(t, err) + dw, err := w.Create("darwin", "amd64") + must(t, err) + _, err = dw.Write([]byte("darwin/amd64")) + must(t, err) + dw, err = w.Create("darwin", "386") + must(t, err) + _, err = dw.Write([]byte("darwin/386")) + must(t, err) + must(t, w.Close()) + info, err := f.Stat() + must(t, err) + r, err := OpenFile(f, info.Size()) + must(t, err) + + cases := []struct { + goos, goarch string + body []byte + }{ + {"linux", "386", svelteLinuxElfBinary}, + {"darwin", "amd64", []byte("darwin/amd64")}, + {"darwin", "386", []byte("darwin/386")}, + } + for _, c := range cases { + rc, err := r.Open(c.goos, c.goarch) + if err != nil { + t.Fatal(err) + } + mustBytes(t, rc, c.body) + must(t, rc.Close()) + info, ok := r.Stat(c.goos, c.goarch) + if !ok { + t.Error(c.goos, "/", c.goarch, ": not found") + continue + } + if got, want := info.Size, int64(len(c.body)); got != want { + t.Errorf("%s/%s: got %v, want %v", c.goos, c.goarch, got, want) + } + } + + _, err = r.Open("test", "nope") + if got, want := err, ErrNoSuchImage; got != want { + t.Errorf("got %v, want %v", got, want) + } +} + +func must(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } +} + +func mustBytes(t *testing.T, r io.Reader, want []byte) { + t.Helper() + got, err := ioutil.ReadAll(r) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(got, want) { + t.Errorf("got %s, want %s", got, want) + } +} diff --git a/fatbin/footer.go b/fatbin/footer.go new file mode 100644 index 00000000..f6a86321 --- /dev/null +++ b/fatbin/footer.go @@ -0,0 +1,51 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package fatbin + +import ( + "encoding/binary" + "errors" + "io" + + "github.com/cespare/xxhash" +) + +const ( + magic uint32 = 0x5758ba2c + headersz = 20 +) + +var ( + errNoFooter = errors.New("binary contains no footer") + + bin = binary.LittleEndian +) + +func writeFooter(w io.Writer, offset int64) (int, error) { + var p [headersz]byte + bin.PutUint64(p[:8], uint64(offset)) + bin.PutUint32(p[8:12], magic) + bin.PutUint64(p[12:20], xxhash.Sum64(p[:12])) + return w.Write(p[:]) +} + +func readFooter(r io.ReaderAt, size int64) (offset int64, err error) { + if size < headersz { + return 0, errNoFooter + } + var p [headersz]byte + _, err = r.ReadAt(p[:], size-headersz) + if err != nil { + return 0, err + } + if bin.Uint32(p[8:12]) != magic { + return 0, errNoFooter + } + offset = int64(bin.Uint64(p[:8])) + if xxhash.Sum64(p[:12]) != bin.Uint64(p[12:20]) { + return 0, ErrCorruptedImage + } + return +} diff --git a/fatbin/footer_test.go b/fatbin/footer_test.go new file mode 100644 index 00000000..bb77aa68 --- /dev/null +++ b/fatbin/footer_test.go @@ -0,0 +1,79 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package fatbin + +import ( + "bytes" + "io" + "testing" +) + +func TestReadWriteFooter(t *testing.T) { + for _, sz := range []int64{0, 12, 1e12, 1e13 + 4} { + var b bytes.Buffer + if _, err := writeFooter(&b, sz); err != nil { + t.Error(err) + continue + } + off, err := readFooter(bytes.NewReader(b.Bytes()), int64(b.Len())) + if err != nil { + t.Error(err) + continue + } + if got, want := off, sz; got != want { + t.Errorf("got %v, want %v", got, want) + } + + padded := paddedReaderAt{bytes.NewReader(b.Bytes()), int64(sz) * 100} + off, err = readFooter(padded, int64(sz)*100+int64(b.Len())) + if err != nil { + t.Error(err) + continue + } + if got, want := off, sz; got != want { + t.Errorf("got %v, want %v", got, want) + } + } +} + +func TestCorruptedFooter(t *testing.T) { + var b bytes.Buffer + if _, err := writeFooter(&b, 1234); err != nil { + t.Fatal(err) + } + n := b.Len() + for i := 0; i < n; i++ { + if i >= n-12 && i < n-8 { + continue //skip magic + } + p := make([]byte, b.Len()) + copy(p, b.Bytes()) + p[i]++ + _, err := readFooter(bytes.NewReader(p), int64(len(p))) + if got, want := err, ErrCorruptedImage; got != want { + t.Errorf("got %v, want %v", got, want) + } + } +} + +type paddedReaderAt struct { + io.ReaderAt + N int64 +} + +func (r paddedReaderAt) ReadAt(p []byte, off int64) (n int, err error) { + off -= r.N + for i := range p { + p[i] = 0 + } + switch { + case off < -int64(len(p)): + return len(p), nil + case off < 0: + p = p[-off:] + off = 0 + } + return r.ReaderAt.ReadAt(p, off) +} diff --git a/file/addfs/per_node.go b/file/addfs/per_node.go new file mode 100644 index 00000000..aed49b99 --- /dev/null +++ b/file/addfs/per_node.go @@ -0,0 +1,206 @@ +package addfs + +import ( + "context" + "fmt" + "io/fs" + "time" + + "github.com/grailbio/base/file/fsnode" + "github.com/grailbio/base/ioctx/fsctx" + "github.com/grailbio/base/log" +) + +type ( + // PerNodeFunc computes nodes to add to a directory tree, for example to present alternate views + // of raw data, expand archive files, etc. It operates on a single node at a time. If it returns + // any "addition" nodes, ApplyPerNodeFuncs will place them under a sibling directory called + // "...". For example, suppose we have an input directory: + // parent/ + // └─dir1/ + // ├─fileA + // ├─fileB + // └─dir2/ + // and we call ApplyPerNodeFuncs(parent/, ourFns). The resulting directory tree will be + // parent/ + // ├─.../ + // │ └─dir1/ + // │ └─[ nodes returned by PerNodeFunc.Apply(_, dir1/) for all ourFns ] + // └─dir1/ + // ├─.../ + // │ ├─fileA/ + // │ │ └─[ nodes returned by PerNodeFunc.Apply(_, fileA) for all ourFns ] + // │ ├─fileB/ + // │ │ └─[ nodes returned by PerNodeFunc.Apply(_, fileB) for all ourFns ] + // │ └─dir2/ + // │ └─[ nodes returned by PerNodeFunc.Apply(_, dir2/) for all ourFns ] + // ├─fileA + // ├─fileB + // └─dir2/ + // └─.../ + // Users browsing this resulting tree can work with just the original files and ourFns won't + // be invoked. However, they can also navigate into any of the .../s if interested and then + // use the additional views generated by ourFns. If they're interested in our_view for + // /path/to/a/file, they just need to prepend .../, like /path/to/a/.../file/our_view. + // (Perhaps it'd be more intuitive to "append", like /path/to/a/file/our_view, but then the + // file name would conflict with the view-containing directory.) + // + // Funcs that need to list the children of a fsnode.Parent should be careful: they may want to + // set an upper limit on number of entries to read, and otherwise default to empty, to avoid + // performance problems (resulting in bad UX) for very large directories. + // + // Funcs that simply look at filenames and declare derived outputs may want to place their + // children directly under /.../file/ for convenient access. However, Funcs that are expensive, + // for example reading some file contents, etc., may want to separate themselves under their own + // subdirectory, like .../file/func_name/. This lets users browsing the tree "opt-in" to seeing + // the results of the expensive computation by navigating to .../file/func_name/. + // + // If the input tree has any "..." that conflict with the added ones, the added ones override. + // The originals will simply not be accessible. + PerNodeFunc interface { + Apply(context.Context, fsnode.T) (adds []fsnode.T, _ error) + } + perNodeFunc func(context.Context, fsnode.T) (adds []fsnode.T, _ error) +) + +func NewPerNodeFunc(fn func(context.Context, fsnode.T) ([]fsnode.T, error)) PerNodeFunc { + return perNodeFunc(fn) +} +func (f perNodeFunc) Apply(ctx context.Context, n fsnode.T) ([]fsnode.T, error) { return f(ctx, n) } + +const addsDirName = "..." + +// perNodeImpl extends the original Parent with the .../ child. +type perNodeImpl struct { + original fsnode.Parent + fns []PerNodeFunc + adds fsnode.Parent +} + +var ( + _ fsnode.Parent = (*perNodeImpl)(nil) + _ fsnode.Cacheable = (*perNodeImpl)(nil) +) + +// ApplyPerNodeFuncs returns a new Parent that contains original's nodes plus any added by fns. +// See PerNodeFunc's for more documentation on how this works. +// Later fns's added nodes will overwrite earlier ones, if any names conflict. +func ApplyPerNodeFuncs(original fsnode.Parent, fns ...PerNodeFunc) fsnode.Parent { + fns = append([]PerNodeFunc{}, fns...) + adds := perNodeAdds{ + FileInfo: fsnode.CopyFileInfo(original.Info()). + WithName(addsDirName). + // ... directory is not writable. + WithModePerm(original.Info().Mode().Perm() & 0555), + original: original, + fns: fns, + } + return &perNodeImpl{original, fns, &adds} +} + +func (n *perNodeImpl) FSNodeT() {} +func (n *perNodeImpl) Info() fs.FileInfo { return n.original.Info() } +func (n *perNodeImpl) CacheableFor() time.Duration { return fsnode.CacheableFor(n.original) } +func (n *perNodeImpl) Child(ctx context.Context, name string) (fsnode.T, error) { + if name == addsDirName { + return n.adds, nil + } + child, err := n.original.Child(ctx, name) + if err != nil { + return nil, err + } + return perNodeRecurse(child, n.fns), nil +} +func (n *perNodeImpl) Children() fsnode.Iterator { + return fsnode.NewConcatIterator( + // TODO: Consider omitting .../ if the directory has no other children. + fsnode.NewIterator(n.adds), + // TODO: Filter out any conflicting ... to be consistent with Child. + fsnode.MapIterator(n.original.Children(), func(_ context.Context, child fsnode.T) (fsnode.T, error) { + return perNodeRecurse(child, n.fns), nil + }), + ) +} +func (n *perNodeImpl) AddChildLeaf(ctx context.Context, name string, flags uint32) (fsnode.Leaf, fsctx.File, error) { + return n.original.AddChildLeaf(ctx, name, flags) +} +func (n *perNodeImpl) AddChildParent(ctx context.Context, name string) (fsnode.Parent, error) { + p, err := n.original.AddChildParent(ctx, name) + if err != nil { + return nil, err + } + return ApplyPerNodeFuncs(p, n.fns...), nil +} +func (n *perNodeImpl) RemoveChild(ctx context.Context, name string) error { + return n.original.RemoveChild(ctx, name) +} + +// perNodeAdds is the .../ Parent. It has a child (directory) for each original child (both +// directories and files). The children contain the PerNodeFunc.Apply outputs. +type perNodeAdds struct { + fsnode.ParentReadOnly + fsnode.FileInfo + original fsnode.Parent + fns []PerNodeFunc +} + +var ( + _ fsnode.Parent = (*perNodeAdds)(nil) + _ fsnode.Cacheable = (*perNodeAdds)(nil) +) + +func (n *perNodeAdds) Child(ctx context.Context, name string) (fsnode.T, error) { + child, err := n.original.Child(ctx, name) + if err != nil { + return nil, err + } + return n.newAddsForChild(child), nil +} +func (n *perNodeAdds) Children() fsnode.Iterator { + // TODO: Filter out any conflicting ... to be consistent with Child. + return fsnode.MapIterator(n.original.Children(), func(_ context.Context, child fsnode.T) (fsnode.T, error) { + return n.newAddsForChild(child), nil + }) +} +func (n *perNodeAdds) FSNodeT() {} + +func (n *perNodeAdds) newAddsForChild(original fsnode.T) fsnode.Parent { + originalInfo := original.Info() + return fsnode.NewParent( + fsnode.NewDirInfo(originalInfo.Name()). + WithModTime(originalInfo.ModTime()). + // Derived directory must be executable to be usable, even if original file wasn't. + WithModePerm(originalInfo.Mode().Perm()|0111). + WithCacheableFor(fsnode.CacheableFor(original)), + fsnode.FuncChildren(func(ctx context.Context) ([]fsnode.T, error) { + adds := make(map[string]fsnode.T) + for _, fn := range n.fns { + fnAdds, err := fn.Apply(ctx, original) + if err != nil { + return nil, fmt.Errorf("addfs: error running func %v: %w", fn, err) + } + for _, add := range fnAdds { + name := add.Info().Name() + if _, exists := adds[name]; exists { + // TODO: Consider returning an error here. Or merging the added trees? + log.Error.Printf("addfs %s: conflict for added name: %s", originalInfo.Name(), name) + } + adds[name] = add + } + } + wrapped := make([]fsnode.T, 0, len(adds)) + for _, add := range adds { + wrapped = append(wrapped, perNodeRecurse(add, n.fns)) + } + return wrapped, nil + }), + ) +} + +func perNodeRecurse(node fsnode.T, fns []PerNodeFunc) fsnode.T { + parent, ok := node.(fsnode.Parent) + if !ok { + return node + } + return ApplyPerNodeFuncs(parent, fns...) +} diff --git a/file/addfs/per_node_test.go b/file/addfs/per_node_test.go new file mode 100644 index 00000000..355df8e6 --- /dev/null +++ b/file/addfs/per_node_test.go @@ -0,0 +1,119 @@ +package addfs + +import ( + "context" + "fmt" + "sort" + "strings" + "testing" + + "github.com/grailbio/base/file/fsnode" + . "github.com/grailbio/base/file/fsnode/fsnodetesting" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPerNodeFuncs(t *testing.T) { + ctx := context.Background() + root := func() Parent { + return Parent{ + "dir0": Parent{}, + "dir1": Parent{ + "dir10": Parent{ + "a": []byte("content dir10/a"), + "b": []byte("content dir10/b"), + }, + "a": []byte("content dir1/a"), + "b": []byte("content dir1/b"), + }, + } + } + t.Run("basic", func(t *testing.T) { + root := root() + n := MakeT(t, "", root).(fsnode.Parent) + n = ApplyPerNodeFuncs(n, + NewPerNodeFunc( + func(ctx context.Context, node fsnode.T) ([]fsnode.T, error) { + switch n := node.(type) { + case fsnode.Parent: + iter := n.Children() + defer func() { assert.NoError(t, iter.Close(ctx)) }() + children, err := fsnode.IterateAll(ctx, iter) + assert.NoError(t, err) + var names []string + for _, child := range children { + names = append(names, child.Info().Name()) + } + sort.Strings(names) + return []fsnode.T{ + fsnode.ConstLeaf(fsnode.NewRegInfo("children names"), []byte(strings.Join(names, ","))), + }, nil + case fsnode.Leaf: + return []fsnode.T{ + fsnode.ConstLeaf(fsnode.NewRegInfo("copy"), nil), // Will be overwritten. + }, nil + } + require.Failf(t, "invalid node type", "node: %T", node) + panic("unreachable") + }, + ), + NewPerNodeFunc( + func(ctx context.Context, node fsnode.T) ([]fsnode.T, error) { + switch n := node.(type) { + case fsnode.Parent: + return nil, nil + case fsnode.Leaf: + return []fsnode.T{ + fsnode.ConstLeaf(fsnode.NewRegInfo("copy"), LeafReadAll(ctx, t, n)), + }, nil + } + require.Failf(t, "invalid node type", "node: %T", node) + panic("unreachable") + }, + ), + ) + got := Walker{}.WalkContents(ctx, t, n) + want := Parent{ + "...": Parent{ + "dir0": Parent{"children names": []byte("")}, + "dir1": Parent{"children names": []byte("a,b,dir10")}, + }, + "dir0": Parent{ + "...": Parent{}, + }, + "dir1": Parent{ + "...": Parent{ + "dir10": Parent{"children names": []byte("a,b")}, + "a": Parent{"copy": []byte("content dir1/a")}, + "b": Parent{"copy": []byte("content dir1/b")}, + }, + "dir10": Parent{ + "...": Parent{ + "a": Parent{"copy": []byte("content dir10/a")}, + "b": Parent{"copy": []byte("content dir10/b")}, + }, + "a": []byte("content dir10/a"), + "b": []byte("content dir10/b"), + }, + "a": []byte("content dir1/a"), + "b": []byte("content dir1/b"), + }, + } + assert.Equal(t, want, got) + }) + t.Run("lazy", func(t *testing.T) { + root := root() + n := MakeT(t, "", root).(fsnode.Parent) + n = ApplyPerNodeFuncs(n, NewPerNodeFunc( + func(_ context.Context, node fsnode.T) ([]fsnode.T, error) { + return nil, fmt.Errorf("func was called: %q", node.Info().Name()) + }, + )) + got := Walker{ + IgnoredNames: map[string]struct{}{ + addsDirName: struct{}{}, + }, + }.WalkContents(ctx, t, n) + assert.Equal(t, root, got) + }) +} diff --git a/file/addfs/per_subtree.go b/file/addfs/per_subtree.go new file mode 100644 index 00000000..cc904222 --- /dev/null +++ b/file/addfs/per_subtree.go @@ -0,0 +1,7 @@ +package addfs + +// TODO: Implement PerSubtreeFunc. +// A PerNodeFunc is applied independently to each node in an entire directory tree. It may be +// useful to define funcs that are contextual. For example if an fsnode.Parent called base/ has a +// child called .git, we may want to define git-repository-aware views for each descendent node, +// like base/file/addfs/.../per_subtree.go/git/log.txt containing history. diff --git a/file/addfs/unzipfs/unzipfs.go b/file/addfs/unzipfs/unzipfs.go new file mode 100644 index 00000000..a08fe8f6 --- /dev/null +++ b/file/addfs/unzipfs/unzipfs.go @@ -0,0 +1,251 @@ +package unzipfs + +import ( + "archive/zip" + "compress/flate" + "context" + stderrors "errors" + "fmt" + "io" + "io/fs" + "path" + "runtime" + + "github.com/grailbio/base/errors" + "github.com/grailbio/base/file/addfs" + "github.com/grailbio/base/file/fsnode" + "github.com/grailbio/base/grail/biofs/biofseventlog" + "github.com/grailbio/base/ioctx" + "github.com/grailbio/base/ioctx/fsctx" + "github.com/grailbio/base/log" + "github.com/grailbio/base/morebufio" + "github.com/grailbio/base/sync/loadingcache" +) + +type unzipFunc struct{} + +// Func is an addfs.PerNodeFunc that presents zip file contents as a subdirectory tree. +// Users can access contents in .../myfile.zip/unzip/, for example. +// +// The file need not have extension .zip. Func.Apply reads the file header and if it's not +// a supported zip file the unzip/ directory is omitted. +var Func unzipFunc + +var _ addfs.PerNodeFunc = Func + +func (unzipFunc) Apply(ctx context.Context, node fsnode.T) ([]fsnode.T, error) { + zipLeaf, ok := node.(fsnode.Leaf) + if !ok { + return nil, nil + } + info := fsnode.NewDirInfo("unzip").WithCacheableFor(fsnode.CacheableFor(zipLeaf)) + parent, err := parentFromLeaf(ctx, info, zipLeaf) + if err != nil { + return nil, err + } + if parent == nil { + return nil, nil + } + return []fsnode.T{parent}, nil +} + +type readerHandle struct { + *zip.Reader + ioctx.Closer + leaf fsnode.Leaf +} + +func finalizeHandle(h *readerHandle) { + if err := h.Close(context.Background()); err != nil { + log.Error.Printf("unzipfs: error closing handle: %v", err) + } +} + +// parentFromLeaf opens zipLeaf to determine if it's a zip file and returns a Parent if so. +// Returns nil, nil in cases where the file is not supported (like, not a zip file). +// TODO: Consider exposing more public APIs like this fsnode.Leaf -> fsnode.Parent and/or +// *zip.Reader -> fsnode.Parent. +func parentFromLeaf(ctx context.Context, parentInfo fsnode.FileInfo, zipLeaf fsnode.Leaf) (fsnode.Parent, error) { + zipFile, err := fsnode.Open(ctx, zipLeaf) + if err != nil { + return nil, errors.E(err, "opening for unzip") + } + handle := readerHandle{Closer: zipFile, leaf: zipLeaf} + // TODO: More reliable/explicit cleanup. Refcount? + runtime.SetFinalizer(&handle, finalizeHandle) + info, err := zipFile.Stat(ctx) + if err != nil { + return nil, errors.E(err, "stat-ing for unzip") + } + rAt, ok := zipFile.(ioctx.ReaderAt) + if !ok { + log.Info.Printf("zipfs: random access not supported: %s, returning empty dir", zipLeaf.Info().Name()) + // TODO: Some less efficient fallback path? Try seeking? + return nil, nil + } + // Buffer makes header read much faster when underlying file is high latency, like S3. + // Of course there's a tradeoff where very small zip files (with much smaller headers) will + // not be read as lazily, but the speedup is significant for S3. + rAt = morebufio.NewReaderAtSize(rAt, 1024*1024) + handle.Reader, err = zip.NewReader(ioctx.ToStdReaderAt(ctx, rAt), info.Size()) + if err != nil { + if stderrors.Is(err, zip.ErrFormat) || + stderrors.Is(err, zip.ErrAlgorithm) || + stderrors.Is(err, zip.ErrChecksum) { + log.Info.Printf("zipfs: not a valid zip file: %s, returning empty dir", zipLeaf.Info().Name()) + return nil, nil + } + return nil, errors.E(err, "initializing zip reader") + } + return fsnode.NewParent(parentInfo, &handleChildGen{r: &handle, pathPrefix: "."}), nil +} + +type handleChildGen struct { + r *readerHandle + pathPrefix string + children loadingcache.Value +} + +func (g *handleChildGen) GenerateChildren(ctx context.Context) ([]fsnode.T, error) { + biofseventlog.UsedFeature("unzipfs.children") + var children []fsnode.T + err := g.children.GetOrLoad(ctx, &children, func(ctx context.Context, opts *loadingcache.LoadOpts) error { + entries, err := fs.ReadDir(g.r, g.pathPrefix) + if err != nil { + return err + } + children = make([]fsnode.T, len(entries)) + cacheFor := fsnode.CacheableFor(g.r.leaf) + for i, entry := range entries { + stat, err := entry.Info() // Immediate (no additional file read) as of go1.17. + if err != nil { + return errors.E(err, fmt.Sprintf("stat: %s", entry.Name())) + } + childInfo := fsnode.CopyFileInfo(stat).WithCacheableFor(cacheFor) + fullName := path.Join(g.pathPrefix, entry.Name()) + if entry.IsDir() { + children[i] = fsnode.NewParent(childInfo, &handleChildGen{r: g.r, pathPrefix: fullName}) + } else { + children[i] = zipFileLeaf{g.r, childInfo, fullName} + } + } + opts.CacheFor(cacheFor) + return nil + }) + if err != nil { + return nil, errors.E(err, fmt.Sprintf("listing path: %s", g.pathPrefix)) + } + return children, nil +} + +type zipFileLeaf struct { + r *readerHandle + fsnode.FileInfo + zipName string +} + +var _ fsnode.Leaf = (*zipFileLeaf)(nil) + +func (z zipFileLeaf) FSNodeT() {} + +type zipFileLeafFile struct { + info fsnode.FileInfo + + // semaphore guards all subsequent fields. It's used to serialize operations. + semaphore chan struct{} + // stdRAt translates context-less ReadAt requests into context-ful ones. We serialize Read + // requests (with semaphore) and then set stdRAt.Ctx temporarily to allow cancellation. + stdRAt ioctx.StdReaderAt + // stdRC wraps stdRAt. Its operations don't directly accept a context but are subject to + // cancellation indirectly via the inner stdRAt. + stdRC io.ReadCloser + // fileCloser cleans up. + fileCloser ioctx.Closer +} + +func (z zipFileLeaf) OpenFile(ctx context.Context, flag int) (fsctx.File, error) { + biofseventlog.UsedFeature("unzipfs.open") + var fileEntry *zip.File + for _, f := range z.r.File { + if f.Name == z.zipName { + fileEntry = f + break + } + } + if fileEntry == nil { + return nil, errors.E(errors.NotExist, + fmt.Sprintf("internal inconsistency: entry %q not found in zip metadata", z.zipName)) + } + dataOffset, err := fileEntry.DataOffset() + if err != nil { + return nil, errors.E(err, fmt.Sprintf("could not get data offset for %s", fileEntry.Name)) + } + var makeDecompressor func(r io.Reader) io.ReadCloser + switch fileEntry.Method { + case zip.Store: + // TODO: Consider returning a ReaderAt in this case for user convenience. + makeDecompressor = io.NopCloser + case zip.Deflate: + makeDecompressor = flate.NewReader + default: + return nil, errors.E(errors.NotSupported, + fmt.Sprintf("unsupported method: %d for: %s", fileEntry.Method, fileEntry.Name)) + } + zipFile, err := fsnode.Open(ctx, z.r.leaf) + if err != nil { + return nil, err + } + rAt, ok := zipFile.(ioctx.ReaderAt) + if !ok { + err := errors.E(errors.NotSupported, fmt.Sprintf("not ReaderAt: %v", zipFile)) + errors.CleanUpCtx(ctx, zipFile.Close, &err) + return nil, err + } + f := zipFileLeafFile{ + info: z.FileInfo, + semaphore: make(chan struct{}, 1), + stdRAt: ioctx.StdReaderAt{Ctx: ctx, ReaderAt: rAt}, + fileCloser: zipFile, + } + defer func() { f.stdRAt.Ctx = nil }() + f.stdRC = makeDecompressor( + io.NewSectionReader(&f.stdRAt, dataOffset, int64(fileEntry.CompressedSize64))) + return &f, nil +} + +func (f *zipFileLeafFile) Stat(context.Context) (fs.FileInfo, error) { return f.info, nil } + +func (f *zipFileLeafFile) Read(ctx context.Context, dst []byte) (int, error) { + select { + case f.semaphore <- struct{}{}: + defer func() { _ = <-f.semaphore }() + case <-ctx.Done(): + return 0, ctx.Err() + } + + f.stdRAt.Ctx = ctx + defer func() { f.stdRAt.Ctx = nil }() + return f.stdRC.Read(dst) +} + +func (f *zipFileLeafFile) Close(ctx context.Context) error { + select { + case f.semaphore <- struct{}{}: + defer func() { _ = <-f.semaphore }() + case <-ctx.Done(): + return ctx.Err() + } + + f.stdRAt.Ctx = ctx + defer func() { f.stdRAt = ioctx.StdReaderAt{} }() + var err error + if f.stdRC != nil { + errors.CleanUp(f.stdRC.Close, &err) + f.stdRC = nil + } + if f.fileCloser != nil { + errors.CleanUpCtx(ctx, f.fileCloser.Close, &err) + f.fileCloser = nil + } + return err +} diff --git a/file/addfs/unzipfs/unzipfs_test.go b/file/addfs/unzipfs/unzipfs_test.go new file mode 100644 index 00000000..63af50f4 --- /dev/null +++ b/file/addfs/unzipfs/unzipfs_test.go @@ -0,0 +1,210 @@ +package unzipfs + +import ( + "archive/zip" + "bytes" + "context" + "io" + "log" + "strings" + "sync" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/grailbio/base/file/fsnode" + . "github.com/grailbio/base/file/fsnode/fsnodetesting" + "github.com/grailbio/base/ioctx" + "github.com/grailbio/base/ioctx/fsctx" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParent(t *testing.T) { + ctx := context.Background() + baseTime := time.Unix(1_600_000_000, 0) + + var zipBytes bytes.Buffer + zipW := zip.NewWriter(&zipBytes) + + a0Info := fsnode.NewRegInfo("0.txt").WithModTime(baseTime).WithModePerm(0600) + a0Content := "a0" + addFile(t, zipW, "a/", &a0Info, a0Content, true) + + a00Info := fsnode.NewRegInfo("0.exe").WithModTime(baseTime.Add(time.Hour)).WithModePerm(0755) + a00Content := "a00" + addFile(t, zipW, "a/0/", &a00Info, a00Content, true) + + b0Info := fsnode.NewRegInfo("0.txt").WithModTime(baseTime.Add(2 * time.Hour)).WithModePerm(0644) + b0Content := "b0" + addFile(t, zipW, "b/", &b0Info, b0Content, false) + + topInfo := fsnode.NewRegInfo("0.txt").WithModTime(baseTime.Add(3 * time.Hour)).WithModePerm(0600) + topContent := "top" + addFile(t, zipW, "", &topInfo, topContent, false) + + require.NoError(t, zipW.Close()) + + parentInfo := fsnode.NewDirInfo("unzip") + parent, err := parentFromLeaf(ctx, parentInfo, fsnode.ConstLeaf(fsnode.NewRegInfo("zip"), zipBytes.Bytes())) + require.NotNil(t, parent) + require.NoError(t, err) + + walker := Walker{Info: true} + diff := cmp.Diff( + InfoT{parentInfo, Parent{ + "a": InfoT{ + fsnode.NewDirInfo("a"), + Parent{ + a0Info.Name(): InfoT{a0Info, Leaf([]byte(a0Content))}, + "0": InfoT{ + fsnode.NewDirInfo("0"), + Parent{ + a00Info.Name(): InfoT{a00Info, Leaf([]byte(a00Content))}, + }, + }, + }, + }, + "b": InfoT{ + fsnode.NewDirInfo("b"), + Parent{ + b0Info.Name(): InfoT{b0Info, Leaf([]byte(b0Content))}, + }, + }, + topInfo.Name(): InfoT{topInfo, Leaf([]byte(topContent))}, + }}, + walker.WalkContents(ctx, t, parent), + cmp.Comparer(func(a, b fsnode.FileInfo) bool { + a, b = a.WithSys(nil), b.WithSys(nil) + return a.Equal(b) + }), + ) + assert.Empty(t, diff) +} + +func addFile(t *testing.T, zipW *zip.Writer, prefix string, info *fsnode.FileInfo, content string, flate bool) { + *info = info.WithSize(int64(len(content))) + hdr, err := zip.FileInfoHeader(*info) + hdr.Name = prefix + info.Name() + if flate { + hdr.Method = zip.Deflate + } + require.NoError(t, err) + fw, err := zipW.CreateHeader(hdr) + require.NoError(t, err) + _, err = io.Copy(fw, strings.NewReader(content)) + require.NoError(t, err) +} + +func TestNonZip(t *testing.T) { + ctx := context.Background() + parent, err := parentFromLeaf(ctx, + fsnode.NewDirInfo("unzip"), + fsnode.ConstLeaf(fsnode.NewRegInfo("zip"), []byte("not zip"))) + require.NoError(t, err) + require.Nil(t, parent) +} + +func TestReadCancel(t *testing.T) { + ctx := context.Background() + + var zipBytes bytes.Buffer + zipW := zip.NewWriter(&zipBytes) + + fInfo := fsnode.NewRegInfo("f.txt") + // We need to make sure our reads below will exceed internal buffer sizes so we can control + // underlying blocking. Empirically this seems big enough but it may need to increase if + // there are internal changes (in flate, etc.) in the future. + fContent := strings.Repeat("a", 50*1024*1024) + addFile(t, zipW, "", &fInfo, fContent, true) + + require.NoError(t, zipW.Close()) + + // First we allow unblocked reads for zip headers. + zipLeaf := pausingLeaf{Leaf: fsnode.ConstLeaf(fsnode.NewRegInfo("zip"), zipBytes.Bytes())} + parent, err := parentFromLeaf(ctx, fsnode.NewDirInfo("unzip"), &zipLeaf) + require.NoError(t, err) + require.NotNil(t, parent) + children, err := fsnode.IterateAll(ctx, parent.Children()) + require.NoError(t, err) + require.Equal(t, 1, len(children)) + fLeaf := children[0].(fsnode.Leaf) + + f, err := fsnode.Open(ctx, fLeaf) + require.NoError(t, err) + + // Set up read blocking. + waitC := make(chan struct{}) + zipLeaf.mu.Lock() + zipLeaf.readAtWaitTwiceC = waitC + zipLeaf.mu.Unlock() + + var n int + b := make([]byte, 2) + readC := make(chan struct{}) + go func() { + defer close(readC) + n, err = f.Read(ctx, b) + }() + waitC <- struct{}{} // Let the read go through. + waitC <- struct{}{} + <-readC + require.NoError(t, err) + require.Equal(t, 2, n) + require.Equal(t, fContent[:2], string(b)) + + // Now start another read and let it reach ReadAt (first waitC send). + ctxCancel, cancel := context.WithCancel(ctx) + readC = make(chan struct{}) + go func() { + defer close(readC) + // Make sure this read will exhaust internal buffers (in flate, etc.) forcing a read from + // the pausingFile we control. + _, err = io.ReadAll(ioctx.ToStdReader(ctxCancel, f)) + }() + waitC <- struct{}{} + cancel() // Cancel the context instead of letting the read finish. + <-readC + // Make sure we get a cancellation error. + require.ErrorIs(t, err, context.Canceled) +} + +// pausingLeaf returns Files (from Open) that read one item C (which may block) at the start of +// each ReadAt operation. If C is nil, ReadAt's don't block. +type pausingLeaf struct { + fsnode.Leaf + mu sync.Mutex // mu guards readAtWaitTwiceC + // readAtWaitTwiceC controls ReadAt's blocking. If non-nil, ReadAt will read two values from + // this channel before returning. + readAtWaitTwiceC <-chan struct{} +} + +func (*pausingLeaf) FSNodeT() {} +func (p *pausingLeaf) OpenFile(ctx context.Context, flag int) (fsctx.File, error) { + f, err := fsnode.Open(ctx, p.Leaf) + return pausingFile{p, f}, err +} + +type pausingFile struct { + leaf *pausingLeaf + fsctx.File +} + +func (p pausingFile) ReadAt(ctx context.Context, dst []byte, off int64) (n int, err error) { + p.leaf.mu.Lock() + waitC := p.leaf.readAtWaitTwiceC + p.leaf.mu.Unlock() + if waitC != nil { + for i := 0; i < 2; i++ { + log.Printf("pausing: waiting %d", i) + select { + case <-waitC: + case <-ctx.Done(): + return 0, ctx.Err() + } + } + } else { + log.Printf("pausing: nil") + } + return p.File.(ioctx.ReaderAt).ReadAt(ctx, dst, off) +} diff --git a/file/doc.go b/file/doc.go index 9973f538..d84da52c 100644 --- a/file/doc.go +++ b/file/doc.go @@ -44,7 +44,7 @@ // // func init() { // file.RegisterImplementation("s3", s3file.NewImplementation( -// s3file.NewDefaultProvider(session.Options{}))) +// s3file.NewDefaultProvider())) // } // // // Caution: this code ignores all errors. diff --git a/file/file.go b/file/file.go index 062fff17..acb2550f 100644 --- a/file/file.go +++ b/file/file.go @@ -6,7 +6,11 @@ package file import ( "context" + "fmt" "io" + + "github.com/grailbio/base/errors" + "github.com/grailbio/base/ioctx" ) // File defines operations on a file. Implementations must be thread safe. @@ -26,48 +30,123 @@ type File interface { // Reader creates an io.ReadSeeker object that operates on the file. If // Reader() is called multiple times, they share the seek pointer. // + // For emphasis: these share state, which is different from OffsetReader! + // // REQUIRES: Close has not been called Reader(ctx context.Context) io.ReadSeeker + // OffsetReader creates a new, independent ioctx.ReadCloser, starting at + // offset. Unlike Reader, its position in the file is only modified by Read + // on this object. The returned object is not thread-safe, and callers are + // responsible for serializing all of their calls, including calling Close + // after all Reads are done. Of course, callers can use separate + // OffsetReaders in parallel. + // + // Background: This API reflects S3's performance characteristics, where + // initiating a new read position is relatively expensive, but then + // streaming data is fast (including in parallel with multiple readers). + // + // REQUIRES: Close has not been called + OffsetReader(offset int64) ioctx.ReadCloser + // Writer creates a writes that to the file. If Writer() is called multiple // times, they share the seek pointer. // // REQUIRES: Close has not been called Writer(ctx context.Context) io.Writer - // Close commits the contents of a written file, invalidating the - // File and all Readers and Writers created from the file. Exactly - // one of Discard or Close should be called. No other File or - // io.ReadSeeker, io.Writer methods shall be called after Close. - Close(ctx context.Context) error + // TODO: Introduce WriterAt() ioctx.WriterAt, analogous to ReaderAt. // Discard discards a file before it is closed, relinquishing any // temporary resources implied by pending writes. This should be // used if the caller decides not to complete writing the file. // Discard is a best-effort operation. Discard is not defined for // files opened for reading. Exactly one of Discard or Close should - // be called. No other File, io.ReadSeeker, or io.Writer methods + // be called. No other File, io.ReadSeeker, or io.Writer methods // shall be called after Discard. - Discard(ctx context.Context) error + Discard(ctx context.Context) + + // Closer commits the contents of a written file, invalidating the + // File and all Readers and Writers created from the file. Exactly + // one of Discard or Close should be called. No other File or + // io.ReadSeeker, io.Writer methods shall be called after Close. + Closer } -// NewErrorReader returns a new io.ReadSeeker object that returns "err" on any -// operation. -func NewErrorReader(err error) io.ReadSeeker { return &errorReaderWriter{err: err} } +// TODO: Migrate callers to use new location. +type Closer = ioctx.Closer -// NewErrorWriter returns a new io.Writer object that returns "err" on any operation. -func NewErrorWriter(err error) io.Writer { return &errorReaderWriter{err: err} } +// ETagged defines a getter for a file with an ETag. +type ETagged interface { + // ETag is an identifier assigned to a specific version of the file. + ETag() string +} + +// CloseAndReport returns a defer-able helper that calls f.Close and reports errors, if any, +// to *err. Pass your function's named return error. Example usage: +// +// func processFile(filename string) (_ int, err error) { +// ctx := context.Background() +// f, err := file.Open(ctx, filename) +// if err != nil { ... } +// defer file.CloseAndReport(ctx, f, &err) +// ... +// } +// +// If your function returns with an error, any f.Close error will be chained appropriately. +// +// Deprecated: Use errors.CleanUpCtx directly. +func CloseAndReport(ctx context.Context, f Closer, err *error) { + errors.CleanUpCtx(ctx, f.Close, err) +} -type errorReaderWriter struct{ err error } +// MustClose is a defer-able function that calls f.Close and panics on error. +// +// Example: +// ctx := context.Background() +// f, err := file.Open(ctx, filename) +// if err != nil { panic(err) } +// defer file.MustClose(ctx, f) +// ... +func MustClose(ctx context.Context, f Closer) { + if err := f.Close(ctx); err != nil { + if n, ok := f.(named); ok { + panic(fmt.Sprintf("close %s: %v", n.Name(), err)) + } + panic(err) + } +} -func (r *errorReaderWriter) Read([]byte) (int, error) { +type named interface { + // Name returns the path name given to file.Open or file.Create when this + // object was created. + Name() string +} + +// Error implements io.{Reader,Writer,Seeker,Closer}. It returns the given error +// to any call. +type Error struct{ err error } + +// NewError returns a new Error object that returns the given error to any +// Read/Write/Seek/Close call. +func NewError(err error) *Error { return &Error{err: err} } + +// Read implements io.Reader +func (r *Error) Read([]byte) (int, error) { return -1, r.err } -func (r *errorReaderWriter) Seek(int64, int) (int64, error) { +// Seek implements io.Seeker. +func (r *Error) Seek(int64, int) (int64, error) { return -1, r.err } -func (r *errorReaderWriter) Write([]byte) (int, error) { +// Write implements io.Writer. +func (r *Error) Write([]byte) (int, error) { return -1, r.err } + +// Close implements io.Closer. +func (r *Error) Close() error { + return r.err +} diff --git a/file/file_test.go b/file/file_test.go index 6b1779f4..4cfc1f7b 100644 --- a/file/file_test.go +++ b/file/file_test.go @@ -6,45 +6,57 @@ package file_test import ( "context" + "errors" "flag" "fmt" "io" "math/rand" "sync" "testing" + "time" - "github.com/aws/aws-sdk-go/aws/session" "github.com/grailbio/base/file" "github.com/grailbio/base/file/s3file" "github.com/grailbio/testutil" "github.com/grailbio/testutil/assert" ) -type testImpl struct{} +type errFile struct { + err error +} + +func (f *errFile) String() string { return f.err.Error() } -func (impl *testImpl) String() string { return "test" } -func (impl *testImpl) Open(ctx context.Context, path string) (file.File, error) { - return nil, fmt.Errorf("%v: open", path) +func (f *errFile) Open(ctx context.Context, path string, opts ...file.Opts) (file.File, error) { + return nil, f.err } -func (impl *testImpl) Create(ctx context.Context, path string) (file.File, error) { - return nil, fmt.Errorf("%v: create", path) +func (f *errFile) Create(ctx context.Context, path string, opts ...file.Opts) (file.File, error) { + return nil, f.err } -func (impl *testImpl) List(ctx context.Context, dir string, recursive bool) file.Lister { +func (f *errFile) List(ctx context.Context, dir string, recursive bool) file.Lister { return nil } -func (impl *testImpl) Stat(ctx context.Context, path string) (file.Info, error) { - return nil, fmt.Errorf("%v: stat", path) +func (f *errFile) Stat(ctx context.Context, path string, opts ...file.Opts) (file.Info, error) { + return nil, f.err +} + +func (f *errFile) Remove(ctx context.Context, path string) error { + return f.err } -func (impl *testImpl) Remove(ctx context.Context, path string) error { - return fmt.Errorf("%v: remove", path) +func (f *errFile) Presign(ctx context.Context, path, method string, expiry time.Duration) (string, error) { + return "", f.err +} + +func (f *errFile) Close(ctx context.Context) error { + return f.err } func TestRegistration(t *testing.T) { - testImpl := &testImpl{} + testImpl := &errFile{errors.New("test")} file.RegisterImplementation("foo", func() file.Implementation { return testImpl }) assert.True(t, file.FindImplementation("") != nil) assert.True(t, file.FindImplementation("foo") == testImpl) @@ -176,19 +188,23 @@ func ExampleJoin() { fmt.Println(file.Join("foo", "bar")) fmt.Println(file.Join("foo", "")) fmt.Println(file.Join("foo", "/bar/")) + fmt.Println(file.Join(".", "foo:bar")) fmt.Println(file.Join("s3://foo")) fmt.Println(file.Join("s3://foo", "/bar/")) fmt.Println(file.Join("s3://foo", "", "bar")) - fmt.Println(file.Join("s3://foo/", "/", "/bar")) + fmt.Println(file.Join("s3://foo", "0")) + fmt.Println(file.Join("s3://foo", "abc")) fmt.Println(file.Join("s3://foo//bar", "/", "/baz")) // Output: // foo/bar // foo // foo/bar + // ./foo:bar // s3://foo // s3://foo/bar // s3://foo/bar - // s3://foo/bar + // s3://foo/0 + // s3://foo/abc // s3://foo//bar/baz } @@ -208,7 +224,7 @@ func initBenchmark() { once.Do(func() { file.RegisterImplementation("s3", func() file.Implementation { - return s3file.NewImplementation(s3file.NewDefaultProvider(session.Options{}), s3file.Options{}) + return s3file.NewImplementation(s3file.NewDefaultProvider(), s3file.Options{}) }) }) } diff --git a/file/filebench/bigmachine.go b/file/filebench/bigmachine.go new file mode 100644 index 00000000..f709d0b0 --- /dev/null +++ b/file/filebench/bigmachine.go @@ -0,0 +1,174 @@ +package filebench + +import ( + "bytes" + "context" + "encoding/gob" + "fmt" + "io" + "os" + "os/exec" + "path" + "time" + + "github.com/grailbio/base/errors" + "github.com/grailbio/base/must" + "github.com/grailbio/base/traverse" + "github.com/grailbio/bigmachine" +) + +// Bigmachine configures a cluster of remote machines to each execute benchmarks and then report +// their results. +type Bigmachine struct { + // Bs is the collection of bigmachines in which to run benchmarks. EC2 instance type is a + // property of *bigmachine.B, so this lets callers benchmark several EC2 instance types. + // (*B).Name() is printed to identify resutls. Caller remains responsible for shutdown. + Bs []*bigmachine.B + Environ bigmachine.Environ + Services bigmachine.Services +} + +// NewBigmachine returns a new configuration, ready for callers to configure. Callers likely want +// to add bigmachines for remote execution (otherwise it just falls back to local). +// Or, they may add environment variables or services for AWS credentials. +func NewBigmachine(rs ReadSizes) Bigmachine { + return Bigmachine{ + Services: bigmachine.Services{ + "FileBench": benchService{rs}, + }, + } +} + +// RunAndPrint starts a machine in each d.Bs and then executes ReadSizes.RunAndPrint on it. +// It writes all the machine results to out, identifying each section by d.Bs's keys. +func (d Bigmachine) RunAndPrint( + ctx context.Context, + out io.Writer, + pathPrefixes []Prefix, + pathSuffix0 string, + pathSuffixes ...string, +) error { + var results = make([]string, len(d.Bs)) + + err := traverse.Each(len(d.Bs), func(bIdx int) error { + b := d.Bs[bIdx] + machines, err := b.Start(ctx, 1, d.Environ, d.Services) + if err != nil { + return err + } + machine := machines[0] + + // Benchmark runs have encountered some throttling from S3 (503 SlowDown). Sleep a bit + // to separate out the benchmark runs, so that each benchmarking machine is likely to + // locate some distinct S3 remote IPs. (Due to VPC DNS caching, if all the machines start + // simultaneously, they're likely to use the same S3 peers.) Of course, this introduces + // systematic bias in comparing results between machines, but we accept that for now. + time.Sleep(time.Minute * time.Duration(bIdx)) + + return machine.Call(ctx, "FileBench.Run", + benchRequest{pathPrefixes, pathSuffix0, pathSuffixes}, + &results[bIdx]) + }) + + for bIdx, result := range results { + if result == "" { + continue + } + if bIdx > 0 { + _, err := fmt.Fprintln(out) + must.Nil(err) + } + _, err := fmt.Fprintf(out, "[%d] %s\n%s", bIdx, d.Bs[bIdx].Name(), result) + must.Nil(err) + } + return err +} + +type ( + benchService struct{ ReadSizes } + benchRequest struct { + PathPrefixes []Prefix + PathSuffix0 string + PathSuffixes []string + } + + fuseService struct{} +) + +func init() { + gob.Register(benchService{}) + gob.Register(fuseService{}) +} + +func (s benchService) Run(ctx context.Context, req benchRequest, out *string) error { + var buf bytes.Buffer + s.ReadSizes.RunAndPrint(ctx, &buf, req.PathPrefixes, req.PathSuffix0, req.PathSuffixes...) + *out = buf.String() + return nil +} + +// AddS3FUSE configures d so that each machine running benchmarks can access S3 objects through +// the local filesystem, at mountPath. For example, object s3://b/my/key will appear at +// $mountPath/b/my/key. Callers can use this to construct paths for RunAndPrint. +func (d Bigmachine) AddS3FUSE() (mountPath string) { + must.True(len(s3FUSEBinary) > 0) + d.Services["s3FUSE"] = fuseService{} + return s3FUSEPath +} + +const s3FUSEPath = "/tmp/s3" + +func (fuseService) Init(*bigmachine.B) (err error) { + defer func() { + if err != nil { + err = errors.E(err, errors.Fatal) + } + }() + if err := os.MkdirAll(s3FUSEPath, 0700); err != nil { + return err + } + ents, err := os.ReadDir(s3FUSEPath) + if err != nil { + return err + } + if len(ents) > 0 { + return errors.New("s3 fuse mount is non-empty") + } + tmpDir, err := os.MkdirTemp("", "s3fuse-*") + if err != nil { + return err + } + exe := path.Join(tmpDir, "s3fuse") + if err := os.WriteFile(exe, s3FUSEBinary, 0700); err != nil { + return err + } + cmdErrC := make(chan error) + go func() { + out, err := exec.Command(exe, s3FUSEPath).CombinedOutput() + if err == nil { + err = errors.E("s3fuse exited unexpectedly") + } + cmdErrC <- errors.E(err, fmt.Sprintf("s3fuse output:\n%s", out)) + }() + readDirC := make(chan error) + go func() { + for { + ents, err := os.ReadDir(s3FUSEPath) + if err != nil { + readDirC <- err + return + } + if len(ents) > 0 { + readDirC <- nil + } + time.Sleep(time.Second) + } + }() + select { + case err = <-cmdErrC: + case err = <-readDirC: + case <-time.After(10 * time.Second): + err = errors.New("ran out of time waiting for FUSE mount") + } + return err +} diff --git a/file/filebench/embed_bazel.go b/file/filebench/embed_bazel.go new file mode 100644 index 00000000..b5b671fb --- /dev/null +++ b/file/filebench/embed_bazel.go @@ -0,0 +1,8 @@ +//go:build bazel + +package filebench + +import _ "embed" + +//go:embed s3fuse_binary +var s3FUSEBinary []byte diff --git a/file/filebench/embed_nobazel.go b/file/filebench/embed_nobazel.go new file mode 100644 index 00000000..ba07db4f --- /dev/null +++ b/file/filebench/embed_nobazel.go @@ -0,0 +1,5 @@ +//go:build !bazel + +package filebench + +var s3FUSEBinary []byte diff --git a/file/filebench/filebench.go b/file/filebench/filebench.go new file mode 100644 index 00000000..580612b1 --- /dev/null +++ b/file/filebench/filebench.go @@ -0,0 +1,307 @@ +package filebench + +import ( + "context" + "fmt" + "io" + "log" + "math/rand" + "sort" + "strings" + "sync/atomic" + "text/tabwriter" + "time" + + "github.com/grailbio/base/file" + "github.com/grailbio/base/ioctx" + "github.com/grailbio/base/must" + "github.com/grailbio/base/traverse" +) + +// ReadSizes are the parameters for a benchmark run. +type ReadSizes struct { + ChunkBytes []int + ContiguousChunks []int + MaxReadBytes int + MaxReplicates int +} + +// ReplicateTargetBytes limits the number of replicates of a single benchmark condition. +const ReplicateTargetBytes int = 1e9 + +// DefaultReadSizes constructs ReadSizes with the default range of parameters. +func DefaultReadSizes() ReadSizes { + return ReadSizes{ + ChunkBytes: []int{ + 1 << 10, + 1 << 20, + 1 << 23, + 1 << 24, + 1 << 25, + 1 << 27, + 1 << 29, + 1 << 30, + 1 << 32, + }, + ContiguousChunks: []int{ + 1, + 1 << 3, + 1 << 6, + 1 << 9, + }, + MaxReadBytes: 1 << 32, + MaxReplicates: 10, + } +} + +func (r ReadSizes) MinFileSize() int { + size := maxInts(r.ChunkBytes) * maxInts(r.ContiguousChunks) + if size < r.MaxReadBytes { + return size + } + return r.MaxReadBytes +} + +func (r ReadSizes) sort() { + must.True(len(r.ChunkBytes) > 0) + must.True(len(r.ContiguousChunks) > 0) + sort.Ints(r.ChunkBytes) + sort.Ints(r.ContiguousChunks) +} + +type Prefix struct { + Path string + // MaxReadBytes optionally overrides ReadSizes.MaxReadBytes (only to become smaller). + // Useful if one prefix (like FUSE) is slower than others. + MaxReadBytes int +} + +// RunAndPrint executes the benchmark cases and prints a human-readable summary to out. +// pathPrefixes is typically s3:// or a FUSE mount point. Results are reported for each one. +// pathSuffix* are at least one S3-relative path (like bucket/some/file.txt) to a large file to read +// during benchmarking. If there are multiple, reads are spread across them (not multiplied for each +// suffix). Caller may want to pass multiple to try to reduce throttling when several benchmark +// tasks are running in parallel (see Bigmachine.RunAndPrint). +func (r ReadSizes) RunAndPrint( + ctx context.Context, + out io.Writer, + pathPrefixes []Prefix, + pathSuffix0 string, + pathSuffixes ...string, +) { + minFileSize := r.MinFileSize() + r.sort() // Make sure table is easy to read. + + pathSuffixes = append([]string{pathSuffix0}, pathSuffixes...) + type fileOption struct { + file.File + Info file.Info + } + files := make([][]fileOption, len(pathPrefixes)) + for prefixIdx, prefix := range pathPrefixes { + files[prefixIdx] = make([]fileOption, len(pathSuffixes)) + for suffixIdx, suffix := range pathSuffixes { + f, err := file.Open(ctx, file.Join(prefix.Path, suffix)) + must.Nil(err) + defer func() { must.Nil(f.Close(ctx)) }() + o := &files[prefixIdx][suffixIdx] + o.File = f + + o.Info, err = f.Stat(ctx) + must.Nil(err) + must.True(o.Info.Size() >= int64(minFileSize), "file too small", f.Name()) + } + } + + type ( + condition struct { + prefixIdx, chunkBytesIdx, contiguousChunksIdx int + parallel bool + } + result struct { + totalBytes int + totalTime time.Duration + } + ) + var ( + tasks []condition + results = make([][][][]result, len(pathPrefixes)) + ) + for prefixIdx, prefix := range pathPrefixes { + results[prefixIdx] = make([][][]result, len(r.ChunkBytes)) + for chunkBytesIdx, chunkBytes := range r.ChunkBytes { + results[prefixIdx][chunkBytesIdx] = make([][]result, len(r.ContiguousChunks)) + for contiguousChunksIdx, contiguousChunks := range r.ContiguousChunks { + results[prefixIdx][chunkBytesIdx][contiguousChunksIdx] = make([]result, 2) + totalReadBytes := chunkBytes * contiguousChunks + maxReadBytes := r.MaxReadBytes + if 0 < prefix.MaxReadBytes && prefix.MaxReadBytes < maxReadBytes { + maxReadBytes = prefix.MaxReadBytes + } + if totalReadBytes > maxReadBytes { + continue + } + replicates := 1 + if totalReadBytes < ReplicateTargetBytes { + replicates = (ReplicateTargetBytes - 1 + totalReadBytes) / totalReadBytes + if replicates > r.MaxReplicates { + replicates = r.MaxReplicates + } + } + for _, parallel := range []bool{false, true} { + for ri := 0; ri < replicates; ri++ { + tasks = append(tasks, condition{ + prefixIdx: prefixIdx, + chunkBytesIdx: chunkBytesIdx, + contiguousChunksIdx: contiguousChunksIdx, + parallel: parallel, + }) + } + } + } + } + } + + var ( + reproducibleRandom = rand.New(rand.NewSource(1)) + ephemeralRandom = rand.New(rand.NewSource(time.Now().UnixNano())) + ) + // While benchmarking is running, it's easy to compare the current task index from different + // benchmarking machines to judge their relative progress. + reproducibleRandom.Shuffle(len(tasks), func(i, j int) { + tasks[i], tasks[j] = tasks[j], tasks[i] + }) + + var ( + currentTaskIdx int32 + cancelled = make(chan struct{}) + ) + go func() { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + taskIdx := atomic.LoadInt32(¤tTaskIdx) + c := tasks[taskIdx] + prefix := pathPrefixes[c.prefixIdx] + chunkBytes := r.ChunkBytes[c.chunkBytesIdx] + contiguousChunks := r.ContiguousChunks[c.contiguousChunksIdx] + log.Printf("done %d of %d tasks, current: %dB * %d on %s", + taskIdx, len(tasks), chunkBytes, contiguousChunks, prefix.Path) + case <-cancelled: + break + } + } + }() + defer close(cancelled) + + dst := make([]byte, r.MaxReadBytes) + for taskIdx, c := range tasks { + atomic.StoreInt32(¤tTaskIdx, int32(taskIdx)) + + chunkBytes := r.ChunkBytes[c.chunkBytesIdx] + contiguousChunks := r.ContiguousChunks[c.contiguousChunksIdx] + + // Vary read locations non-reproducibly to try to spread load and avoid S3 throttling. + // There's a tradeoff here: we're also likely introducing variance in benchmark results + // if S3 read performance varies between objects and over time, which it probably does [1]. + // For now, empirically, it seems like throttling is the bigger problem, especially because + // our benchmark runs are relatively brief (compared to large batch workloads) and thus + // significantly affected by some throttling. We may revisit this in the future if a + // different choice helps make the benchmark a better guide for optimization. + // + // [1] https://web.archive.org/web/20221220192142/https://docs.aws.amazon.com/AmazonS3/latest/userguide/optimizing-performance.html + f := files[c.prefixIdx][ephemeralRandom.Intn(len(pathSuffixes))] + offset := ephemeralRandom.Int63n(f.Info.Size() - int64(chunkBytes*contiguousChunks) + 1) + + parIdx := 0 + start := time.Now() + func() { + var ( + traverser traverse.T + chunks = make([]struct { + r io.Reader + dst []byte + }, contiguousChunks) + ) + if c.parallel { + parIdx = 1 + for i := range chunks { + chunkOffset := i * chunkBytes + rc := f.OffsetReader(offset + int64(chunkOffset)) + defer func() { must.Nil(rc.Close(ctx)) }() + chunks[i].r = ioctx.ToStdReader(ctx, rc) + chunks[i].dst = dst[chunkOffset : chunkOffset+chunkBytes] + } + } else { + traverser.Limit = 1 + rc := ioctx.ToStdReadCloser(ctx, f.OffsetReader(offset)) + defer func() { must.Nil(rc.Close()) }() + for i := range chunks { + chunks[i].r = rc + chunks[i].dst = dst[:chunkBytes] + } + } + _ = traverser.Each(contiguousChunks, func(i int) error { + n, err := io.ReadFull(chunks[i].r, chunks[i].dst) + must.Nil(err) + must.True(n == chunkBytes) + return nil + }) + }() + elapsed := time.Since(start) + + results[c.prefixIdx][c.chunkBytesIdx][c.contiguousChunksIdx][parIdx].totalBytes += chunkBytes * contiguousChunks + results[c.prefixIdx][c.chunkBytesIdx][c.contiguousChunksIdx][parIdx].totalTime += elapsed + } + + tw := tabwriter.NewWriter(out, 0, 4, 4, ' ', 0) + mustPrintf := func(format string, args ...interface{}) { + _, err := fmt.Fprintf(tw, format, args...) + must.Nil(err) + } + mustPrintf("\t") + for _, prefix := range pathPrefixes { + mustPrintf("%s%s", prefix.Path, strings.Repeat("\t", 2*len(r.ContiguousChunks))) + } + mustPrintf("\n") + for range files { + for _, parLabel := range []string{"", "p"} { + for _, contiguousChunks := range r.ContiguousChunks { + mustPrintf("\t%s%d", parLabel, contiguousChunks) + } + } + } + mustPrintf("\n") + for chunkBytesIdx, chunkBytes := range r.ChunkBytes { + mustPrintf("%d", chunkBytes/(1<<20)) + for prefixIdx := range files { + for _, parIdx := range []int{0, 1} { + for contiguousChunksIdx := range r.ContiguousChunks { + s := results[prefixIdx][chunkBytesIdx][contiguousChunksIdx][parIdx] + mustPrintf("\t") + if s.totalTime > 0 { + mibs := float64(s.totalBytes) / s.totalTime.Seconds() / float64(1<<20) + mustPrintf("%.f", mibs) + } + } + } + } + mustPrintf("\n") + } + must.Nil(tw.Flush()) +} + +func maxInts(ints []int) int { + if len(ints) == 0 { + return 0 // OK for our purposes. + } + max := ints[0] + for _, i := range ints[1:] { + if i > max { + max = i + } + } + return max +} diff --git a/file/filebench/s3fuse/main.go b/file/filebench/s3fuse/main.go new file mode 100644 index 00000000..b23668c0 --- /dev/null +++ b/file/filebench/s3fuse/main.go @@ -0,0 +1,27 @@ +package main + +import ( + "os" + + "github.com/grailbio/base/file" + "github.com/grailbio/base/file/fsnodefuse" + "github.com/grailbio/base/file/gfilefs" + "github.com/grailbio/base/file/s3file" + "github.com/grailbio/base/must" + "github.com/hanwen/go-fuse/v2/fs" + "github.com/hanwen/go-fuse/v2/fuse" +) + +func main() { + mount := os.Args[1] + file.RegisterImplementation("s3", func() file.Implementation { + return s3file.NewImplementation(s3file.NewDefaultProvider(), s3file.Options{}) + }) + root := fsnodefuse.NewRoot(gfilefs.New("s3://", "s3")) + mountOpts := fuse.MountOptions{FsName: "s3"} + fsnodefuse.ConfigureDefaultMountOptions(&mountOpts) + fsnodefuse.ConfigureRequiredMountOptions(&mountOpts) + server, err := fs.Mount(mount, root, &fs.Options{MountOptions: mountOpts}) + must.Nil(err) + server.Wait() +} diff --git a/file/filebench/snapshot.txt b/file/filebench/snapshot.txt new file mode 100644 index 00000000..df60fc06 --- /dev/null +++ b/file/filebench/snapshot.txt @@ -0,0 +1,108 @@ +For reference, whoever's changing base/file code may occasionally run this benchmark and +update the snapshot below. It can be useful for showing code reviewers the result of a change, +or just so readers can get a sense of performance without running the benchmarks themselves. + +Of course, since we're not totally controlling the environment or the data, be careful to +set a baseline before evaluating your change. + +Some context for the numbers below: + * S3 performance guidelines [1] suggest that each request should result in around 85–90 MB/s + of read throughput. Our numbers are MiB/s (not sure if they mean M = 1000^2 or Mi = 1024^2) and + it looks like our sequential reads ramp up to the right vicinity. + * EC2 documentation offers network performance expectations (Gbps): + * m5.x: 1.25 (base) - 10 (burst) + * m5.4x: 5 - 10 + * m5.12x: 12 + * m5.24x: 25 + * m5n.24x: 100 + Note that a 1000 in the table below is MiB/s which is 8*1.024^2 ~= 8.4 Gbps. + +[1] https://docs.aws.amazon.com/AmazonS3/latest/userguide/optimizing-performance-design-patterns.html#optimizing-performance-parallelization +[2] https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/general-purpose-instances.html#general-purpose-network-performance + +Very brief summary and speculations about our current S3 performance for non-parallel(!) clients: + * Could be improved for small sequential reads (few MiB at a time) in the current io.Reader + model by adding (read-ahead) buffering (basically speculating about future reads). + * Can't get much better for large reads (GiBs) with large buffers (100s MiBs) on small instance + types because we're getting close to documented network performance. + * Can probably get better for large reads with large buffers on large, network-optimized (*n) + instances where we're not especially close to network limits yet. Maybe this requires more + careful CPU and allocation usage. +Parallel clients might matter more: + * Clients reading many small files in parallel (like machine learning training reading V3 + fragments file columns) may already do chunked reads pretty effectively. We should measure to + verify, though; maybe the chunks are not right-sized. + * Clients reading very large files may benefit even more from a parallel-friendly API (ReaderAt): + * Copies to disk could be accelerated without huge memory buffers. + Examples: `grail-file cp`, biofs/gfilefs, reflow interning, bigslice shuffling. + * Copies to memory could be accelerated without duplicated memory buffers: + Examples: bio/align/mapper's index (if it was Go), FASTA. + +Results: +Rows are read chunk sizes (MiB), columns are number of sequential chunks. +Data cells are average MiB/s for the entire sequential read. + +Note: During development, the results seemed worse on some days, peaking around 75 MiB/s rather +than 95 MiB/s. It was consistent for several runs within that day. Unclear if the cause was +something local in the benchmarking machines (or in our builds) or the S3 service. + +The numbers below were generated using the bigmachine runner in us-west-2 reading a GRAIL-internal +113 GiB file residing in an S3 bucket in us-west-2. The first columns are reading directly from S3; +the next ones are through FUSE. + +A couple of the FUSE numbers seem surprisingly high. I suspect this is due to parts of reads being +served by the page cache, due to randomly overlapping with earlier benchmark tasks. We could try +to confirm this and clear the page cache in the future, though for now it's just useful to alert +about issues causing widespread slowness. + +[0] m5.4xlarge + s3:// /tmp/s3 + 1 8 64 512 p1 p8 p64 p512 1 8 64 512 p1 p8 p64 p512 +0 0 0 0 3 0 0 0 1 0 0 0 2 0 0 0 2 +1 4 24 41 46 5 34 198 142 5 26 41 74 4 30 167 417 +8 22 37 41 44 21 140 620 942 23 44 65 22 112 756 +16 35 46 49 35 203 867 34 48 28 182 +32 55 50 65 51 317 999 40 51 43 267 +128 177 245 217 960 51 53 +512 728 415 447 1075 48 38 +1024 855 839 +4096 1077 1025 + +[1] m5.12xlarge + s3:// /tmp/s3 + 1 8 64 512 p1 p8 p64 p512 1 8 64 512 p1 p8 p64 p512 +0 0 0 0 2 0 0 0 1 0 0 0 2 0 0 0 2 +1 5 27 45 59 4 34 213 648 5 26 50 71 4 34 213 537 +8 29 47 41 50 27 142 823 1209 29 45 83 29 152 732 +16 37 53 64 32 165 822 31 48 28 230 +32 74 64 84 65 346 1258 45 57 50 236 +128 231 181 202 854 55 46 +512 360 615 541 1297 52 57 +1024 1000 1076 +4096 1297 1280 + +[2] m5.24xlarge + s3:// /tmp/s3 + 1 8 64 512 p1 p8 p64 p512 1 8 64 512 p1 p8 p64 p512 +0 0 0 0 2 0 0 0 1 0 0 0 2 0 0 0 2 +1 5 26 46 52 5 37 188 492 3 30 50 69 4 30 170 661 +8 31 46 52 50 28 169 897 2119 25 50 62 27 158 811 +16 41 54 54 37 166 1365 36 50 39 208 +32 66 83 29 55 279 1873 42 69 44 282 +128 168 199 182 1224 54 52 +512 555 643 495 2448 59 55 +1024 789 907 +4096 2395 2410 + +[3] m5n.24xlarge + s3:// /tmp/s3 + 1 8 64 512 p1 p8 p64 p512 1 8 64 512 p1 p8 p64 p512 +0 0 0 0 2 0 0 0 1 0 0 0 2 0 0 0 1 +1 4 28 53 55 5 32 214 954 4 28 50 52 4 31 188 849 +8 24 44 60 55 24 165 865 2811 25 42 43 26 144 788 +16 38 55 62 43 181 992 38 52 39 202 +32 55 80 64 60 314 2407 42 59 48 283 +128 171 179 190 1005 56 51 +512 462 549 469 4068 56 70 +1024 1343 821 +4096 2921 3010 diff --git a/file/fsnode/fileinfo.go b/file/fsnode/fileinfo.go new file mode 100644 index 00000000..c9aba2fe --- /dev/null +++ b/file/fsnode/fileinfo.go @@ -0,0 +1,113 @@ +package fsnode + +import ( + "os" + "time" +) + +// FileInfo implements os.FileInfo. Instances are immutable but convenient +// copy-and-set methods are provided for some fields. FileInfo implements +// (T).Info, so implementations of T can conveniently embed a FileInfo for +// simple cases, e.g. if the information is immutable. +type FileInfo struct { + name string + size int64 + mode os.FileMode + mod time.Time + sys interface{} + cacheableFor time.Duration +} + +// Info implements (T).Info. +func (fi FileInfo) Info() os.FileInfo { + return fi +} + +// NewDirInfo constructs FileInfo for a directory. +// Default ModePerm is 0555 (r-xr-xr-x). Other defaults are zero. +func NewDirInfo(name string) FileInfo { return FileInfo{name: name, mode: os.ModeDir | 0555} } + +// NewRegInfo constructs FileInfo for a regular file. +// Default ModePerm is 0444 (r--r--r--). Other defaults are zero. +func NewRegInfo(name string) FileInfo { return FileInfo{name: name, mode: 0444} } + +// NewSymlinkInfo constructs FileInfo for a symlink. +// +// Create a symlink by using this FileInfo with a Leaf whose contents are the target path. +// The path may be relative or absolute. +func NewSymlinkInfo(name string) FileInfo { + return FileInfo{ + name: name, + // Note: Symlinks don't need permissions. From `man 7 symlink`: + // On Linux, the permissions of a symbolic link are not used in any operations; ... + // And on macOS: + // Of these, only the flags are used by the system; the access permissions and + // ownership are ignored. + mode: os.ModeSymlink, + } +} + +// CopyFileInfo constructs FileInfo with the same public fields as info. +// It copies cacheability if available. +func CopyFileInfo(info os.FileInfo) FileInfo { + return FileInfo{ + name: info.Name(), + size: info.Size(), + mode: info.Mode(), + mod: info.ModTime(), + sys: info.Sys(), + cacheableFor: CacheableFor(info), + } +} + +func (f FileInfo) Name() string { return f.name } +func (f FileInfo) Size() int64 { return f.size } +func (f FileInfo) Mode() os.FileMode { return f.mode } +func (f FileInfo) ModTime() time.Time { return f.mod } +func (f FileInfo) IsDir() bool { return f.mode&os.ModeDir != 0 } +func (f FileInfo) Sys() interface{} { return f.sys } +func (f FileInfo) CacheableFor() time.Duration { return f.cacheableFor } + +func (f FileInfo) WithName(name string) FileInfo { + cp := f + cp.name = name + return cp +} +func (f FileInfo) WithSize(size int64) FileInfo { + cp := f + cp.size = size + return cp +} +func (f FileInfo) WithModePerm(perm os.FileMode) FileInfo { + cp := f + cp.mode = (perm & os.ModePerm) | (cp.mode &^ os.ModePerm) + return cp +} +func (f FileInfo) WithModeType(modeType os.FileMode) FileInfo { + cp := f + cp.mode = (modeType & os.ModeType) | (cp.mode &^ os.ModeType) + return cp +} +func (f FileInfo) WithModTime(mod time.Time) FileInfo { + cp := f + cp.mod = mod + return cp +} +func (f FileInfo) WithSys(sys interface{}) FileInfo { + cp := f + cp.sys = sys + return cp +} +func (f FileInfo) WithCacheableFor(d time.Duration) FileInfo { + cp := f + cp.cacheableFor = d + return cp +} + +func (f FileInfo) Equal(g FileInfo) bool { + if !f.mod.Equal(g.mod) { + return false + } + f.mod = g.mod + return f == g +} diff --git a/file/fsnode/fsnode.go b/file/fsnode/fsnode.go new file mode 100644 index 00000000..fca74b13 --- /dev/null +++ b/file/fsnode/fsnode.go @@ -0,0 +1,199 @@ +// fsnode represents a filesystem as a directed graph (probably a tree for many implementations). +// Directories are nodes with out edges (children). Files are nodes without. +// +// fsnode.T is designed for incremental iteration. Callers can step through the graph one link +// at a time (Parent.Child) or one level at a time (Parent.Children). In general, implementations +// should do incremental work for each step. See also: Cacheable. +// +// Compared to fs.FS: +// * Leaf explicitly models an unopened file. fs users have to choose their own representation, +// like the pair (fs.FS, name string) or func Open(...). +// * Graph traversal (that is, directory listing) uses the one node type, rather than a separate +// one (like fs.DirEntry). Callers can access "all cheaply available FileInfo" during listing +// or can Open nodes if they want completeness at higher cost. +// * Parent offers one, explicit way of traversing the graph. fs.FS has optional ReadDirFS or +// callers can Open(".") and see if ReadDirFile is returned. (fs.ReadDir unifies these but also +// disallows pagination). +// * Only supports directories and files. fs.FS supports much more. TODO: Add symlinks? +// * fs.FS.Open naturally allows "jumping" several levels deep without step-by-step traversal. +// (See Parent.Child) for performance note. +package fsnode + +import ( + "context" + "fmt" + "io" + "os" + "time" + + "github.com/grailbio/base/errors" + "github.com/grailbio/base/ioctx/fsctx" +) + +type ( + // T is a Parent or Leaf. A T that is not either of those is invalid. + T interface { + // Info provides immediately-available information. A subset of fields must be accurate: + // Name + // Mode&os.ModeType + // IsDir + // The rest can be zero values if they're not immediately available. + // Implementations may find FileInfo (in this package) convenient for embedding or use + // in public API. + // + // Leaf.Open().Stat() gets complete information. That returned FileInfo must have the same + // values for the fields listed above. The others can change if better information is + // available. + // TODO: Specify something about Info()'s return changing after a Stat call? + Info() os.FileInfo + // FSNodeT distinguishes T from os.FileInfo. It does nothing. + FSNodeT() + } + // Parent is a T that has zero or more child Ts. + Parent interface { + // T.Info must be consistent with directory (mode and IsDir). + T + // Child returns the named child. Returns nil, os.ErrNotExist if no such child exists. + // name is not a path and must not contain '/'. It must satisfy fs.ValidPath, too. + // + // In some implementations, Child lookup may be relatively expensive and implementations + // may want to reduce the cost of accessing a deeply-nested node. They may make all Child() + // requests succeed immediately and then return path errors from listing or Leaf.Open + // operations for the earlier path segment. + Child(_ context.Context, name string) (T, error) + // Children returns an iterator that can list all children. + // Children takes no Context; it's expected to simply construct a view and return errors to + // callers when they choose an Iterator operation. + Children() Iterator + // AddChildLeaf adds a child leaf to this parent, returning the new + // leaf and an open file for the leaf's contents. The behavior of name + // collisions may vary by implementation. It may be convenient to embed + // ParentReadOnly if your Parent implementation is read-only. + // TODO: Include mode? + AddChildLeaf(_ context.Context, name string, flags uint32) (Leaf, fsctx.File, error) + // AddChildParent adds a child parent to this parent, returning the new + // parent. The behavior of name collisions may vary by implementation. + // It may be convenient to embed ParentReadOnly if your Parent + // implementation is read-only. + AddChildParent(_ context.Context, name string) (Parent, error) + // RemoveChild removes a child. It may be convenient to embed + // ParentReadOnly if your Parent implementation is read-only. + RemoveChild(_ context.Context, name string) error + } + // Iterator yields child nodes iteratively. + // + // Users must serialize their own method calls. No calls can be made after Close(). + // TODO: Do we need Stat here, maybe to update directory mode? + Iterator interface { + // Next gets the next node. Must return (nil, io.EOF) at the end, not (non-nil, io.EOF). + Next(context.Context) (T, error) + // Close frees resources. + Close(context.Context) error + } + // Leaf is a T corresponding to a fsctx.File. It can be opened any number of times and must + // allow concurrent calls (it may lock internally if necessary). + Leaf interface { + // T is implementation of common node operations. The FileInfo returned + // by T.Info must be consistent with a regular file (mode and !IsDir). + T + // OpenFile opens the file. File.Stat()'s result must be consistent + // with T.Info. flag holds the flag bits, specified the same as those + // passed to os.OpenFile.See os.O_*. + OpenFile(ctx context.Context, flag int) (fsctx.File, error) + } + // Cacheable optionally lets users make use of caching. The cacheable data depends on + // which type Cacheable is defined on: + // * On any T, FileInfo. + // * On an fsctx.File, the FileInfo and contents. + // + // Common T implementations are expected to be "views" of remote data sources not under + // our exclusive control (like local filesystem or S3). As such, callers should generally + // expect best-effort consistency, regardless of caching. + Cacheable interface { + // CacheableFor is the maximum allowed cache time. + // Zero means don't cache. Negative means cache forever. + // TODO: Make this a non-Duration type to avoid confusion with negatives? + CacheableFor() time.Duration + } + cacheableFor struct{ time.Duration } +) + +const CacheForever = time.Duration(-1) + +// CacheableFor returns the configured cache time if obj is Cacheable, otherwise returns default 0. +func CacheableFor(obj interface{}) time.Duration { + cacheable, ok := obj.(Cacheable) + if !ok { + return 0 + } + return cacheable.CacheableFor() +} +func NewCacheable(d time.Duration) Cacheable { return cacheableFor{d} } +func (c cacheableFor) CacheableFor() time.Duration { return c.Duration } + +// Open opens the file of a leaf in (the commonly desired) read-only mode. +func Open(ctx context.Context, n Leaf) (fsctx.File, error) { + return n.OpenFile(ctx, os.O_RDONLY) +} + +// IterateFull reads the full len(dst) nodes from Iterator. If actual number read is less than +// len(dst), error is non-nil. Error is io.EOF for EOF. Unlike io.ReadFull, this doesn't return +// io.ErrUnexpectedEOF (unless iter does). +func IterateFull(ctx context.Context, iter Iterator, dst []T) (int, error) { + for i := range dst { + var err error + dst[i], err = iter.Next(ctx) + if err != nil { + if err == io.EOF && dst[i] != nil { + return i, iteratorEOFError(iter) + } + return i, err + } + } + return len(dst), nil +} + +// IterateAll reads iter until EOF. Returns nil error on success, not io.EOF (like io.ReadAll). +func IterateAll(ctx context.Context, iter Iterator) ([]T, error) { + var dst []T + for { + node, err := iter.Next(ctx) + if err != nil { + if err == io.EOF { + if node != nil { + return dst, iteratorEOFError(iter) + } + return dst, nil + } + return dst, err + } + dst = append(dst, node) + } +} + +func iteratorEOFError(iter Iterator) error { + return errors.E(errors.Precondition, fmt.Sprintf("BUG: iterator.Next (%T) returned element+EOF", iter)) +} + +// ParentReadOnly is a partial implementation of Parent interface functions +// that returns NotSupported errors for all write operations. It may be +// convenient to embed if your Parent implementation is read-only. +// +// type MyParent struct { +// fsnode.ParentReadOnly +// } +// +// func (MyParent) ChildChild(context.Context, string) (T, error) { ... } +// func (MyParent) Children() Iterator { ... } +// // No need to implement write functions. +type ParentReadOnly struct{} + +func (ParentReadOnly) AddChildLeaf(context.Context, string, uint32) (Leaf, fsctx.File, error) { + return nil, nil, errors.E(errors.NotSupported) +} +func (ParentReadOnly) AddChildParent(context.Context, string) (Parent, error) { + return nil, errors.E(errors.NotSupported) +} +func (ParentReadOnly) RemoveChild(context.Context, string) error { + return errors.E(errors.NotSupported) +} diff --git a/file/fsnode/fsnodetesting/make.go b/file/fsnode/fsnodetesting/make.go new file mode 100644 index 00000000..f596a83f --- /dev/null +++ b/file/fsnode/fsnodetesting/make.go @@ -0,0 +1,25 @@ +package fsnodetesting + +import ( + "testing" + + "github.com/grailbio/base/file/fsnode" + "github.com/stretchr/testify/require" +) + +// MakeT is the inverse of (Walker).WalkContents. +// It keeps no references to node and its children after it returns. +func MakeT(t testing.TB, name string, node T) fsnode.T { + switch n := node.(type) { + case Parent: + var children []fsnode.T + for childName, child := range n { + children = append(children, MakeT(t, childName, child)) + } + return fsnode.NewParent(fsnode.NewDirInfo(name), fsnode.ConstChildren(children...)) + case Leaf: + return fsnode.ConstLeaf(fsnode.NewRegInfo(name), append([]byte{}, n...)) + } + require.Failf(t, "invalid node type", "node: %T", node) + panic("unreachable") +} diff --git a/file/fsnode/fsnodetesting/walk.go b/file/fsnode/fsnodetesting/walk.go new file mode 100644 index 00000000..fed02f85 --- /dev/null +++ b/file/fsnode/fsnodetesting/walk.go @@ -0,0 +1,101 @@ +package fsnodetesting + +import ( + "context" + "io/ioutil" + "testing" + + "github.com/grailbio/base/file/fsnode" + "github.com/grailbio/base/ioctx" + "github.com/grailbio/testutil/assert" + "github.com/stretchr/testify/require" +) + +// Walker is a collection of settings. +// TODO: Add (Walker).Walk* variant that inspects FileInfo, too, not just content. +type Walker struct { + IgnoredNames map[string]struct{} + // Info makes WalkContents return InfoT recursively. See that function. + Info bool +} + +// T, Parent, and Leaf are aliases to improve readability of fixture definitions. +// InfoT augments T with its FileInfo for tests that want to check mode, size, etc. +type ( + T = interface{} + Parent = map[string]T + Leaf = []byte + + InfoT = struct { + fsnode.FileInfo + T + } +) + +// WalkContents traverses all of node and returns map and []byte objects representing +// parents/directories and leaves/files, respectively. +// +// For example, if node is a Parent with children named a and b that are regular files, and an empty +// subdirectory subdir, returns: +// Parent{ +// "a": Leaf("a's content"), +// "b": Leaf("b's content"), +// "subdir": Parent{}, +// } +// +// If w.Info, the returned contents will include fsnode.FileInfo, for example: +// InfoT{ +// fsnode.NewDirInfo("parent"), +// Parent{ +// "a": InfoT{ +// fsnode.NewRegInfo("a").WithSize(11), +// Leaf("a's content"), +// }, +// "b": InfoT{ +// fsnode.NewRegInfo("b").WithModePerm(0755), +// Leaf("b's content"), +// }, +// "subdir": InfoT{ +// fsnode.NewDirInfo("subdir") +// Parent{}, +// }, +// }, +// } +func (w Walker) WalkContents(ctx context.Context, t testing.TB, node fsnode.T) T { + switch n := node.(type) { + case fsnode.Parent: + dir := make(Parent) + children, err := fsnode.IterateAll(ctx, n.Children()) + require.NoError(t, err) + for _, child := range children { + name := child.Info().Name() + if _, ok := w.IgnoredNames[name]; ok { + continue + } + _, collision := dir[name] + require.Falsef(t, collision, "name %q is repeated", name) + dir[name] = w.WalkContents(ctx, t, child) + } + if w.Info { + return InfoT{fsnode.CopyFileInfo(n.Info()), dir} + } + return dir + case fsnode.Leaf: + leaf := LeafReadAll(ctx, t, n) + if w.Info { + return InfoT{fsnode.CopyFileInfo(n.Info()), leaf} + } + return leaf + } + require.Failf(t, "invalid node type", "node: %T", node) + panic("unreachable") +} + +func LeafReadAll(ctx context.Context, t testing.TB, n fsnode.Leaf) []byte { + file, err := fsnode.Open(ctx, n) + require.NoError(t, err) + defer func() { assert.NoError(t, file.Close(ctx)) }() + content, err := ioutil.ReadAll(ioctx.ToStdReader(ctx, file)) + require.NoError(t, err) + return content +} diff --git a/file/fsnode/iterators.go b/file/fsnode/iterators.go new file mode 100644 index 00000000..5dd02155 --- /dev/null +++ b/file/fsnode/iterators.go @@ -0,0 +1,165 @@ +package fsnode + +import ( + "context" + "io" + "os" +) + +type sliceIterator struct { + remaining []T + closed bool +} + +var _ Iterator = (*sliceIterator)(nil) + +// NewIterator returns an iterator that yields the given nodes. +func NewIterator(nodes ...T) Iterator { + // Copy the slice because we'll mutate to nil later. + nodes = append([]T{}, nodes...) + return &sliceIterator{remaining: nodes} +} + +func (it *sliceIterator) Next(ctx context.Context) (T, error) { + if it.closed { + return nil, os.ErrClosed + } + if len(it.remaining) == 0 { + return nil, io.EOF + } + next := it.remaining[0] + it.remaining[0] = nil // TODO: Is this necessary to allow GC? + it.remaining = it.remaining[1:] + return next, nil +} + +func (it *sliceIterator) Close(context.Context) error { + it.closed = true + it.remaining = nil + return nil +} + +type lazyIterator struct { + make func(context.Context) ([]T, error) + fetched bool + delegate Iterator +} + +var _ Iterator = (*lazyIterator)(nil) + +// NewLazyIterator uses the given make function upon the first call to Next to +// make the nodes that it yields. +func NewLazyIterator(make func(context.Context) ([]T, error)) Iterator { + return &lazyIterator{make: make} +} + +func (it *lazyIterator) Next(ctx context.Context) (T, error) { + if err := it.ensureFetched(ctx); err != nil { + return nil, err + } + return it.delegate.Next(ctx) +} + +func (it *lazyIterator) Close(ctx context.Context) error { + if it.delegate == nil { + return nil + } + err := it.delegate.Close(ctx) + it.delegate = nil + return err +} + +func (it *lazyIterator) ensureFetched(ctx context.Context) error { + if it.fetched { + if it.delegate == nil { + return os.ErrClosed + } + return nil + } + nodes, err := it.make(ctx) + if err != nil { + return err + } + it.delegate = NewIterator(nodes...) + it.fetched = true + return nil +} + +type concatIterator struct { + iters []Iterator + closed bool +} + +var _ Iterator = (*concatIterator)(nil) + +// NewConcatIterator returns the elements of the given iterators in order, reading each until EOF. +// Manages calling Close on constituents (as it goes along and upon its own Close). +func NewConcatIterator(iterators ...Iterator) Iterator { + return &concatIterator{iters: append([]Iterator{}, iterators...)} +} + +func (it *concatIterator) Next(ctx context.Context) (T, error) { + if it.closed { + return nil, os.ErrClosed + } + for { + if len(it.iters) == 0 { + return nil, io.EOF + } + next, err := it.iters[0].Next(ctx) + if err == io.EOF { + err = nil + if next != nil { + err = iteratorEOFError(it.iters[0]) + } + if closeErr := it.iters[0].Close(ctx); closeErr != nil && err == nil { + err = closeErr + } + it.iters[0] = nil // TODO: Is this actually necessary to allow GC? + it.iters = it.iters[1:] + if err != nil { + return nil, err + } + continue + } + return next, err + } +} + +// Close attempts to Close remaining constituent iterators. Returns the first constituent Close +// error (but attempts to close the rest anyway). +func (it *concatIterator) Close(ctx context.Context) error { + it.closed = true + var err error + for _, iter := range it.iters { + if err2 := iter.Close(ctx); err2 != nil && err == nil { + err = err2 + } + } + it.iters = nil + return err +} + +type mapIterator struct { + iter Iterator + fn func(context.Context, T) (T, error) +} + +// MapIterator returns an Iterator that applies fn to each T yielded by iter. +func MapIterator(iter Iterator, fn func(context.Context, T) (T, error)) Iterator { + return mapIterator{iter, fn} +} +func (it mapIterator) Next(ctx context.Context) (T, error) { + if it.fn == nil { + return nil, os.ErrClosed + } + node, err := it.iter.Next(ctx) + if err == nil && it.fn != nil { + node, err = it.fn(ctx, node) + } + return node, err +} +func (it mapIterator) Close(ctx context.Context) error { + it.fn = nil + return it.iter.Close(ctx) +} diff --git a/file/fsnode/iterators_test.go b/file/fsnode/iterators_test.go new file mode 100644 index 00000000..46016415 --- /dev/null +++ b/file/fsnode/iterators_test.go @@ -0,0 +1,112 @@ +package fsnode + +import ( + "context" + "errors" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSliceAll(t *testing.T) { + ctx := context.Background() + nodes := []mockLeaf{{id: 0}, {id: 1}, {id: 2}} + got, err := IterateAll(ctx, NewIterator(nodes[0], nodes[1], nodes[2])) + require.NoError(t, err) + assert.Equal(t, []T{nodes[0], nodes[1], nodes[2]}, got) +} + +func TestSliceFull(t *testing.T) { + ctx := context.Background() + nodes := []mockLeaf{{id: 0}, {id: 1}, {id: 2}} + dst := make([]T, 2) + + iter := NewIterator(nodes[0], nodes[1], nodes[2]) + got, err := IterateFull(ctx, iter, dst) + require.NoError(t, err) + assert.Equal(t, 2, got) + assert.Equal(t, []T{nodes[0], nodes[1]}, dst) + + got, err = IterateFull(ctx, iter, dst) + assert.ErrorIs(t, err, io.EOF) + assert.Equal(t, 1, got) + assert.Equal(t, []T{nodes[2]}, dst[:1]) +} + +func TestLazy(t *testing.T) { + ctx := context.Background() + nodes := []mockLeaf{{id: 0}, {id: 1}, {id: 2}} + makeNodes := func(context.Context) ([]T, error) { + return []T{nodes[0], nodes[1], nodes[2]}, nil + } + got, err := IterateAll(ctx, NewLazyIterator(makeNodes)) + require.NoError(t, err) + assert.Equal(t, []T{nodes[0], nodes[1], nodes[2]}, got) +} + +func TestLazyErr(t *testing.T) { + ctx := context.Background() + makeNodes := func(context.Context) ([]T, error) { + return nil, errors.New("test error") + } + _, err := IterateAll(ctx, NewLazyIterator(makeNodes)) + require.Error(t, err) + require.Contains(t, err.Error(), "test error") +} + +func TestConcatAll(t *testing.T) { + ctx := context.Background() + nodes := []mockLeaf{{id: 0}, {id: 1}, {id: 2}, {id: 3}} + iter := NewConcatIterator( + NewIterator(), + NewIterator(nodes[0]), + NewIterator(), + NewIterator(), + NewIterator(nodes[1], nodes[2], nodes[3]), + NewIterator(), + ) + got, err := IterateAll(ctx, iter) + require.NoError(t, err) + assert.Equal(t, []T{nodes[0], nodes[1], nodes[2], nodes[3]}, got) +} + +func TestConcatFull(t *testing.T) { + ctx := context.Background() + nodes := []mockLeaf{{id: 0}, {id: 1}, {id: 2}, {id: 3}} + iter := NewConcatIterator( + NewIterator(), + NewIterator(nodes[0]), + NewIterator(), + NewIterator(), + NewIterator(nodes[1], nodes[2], nodes[3]), + NewIterator(), + ) + + var dst []T + got, err := IterateFull(ctx, iter, dst) + require.NoError(t, err) + assert.Equal(t, 0, got) + + dst = make([]T, 3) + got, err = IterateFull(ctx, iter, dst[:2]) + require.NoError(t, err) + assert.Equal(t, 2, got) + assert.Equal(t, []T{nodes[0], nodes[1]}, dst[:2]) + + got, err = IterateFull(ctx, iter, dst[:1]) + require.NoError(t, err) + assert.Equal(t, 1, got) + assert.Equal(t, []T{nodes[2]}, dst[:1]) + + got, err = IterateFull(ctx, iter, dst) + assert.ErrorIs(t, err, io.EOF) + assert.Equal(t, 1, got) + assert.Equal(t, []T{nodes[3]}, dst[:1]) +} + +type mockLeaf struct { + Leaf + id int +} diff --git a/file/fsnode/leaf.go b/file/fsnode/leaf.go new file mode 100644 index 00000000..13fff91d --- /dev/null +++ b/file/fsnode/leaf.go @@ -0,0 +1,81 @@ +package fsnode + +import ( + "bytes" + "context" + "os" + + "github.com/grailbio/base/errors" + "github.com/grailbio/base/ioctx" + "github.com/grailbio/base/ioctx/fsctx" +) + +type funcLeaf struct { + FileInfo + // open is called when OpenFile is called. See Leaf.OpenFile. + open func(context.Context, int) (fsctx.File, error) +} + +// FuncLeaf constructs a Leaf from an open function. +// It's invoked every time; implementations should do their own caching if desired. +func FuncLeaf(info FileInfo, open func(ctx context.Context, flag int) (fsctx.File, error)) Leaf { + return funcLeaf{info, open} +} +func (l funcLeaf) OpenFile(ctx context.Context, flag int) (fsctx.File, error) { + return l.open(ctx, flag) +} +func (l funcLeaf) FSNodeT() {} + +type ( + readerAtLeaf struct { + FileInfo + ioctx.ReaderAt + } + readerAtFile struct { + readerAtLeaf + off int64 + } +) + +// ReaderAtLeaf constructs a Leaf whose file reads from r. +// Cacheability of both metadata and content is governed by info. +func ReaderAtLeaf(info FileInfo, r ioctx.ReaderAt) Leaf { return readerAtLeaf{info, r} } + +func (r readerAtLeaf) OpenFile(context.Context, int) (fsctx.File, error) { + return &readerAtFile{r, 0}, nil +} +func (l readerAtLeaf) FSNodeT() {} + +func (f *readerAtFile) Stat(context.Context) (os.FileInfo, error) { + if f.ReaderAt == nil { + return nil, os.ErrClosed + } + return f.FileInfo, nil +} +func (f *readerAtFile) Read(ctx context.Context, dst []byte) (int, error) { + if f.ReaderAt == nil { + return 0, os.ErrClosed + } + n, err := f.ReadAt(ctx, dst, f.off) + f.off += int64(n) + return n, err +} +func (f *readerAtFile) Write(context.Context, []byte) (int, error) { + return 0, errors.E(errors.NotSupported, f.Name(), "is read-only") +} +func (f *readerAtFile) Close(context.Context) error { + if f.ReaderAt == nil { + return os.ErrClosed + } + f.ReaderAt = nil + return nil +} + +// ConstLeaf constructs a leaf with constant contents. Caller must not modify content after call. +// Uses content's size (ignoring existing info.Size). +func ConstLeaf(info FileInfo, content []byte) Leaf { + info = info.WithSize(int64(len(content))) + return ReaderAtLeaf(info, ioctx.FromStdReaderAt(bytes.NewReader(content))) +} + +// TODO: From *os.File? diff --git a/file/fsnode/parent.go b/file/fsnode/parent.go new file mode 100644 index 00000000..88a79428 --- /dev/null +++ b/file/fsnode/parent.go @@ -0,0 +1,67 @@ +package fsnode + +import ( + "context" + "os" + + "github.com/grailbio/base/log" +) + +// NewParent returns a Parent whose children are defined by gen. gen.GenerateChildren is called on +// every Parent query (including Child, which returns one result). Implementers should cache +// internally if necessary. +func NewParent(info FileInfo, gen ChildrenGenerator) Parent { + if !info.IsDir() { + log.Panicf("FileInfo has file mode, require directory: %#v", info) + } + return parentImpl{FileInfo: info, gen: gen} +} + +type ( + // ChildrenGenerator generates child nodes. + ChildrenGenerator interface { + GenerateChildren(context.Context) ([]T, error) + } + childrenGenFunc func(context.Context) ([]T, error) + childrenGenConst []T +) + +// FuncChildren constructs a ChildrenGenerator that simply invokes fn, for convenience. +func FuncChildren(fn func(context.Context) ([]T, error)) ChildrenGenerator { + return childrenGenFunc(fn) +} + +func (fn childrenGenFunc) GenerateChildren(ctx context.Context) ([]T, error) { return fn(ctx) } + +// ConstChildren constructs a ChildrenGenerator that always returns the given children. +func ConstChildren(children ...T) ChildrenGenerator { + children = append([]T{}, children...) + return childrenGenConst(children) +} + +func (c childrenGenConst) GenerateChildren(ctx context.Context) ([]T, error) { return c, nil } + +type parentImpl struct { + ParentReadOnly + FileInfo + gen ChildrenGenerator +} + +func (n parentImpl) Child(ctx context.Context, name string) (T, error) { + children, err := n.gen.GenerateChildren(ctx) + if err != nil { + return nil, err + } + for _, child := range children { + if child.Info().Name() == name { + return child, nil + } + } + return nil, os.ErrNotExist +} + +func (n parentImpl) Children() Iterator { + return NewLazyIterator(n.gen.GenerateChildren) +} + +func (n parentImpl) FSNodeT() {} diff --git a/file/fsnodefuse/attr_darwin.go b/file/fsnodefuse/attr_darwin.go new file mode 100644 index 00000000..06f3d2a8 --- /dev/null +++ b/file/fsnodefuse/attr_darwin.go @@ -0,0 +1,12 @@ +package fsnodefuse + +import "github.com/hanwen/go-fuse/v2/fuse" + +// blockSize is defined in the stat(2) man page: +// st_blocks The actual number of blocks allocated for the file in 512-byte units. As short symbolic links are stored in the inode, this number may be zero. +const blockSize = 512 + +func setBlockSize(*fuse.Attr, uint32) { + // a.Blksize not present on darwin. + // TODO: Implement statfs for darwin to pass iosize. +} diff --git a/file/fsnodefuse/attr_linux.go b/file/fsnodefuse/attr_linux.go new file mode 100644 index 00000000..d99952db --- /dev/null +++ b/file/fsnodefuse/attr_linux.go @@ -0,0 +1,11 @@ +package fsnodefuse + +import "github.com/hanwen/go-fuse/v2/fuse" + +// blockSize is defined in the stat(2) man page: +// The st_blocks field indicates the number of blocks allocated to the file, 512-byte units. (This may be smaller than st_size/512 when the file has holes.) +const blockSize = 512 + +func setBlockSize(a *fuse.Attr, size uint32) { + a.Blksize = size +} diff --git a/file/fsnodefuse/dir.go b/file/fsnodefuse/dir.go new file mode 100644 index 00000000..b37bd063 --- /dev/null +++ b/file/fsnodefuse/dir.go @@ -0,0 +1,314 @@ +package fsnodefuse + +import ( + "context" + "crypto/sha512" + "encoding/binary" + "os" + "sync" + "syscall" + "time" + + "github.com/grailbio/base/file/fsnode" + "github.com/grailbio/base/log" + "github.com/grailbio/base/sync/loadingcache" + "github.com/grailbio/base/sync/loadingcache/ctxloadingcache" + "github.com/grailbio/base/writehash" + "github.com/hanwen/go-fuse/v2/fs" + "github.com/hanwen/go-fuse/v2/fuse" +) + +type dirInode struct { + fs.Inode + cache loadingcache.Map + readdirplusCache readdirplusCache + + mu sync.Mutex + n fsnode.Parent +} + +var ( + _ fs.InodeEmbedder = (*dirInode)(nil) + + _ fs.NodeCreater = (*dirInode)(nil) + _ fs.NodeGetattrer = (*dirInode)(nil) + _ fs.NodeLookuper = (*dirInode)(nil) + _ fs.NodeReaddirer = (*dirInode)(nil) + _ fs.NodeSetattrer = (*dirInode)(nil) + _ fs.NodeUnlinker = (*dirInode)(nil) +) + +func (n *dirInode) Readdir(ctx context.Context) (_ fs.DirStream, errno syscall.Errno) { + defer handlePanicErrno(&errno) + ctx = ctxloadingcache.With(ctx, &n.cache) + s, err := newDirStream(ctx, n) + if err != nil { + return nil, errToErrno(err) + } + return s, fs.OK +} + +func (n *dirInode) Lookup(ctx context.Context, name string, out *fuse.EntryOut) (_ *fs.Inode, errno syscall.Errno) { + defer handlePanicErrno(&errno) + ctx = ctxloadingcache.With(ctx, &n.cache) + childFSNode := n.readdirplusCache.Get(name) + if childFSNode == nil { + var err error + childFSNode, err = n.n.Child(ctx, name) + if err != nil { + return nil, errToErrno(err) + } + } + childInode := n.GetChild(name) + if childInode == nil || stableAttr(n, childFSNode) != childInode.StableAttr() { + childInode = n.newInode(ctx, childFSNode) + } + setFSNode(childInode, childFSNode) + setEntryOut(out, childInode.StableAttr().Ino, childFSNode) + return childInode, fs.OK +} + +func (n *dirInode) Getattr(ctx context.Context, _ fs.FileHandle, a *fuse.AttrOut) (errno syscall.Errno) { + defer handlePanicErrno(&errno) + setAttrFromFileInfo(&a.Attr, n.n.Info()) + a.SetTimeout(getCacheTimeout(n.n)) + return fs.OK +} + +func (n *dirInode) Setattr(ctx context.Context, _ fs.FileHandle, _ *fuse.SetAttrIn, a *fuse.AttrOut) (errno syscall.Errno) { + defer handlePanicErrno(&errno) + n.cache.DeleteAll() + + // To avoid deadlock we must notify invalidations while not holding certain inode locks. + // See: https://github.com/libfuse/libfuse/blob/d709c24cbd9e1041264c551c2a4445e654eaf429/include/fuse_lowlevel.h#L1654-L1661 + // We're ok with best-effort execution of the invalidation so a goroutine conveniently avoids locks. + children := n.Children() + go func() { + for childName, child := range children { + // TODO: Consider merely NotifyEntry instead of NotifyDelete. + // Both force a Lookup on the next access, as desired. However, NotifyDelete also + // deletes the child inode immediately which has UX consequences. For example, if a + // user's shell is currently working in that directory, after NotifyDelete they may + // see shell operations fail (similar to what they might see if they `git checkout` a + // branch that doesn't include the current working directory). NotifyEntry avoids those + // errors but may introduce inconsistency (that shell will remain using the old inode + // and its stale contents), which may be confusing. + // TODO: josh@ is not sure about this inconsistency thing. + if errno := n.NotifyDelete(childName, child); errno != fs.OK { + log.Error.Printf("dirInode.Setattr %s: error from NotifyDelete %s: %v", n.Path(nil), childName, errno) + } + } + }() + + setAttrFromFileInfo(&a.Attr, n.n.Info()) + a.SetTimeout(getCacheTimeout(n.n)) + return fs.OK +} + +func (n *dirInode) Create( + ctx context.Context, + name string, + flags uint32, + mode uint32, + out *fuse.EntryOut, +) (_ *fs.Inode, _ fs.FileHandle, _ uint32, errno syscall.Errno) { + defer handlePanicErrno(&errno) + if (mode & syscall.S_IFREG) == 0 { + return nil, nil, 0, syscall.EINVAL + } + leaf, f, err := n.n.AddChildLeaf(ctx, name, flags) + if err != nil { + return nil, nil, 0, errToErrno(err) + } + ino := hashIno(n, leaf.Info().Name()) + embed := ®Inode{n: leaf} + inode := n.NewInode(ctx, embed, fs.StableAttr{Mode: mode, Ino: ino}) + h, err := makeHandle(embed, flags, f) + return inode, h, 0, errToErrno(err) +} + +func (n *dirInode) Unlink(ctx context.Context, name string) syscall.Errno { + return errToErrno(n.n.RemoveChild(ctx, name)) +} + +func (n *dirInode) Mkdir( + ctx context.Context, + name string, + mode uint32, + out *fuse.EntryOut, +) (_ *fs.Inode, errno syscall.Errno) { + defer handlePanicErrno(&errno) + p, err := n.n.AddChildParent(ctx, name) + if err != nil { + return nil, errToErrno(err) + } + embed := &dirInode{n: p} + mode |= syscall.S_IFDIR + ino := hashIno(n, name) + inode := n.NewInode(ctx, embed, fs.StableAttr{Mode: mode, Ino: ino}) + setEntryOut(out, ino, p) + return inode, fs.OK +} + +// newInode returns an inode that wraps fsNode. The type of inode (embedder) +// to create is inferred from the type of fsNode. +func (n *dirInode) newInode(ctx context.Context, fsNode fsnode.T) *fs.Inode { + var embed fs.InodeEmbedder + // TODO: Set owner/UID? + switch fsNode.(type) { + case fsnode.Parent: + embed = &dirInode{} + case fsnode.Leaf: + embed = ®Inode{} + default: + log.Panicf("invalid node type: %T", fsNode) + } + inode := n.NewInode(ctx, embed, stableAttr(n, fsNode)) + // inode may be an existing inode with an existing embedder. Regardless, + // update the underlying fsnode.T. + setFSNode(inode, fsNode) + return inode +} + +func setEntryOut(out *fuse.EntryOut, ino uint64, n fsnode.T) { + out.Ino = ino + setAttrFromFileInfo(&out.Attr, n.Info()) + cacheTimeout := getCacheTimeout(n) + out.SetEntryTimeout(cacheTimeout) + out.SetAttrTimeout(cacheTimeout) +} + +func setAttrFromFileInfo(a *fuse.Attr, info os.FileInfo) { + if info.IsDir() { + a.Mode |= syscall.S_IFDIR + } else { + a.Mode |= syscall.S_IFREG + } + a.Mode |= uint32(info.Mode() & os.ModePerm) + a.Size = uint64(info.Size()) + a.Blocks = a.Size / blockSize + // We want to encourage large reads to reduce syscall overhead. FUSE has a 128 KiB read + // size limit anyway. + // TODO: Is there a better way to set this, in case size limits ever change? + setBlockSize(a, 128*1024) + if mod := info.ModTime(); !mod.IsZero() { + a.SetTimes(nil, &mod, nil) + } +} + +func getCacheTimeout(any interface{}) time.Duration { + cacheableFor := fsnode.CacheableFor(any) + if cacheableFor < 0 { + return 365 * 24 * time.Hour + } + return cacheableFor +} + +func mode(n fsnode.T) uint32 { + switch n.(type) { + case fsnode.Parent: + return syscall.S_IFDIR + case fsnode.Leaf: + return syscall.S_IFREG + default: + log.Panicf("invalid node type: %T", n) + panic("unreachable") + } +} + +// readdirplusCache caches nodes for calls to Lookup that go-fuse issues when +// servicing READDIRPLUS. To handle READDIRPLUS, go-fuse interleaves LOOKUP +// calls for each directory entry. dirStream populates this cache with the +// last returned entry so that it can be used in Lookup, saving a possibly +// costly (fsnode.Parent).Child call. +type readdirplusCache struct { + // mu is used to provide exclusive access to the fields below. + mu sync.Mutex + // m maps child node names to the set of cached nodes for each name. The + // calls to Lookup do not indicate whether they are for a READDIRPLUS, so + // if there are two dirStream instances which each cached a node for a + // given name, Lookup will use an arbitrary node in the cache, as we don't + // know which Lookup is associated with which dirStream. This might cause + // transiently stale information but keeps the implementation simple. + m map[string][]fsnode.T +} + +// Put puts a node n in the cache. +func (c *readdirplusCache) Put(n fsnode.T) { + c.mu.Lock() + defer c.mu.Unlock() + if c.m == nil { + c.m = make(map[string][]fsnode.T) + } + name := n.Info().Name() + c.m[name] = append(c.m[name], n) +} + +// Get gets a node in the cache for the given name. If no node is cached, +// returns nil. +func (c *readdirplusCache) Get(name string) fsnode.T { + c.mu.Lock() + defer c.mu.Unlock() + if c.m == nil { + return nil + } + ns, ok := c.m[name] + if !ok { + return nil + } + return ns[0] +} + +// Drop drops the node n from the cache. n must have been previously added as +// an entry for name using Put. +func (c *readdirplusCache) Drop(n fsnode.T) { + c.mu.Lock() + defer c.mu.Unlock() + name := n.Info().Name() + ns, _ := c.m[name] + if len(ns) == 1 { + delete(c.m, name) + return + } + var dropIndex int + for i := range ns { + if n == ns[i] { + dropIndex = i + break + } + } + last := len(ns) - 1 + ns[dropIndex] = ns[last] + ns[last] = nil + ns = ns[:last] + c.m[name] = ns +} + +func stableAttr(parent fs.InodeEmbedder, n fsnode.T) fs.StableAttr { + var mode uint32 + switch modeType := n.Info().Mode().Type(); modeType { + case 0: + mode |= syscall.S_IFREG + case os.ModeDir: + mode |= syscall.S_IFDIR + case os.ModeSymlink: + mode |= syscall.S_IFLNK + default: + log.Panicf("invalid node mode type: %v", modeType) + } + return fs.StableAttr{ + Mode: mode, + Ino: hashIno(parent, n.Info().Name()), + } +} + +func hashParentInoAndName(parentIno uint64, name string) uint64 { + h := sha512.New() + writehash.Uint64(h, parentIno) + writehash.String(h, name) + return binary.LittleEndian.Uint64(h.Sum(nil)[:8]) +} + +func hashIno(parent fs.InodeEmbedder, name string) uint64 { + return hashParentInoAndName(parent.EmbeddedInode().StableAttr().Ino, name) +} diff --git a/file/fsnodefuse/dir_test.go b/file/fsnodefuse/dir_test.go new file mode 100644 index 00000000..95e302bb --- /dev/null +++ b/file/fsnodefuse/dir_test.go @@ -0,0 +1,67 @@ +package fsnodefuse + +import ( + "context" + "io/ioutil" + "os" + "path" + "strconv" + "testing" + "time" + + "github.com/grailbio/base/file/fsnode" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestLookupCaching checks that Lookup doesn't return stale nodes (past their cache time). +// It's a regression test. +func TestLookupCaching(t *testing.T) { + const ( + childName = "time" + cacheFor = time.Millisecond + waitSlack = 500 * time.Millisecond + ) + // root is a directory with one child. Each time directory listing is initiated, the content + // of the child is fixed as the current time. + root := fsnode.NewParent( + fsnode.NewDirInfo(""), + fsnode.FuncChildren(func(ctx context.Context) ([]fsnode.T, error) { + nowUnixNanos := time.Now().UnixNano() + return []fsnode.T{ + fsnode.ConstLeaf( + fsnode.NewRegInfo(childName).WithCacheableFor(cacheFor), + []byte(strconv.FormatInt(nowUnixNanos, 10)), + ), + }, nil + }), + ) + withMounted(t, root, func(rootPath string) { + childPath := path.Join(rootPath, childName) + // Trigger a directory listing and read the event time from the file. + listingTime := readUnixNanosFile(t, childPath) + // Wait until that time has passed. + // Note: We have to use wall clock time here, not mock, because we're interested in kernel + // inode caching interactions. + // TODO: Is there a way to guarantee that entry cache time has elapsed, for robustness? + waitTime := listingTime.Add(waitSlack) + sleep := waitTime.Sub(time.Now()) + time.Sleep(sleep) + secondListingTime := readUnixNanosFile(t, childPath) + assert.NotEqual(t, + listingTime.UnixNano(), secondListingTime.UnixNano(), + "second listing should have different timestamp", + ) + }) +} + +func readUnixNanosFile(t *testing.T, filePath string) time.Time { + child, err := os.Open(filePath) + require.NoError(t, err) + defer func() { assert.NoError(t, child.Close()) }() + content, err := ioutil.ReadAll(child) + require.NoError(t, err) + listingUnixNano, err := strconv.ParseInt(string(content), 10, 64) + require.NoError(t, err) + return time.Unix(0, listingUnixNano) +} diff --git a/file/fsnodefuse/dirstream.go b/file/fsnodefuse/dirstream.go new file mode 100644 index 00000000..2f95037c --- /dev/null +++ b/file/fsnodefuse/dirstream.go @@ -0,0 +1,91 @@ +package fsnodefuse + +import ( + "context" + "io" + "syscall" + + "github.com/grailbio/base/errors" + "github.com/grailbio/base/file/fsnode" + "github.com/grailbio/base/log" + "github.com/hanwen/go-fuse/v2/fs" + "github.com/hanwen/go-fuse/v2/fuse" +) + +type dirStream struct { + ctx context.Context + dir *dirInode + entries []fsnode.T + // prev is the node of the previous entry, i.e. the node of the most recent + // entry returned by Next. We cache this node to service LOOKUP operations + // that go-fuse issues when servicing READDIRPLUS. See lookupCache. + prev fsnode.T +} + +var _ fs.DirStream = (*dirStream)(nil) + +// newDirStream returns a dirStream whose entries are the children of dir. +// Children are loaded eagerly so that any errors are reported before any +// entries are returned by the stream. If Next returns a non-OK Errno after +// a call that returned an OK Errno, the READDIR operation returns an EIO, +// regardless of the returned Errno. See +// https://github.com/hanwen/go-fuse/issues/436 . +func newDirStream( + ctx context.Context, + dir *dirInode, +) (_ *dirStream, err error) { + var ( + entries []fsnode.T + iter = dir.n.Children() + ) + defer errors.CleanUpCtx(ctx, iter.Close, &err) + for { + n, err := iter.Next(ctx) + if err == io.EOF { + return &dirStream{ + ctx: ctx, + dir: dir, + entries: entries, + }, nil + } + if err != nil { + return nil, err + } + entries = append(entries, n) + } +} + +func (d *dirStream) HasNext() bool { + return len(d.entries) != 0 +} + +func (d *dirStream) Next() (_ fuse.DirEntry, errno syscall.Errno) { + defer handlePanicErrno(&errno) + var next fsnode.T + next, d.entries = d.entries[0], d.entries[1:] + if d.prev != nil { + d.dir.readdirplusCache.Drop(d.prev) + } + d.prev = next + d.dir.readdirplusCache.Put(next) + name := next.Info().Name() + return fuse.DirEntry{ + Name: name, + Mode: mode(next), + Ino: hashIno(d.dir, name), + }, fs.OK +} + +func (d *dirStream) Close() { + var err error + defer handlePanicErr(&err) + defer func() { + if err != nil { + log.Error.Printf("fsnodefuse.dirStream: error on close: %v", err) + } + }() + if d.prev != nil { + d.dir.readdirplusCache.Drop(d.prev) + d.prev = nil + } +} diff --git a/file/fsnodefuse/dirstream_test.go b/file/fsnodefuse/dirstream_test.go new file mode 100644 index 00000000..3900f39a --- /dev/null +++ b/file/fsnodefuse/dirstream_test.go @@ -0,0 +1,41 @@ +package fsnodefuse + +import ( + "os" + "sync/atomic" + "testing" + + "github.com/grailbio/base/file/fsnode" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestNameCollision verifies that we can list a directory in which there are +// entries with duplicate names without panicking. +func TestNameCollision(t *testing.T) { + makeTest := func(children ...fsnode.T) func(*testing.T) { + return func(t *testing.T) { + root := fsnode.NewParent( + fsnode.NewDirInfo("test"), + fsnode.ConstChildren(children...), + ) + withMounted(t, root, func(mountDir string) { + f, err := os.Open(mountDir) + require.NoError(t, err, "opening mounted directory") + defer func() { assert.NoError(t, f.Close()) }() + _, err = f.Readdir(0) + assert.NoError(t, err, "reading mounted directory") + assert.Zero(t, atomic.LoadUint32(&numHandledPanics), "checking number of panics") + }) + } + } + var ( + aReg = fsnode.ConstLeaf(fsnode.NewRegInfo("a"), []byte{}) + aDir = fsnode.NewParent(fsnode.NewDirInfo("a"), fsnode.ConstChildren()) + bReg = fsnode.ConstLeaf(fsnode.NewRegInfo("b"), []byte{}) + bDir = fsnode.NewParent(fsnode.NewDirInfo("b"), fsnode.ConstChildren()) + ) + t.Run("reg_first", makeTest(aReg, aDir)) + t.Run("dir_first", makeTest(aDir, aReg)) + t.Run("mixed", makeTest(aReg, aDir, aDir, bReg, aReg, bReg, aReg, bDir, aReg, aDir)) +} diff --git a/file/fsnodefuse/err.go b/file/fsnodefuse/err.go new file mode 100644 index 00000000..678a9156 --- /dev/null +++ b/file/fsnodefuse/err.go @@ -0,0 +1,60 @@ +package fsnodefuse + +import ( + "fmt" + "runtime/debug" + "sync/atomic" + "syscall" + + "github.com/grailbio/base/errors" + "github.com/grailbio/base/log" + "github.com/hanwen/go-fuse/v2/fs" +) + +// numHandledPanics is the total number of panics handled by handlePanicErrno. +// It can be used in testing to verify whether we triggered panics when +// handling operations. It must be accessed atomically, i.e. using atomic.* +// functions. +var numHandledPanics uint32 + +// handlePanicErrno is a last resort to prevent panics from reaching go-fuse and breaking the FUSE mount. +// All go-fuse-facing APIs that return Errno should defer it. +func handlePanicErrno(errno *syscall.Errno) { + r := recover() + if r == nil { + return + } + atomic.AddUint32(&numHandledPanics, 1) + *errno = errToErrno(makePanicErr(r)) +} + +// handlePanicErr is like handlePanicErrno but for APIs that don't return Errno. +func handlePanicErr(dst *error) { + r := recover() + if r == nil { + return + } + *dst = makePanicErr(r) +} + +func makePanicErr(recovered interface{}) error { + if err, ok := recovered.(error); ok { + return errors.E(err, fmt.Sprintf("recovered panic, stack:\n%v", string(debug.Stack()))) + } + return errors.E(fmt.Sprintf("recovered panic: %v, stack:\n%v", recovered, string(debug.Stack()))) +} + +func errToErrno(err error) syscall.Errno { + if err == nil { + return fs.OK + } + e := errors.Recover(err) + kind := e.Kind + errno, ok := kind.Errno() + if ok { + log.Error.Printf("returning errno: %v for error: %v", errno, e) + return errno + } + log.Error.Printf("error with no good errno match: kind: %v, err: %v", kind, err) + return syscall.EIO +} diff --git a/file/fsnodefuse/err_test.go b/file/fsnodefuse/err_test.go new file mode 100644 index 00000000..86c79c63 --- /dev/null +++ b/file/fsnodefuse/err_test.go @@ -0,0 +1,80 @@ +package fsnodefuse + +import ( + "context" + "io/ioutil" + "os" + "path" + "sync/atomic" + "testing" + + "github.com/grailbio/base/file/fsnode" + "github.com/grailbio/base/ioctx/fsctx" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestPanic* panic while handling a FUSE operation, then repeat the operation to make +// sure the mount isn't broken. +// TODO: Test operations exhaustively, maybe with something like fuzzing. + +func TestPanicOpen(t *testing.T) { + const success = "success" + var ( + info = fsnode.NewRegInfo("panicOnce") + panicked int32 + ) + root := fsnode.NewParent(fsnode.NewDirInfo("root"), + fsnode.ConstChildren( + fsnode.FuncLeaf( + info, + func(ctx context.Context, flag int) (fsctx.File, error) { + if atomic.AddInt32(&panicked, 1) == 1 { + panic("it's a panic!") + } + return fsnode.Open(ctx, fsnode.ConstLeaf(info, []byte(success))) + }, + ), + ), + ) + withMounted(t, root, func(mountDir string) { + _, err := os.Open(path.Join(mountDir, "panicOnce")) + require.Error(t, err) + f, err := os.Open(path.Join(mountDir, "panicOnce")) + require.NoError(t, err) + defer func() { require.NoError(t, f.Close()) }() + content, err := ioutil.ReadAll(f) + require.NoError(t, err) + assert.Equal(t, success, string(content)) + }) +} + +func TestPanicList(t *testing.T) { + const success = "success" + var panicked int32 + root := fsnode.NewParent(fsnode.NewDirInfo("root"), + fsnode.FuncChildren(func(context.Context) ([]fsnode.T, error) { + if atomic.AddInt32(&panicked, 1) == 1 { + panic("it's a panic!") + } + return []fsnode.T{ + fsnode.ConstLeaf(fsnode.NewRegInfo(success), nil), + }, nil + }), + ) + withMounted(t, root, func(mountDir string) { + dir, err := os.Open(mountDir) + require.NoError(t, err) + ents, _ := dir.Readdirnames(0) + // It seems like Readdirnames returns nil error despite the panic. + // TODO: Confirm this is expected. + assert.Empty(t, ents) + require.NoError(t, dir.Close()) + dir, err = os.Open(mountDir) + require.NoError(t, err) + defer func() { require.NoError(t, dir.Close()) }() + ents, err = dir.Readdirnames(0) + assert.NoError(t, err) + assert.Equal(t, []string{success}, ents) + }) +} diff --git a/file/fsnodefuse/fsnodefuse.go b/file/fsnodefuse/fsnodefuse.go new file mode 100644 index 00000000..8b78881a --- /dev/null +++ b/file/fsnodefuse/fsnodefuse.go @@ -0,0 +1,81 @@ +// fsnodefuse implements github.com/hanwen/go-fuse/v2/fs for fsnode.T. +// It's a work-in-progress. No correctness or stability is guaranteed. Or even suggested. +// +// fsnode.Parent naturally becomes a directory. fsnode.Leaf becomes a file. Support for FUSE +// operations on that file depends on what Leaf.Open returns. If that fsctx.File is also +// spliceio.ReaderAt: +// FUSE file supports concurrent, random reads and uses splices to reduce +// userspace <-> kernelspace memory copying. +// ioctx.ReaderAt: +// FUSE file supports concurrent, random reads. +// Otherwise: +// FUSE file supports in-order, contiguous reads only. That is, each read must +// start where the previous one ended. At fsctx.File EOF, file size is recorded +// and then overrides what fsctx.File.Stat() reports for future getattr calls, +// so users can see they're done reading. +// TODO: Decide if there's a better place for this feature. +package fsnodefuse + +import ( + "fmt" + "runtime" + + "github.com/grailbio/base/file/fsnode" + "github.com/grailbio/base/file/internal/kernel" + "github.com/hanwen/go-fuse/v2/fs" + "github.com/hanwen/go-fuse/v2/fuse" +) + +// NewRoot creates a FUSE inode whose contents are the given fsnode.T. +// Note that this inode must be mounted with options from ConfigureRequiredMountOptions. +func NewRoot(node fsnode.T) fs.InodeEmbedder { + switch n := node.(type) { + case fsnode.Parent: + return &dirInode{n: n} + case fsnode.Leaf: + // TODO(josh): Test this path. + return ®Inode{n: n} + } + panic(fmt.Sprintf("unrecognized fsnode type: %T, %[1]v", node)) +} + +// ConfigureRequiredMountOptions sets values in opts to be compatible with fsnodefuse's +// implementation. Users of NewRoot must use these options, and they should call this last, +// to make sure the required options take effect. +func ConfigureRequiredMountOptions(opts *fuse.MountOptions) { + opts.MaxReadAhead = kernel.MaxReadAhead +} + +// ConfigureDefaultMountOptions provides defaults that callers may want to start with, for performance. +func ConfigureDefaultMountOptions(opts *fuse.MountOptions) { + // Increase MaxBackground from its default value (12) to improve S3 read performance. + // + // Empirically, while reading a 30 GiB files in chunks in parallel, the number of concurrent + // reads processed by our FUSE server [1] was ~12 with the default, corresponding to poor + // network utilization (only 500 Mb/s on m5d.4x in EC2); it rises to ~120 after, and network + // read bandwidth rises to >7 Gb/s, close to the speed of reading directly from S3 with + // this machine (~9 Gb/s). + // + // libfuse documentation [2] suggests that this limits the number of kernel readahead + // requests, so raising the limit may allow kernel readahead for every chunk, which could + // plausibly explain the performance benefit. (There's also mention of large direct I/O + // requests from userspace; josh@ did not think his Go test program was using direct I/O for + // this benchmark, but maybe he just didn't know). + // + // This particular value is a somewhat-informed guess. We'd like it to be high enough to + // admit all the parallelism that applications may profitably want. EC2 instances generally + // have <1 Gb/s network bandwidth per CPU (m5n.24x is around that, and non-'n' types have + // several times less), and S3 connections are limited to ~700 Mb/s [3], so just a couple of + // read chunks per CPU are sufficient to be I/O-bound for large objects. Many smaller object + // reads tend to not reach maximum bandwidth, so applications may increase parallelism, + // so we set our limit several times higher. + // TODO: Run more benchmarks (like github.com/grailbio/base/file/filebench) and tune. + // + // [1] As measured by simple logging: https://gitlab.com/grailbio/grail/-/merge_requests/8292/diffs?commit_id=7681acfcac836b92eaca60eb567245b32b81ec50 + // [2] https://web.archive.org/web/20220815053939/https://libfuse.github.io/doxygen/structfuse__conn__info.html#a5f9e695735727343448ae1e1a86dfa03 + // [3] 85-90 MB/s: https://web.archive.org/web/20220325121400/https://docs.aws.amazon.com/AmazonS3/latest/userguide/optimizing-performance-design-patterns.html#optimizing-performance-parallelization + opts.MaxBackground = 16 * runtime.NumCPU() + + // We don't use extended attributes so we can skip these requests to improve performance. + opts.DisableXAttrs = true +} diff --git a/file/fsnodefuse/handle.go b/file/fsnodefuse/handle.go new file mode 100644 index 00000000..c66e6f6e --- /dev/null +++ b/file/fsnodefuse/handle.go @@ -0,0 +1,349 @@ +package fsnodefuse + +import ( + "context" + "fmt" + "io" + "syscall" + + "github.com/grailbio/base/errors" + "github.com/grailbio/base/file/fsnode" + "github.com/grailbio/base/file/fsnodefuse/trailingbuf" + "github.com/grailbio/base/file/internal/kernel" + "github.com/grailbio/base/ioctx" + "github.com/grailbio/base/ioctx/fsctx" + "github.com/grailbio/base/ioctx/spliceio" + "github.com/grailbio/base/log" + "github.com/grailbio/base/sync/loadingcache" + "github.com/grailbio/base/sync/loadingcache/ctxloadingcache" + "github.com/hanwen/go-fuse/v2/fs" + "github.com/hanwen/go-fuse/v2/fuse" +) + +// makeHandle makes a fs.FileHandle for the given file, constructing an +// appropriate implementation given the flags and file implementation. +func makeHandle(n *regInode, flags uint32, file fsctx.File) (fs.FileHandle, error) { + var ( + spliceioReaderAt, isSpliceioReaderAt = file.(spliceio.ReaderAt) + ioctxReaderAt, isIoctxReaderAt = file.(ioctx.ReaderAt) + ) + if (flags&fuse.O_ANYWRITE) == 0 && !isSpliceioReaderAt && !isIoctxReaderAt { + tbReaderAt := trailingbuf.New(file, 0, kernel.MaxReadAhead) + return sizingHandle{ + n: n, + f: file, + r: tbReaderAt, + cache: &n.cache, + }, nil + } + var r fs.FileReader + switch { + case isSpliceioReaderAt: + r = fileReaderSpliceio{spliceioReaderAt} + case isIoctxReaderAt: + r = fileReaderIoctx{ioctxReaderAt} + case (flags & syscall.O_WRONLY) != syscall.O_WRONLY: + return nil, errors.E( + errors.NotSupported, + fmt.Sprintf("%T must implement spliceio.SpliceReaderAt or ioctx.ReaderAt", file), + ) + } + w, _ := file.(Writable) + return &handle{ + f: file, + r: r, + w: w, + cache: &n.cache, + }, nil +} + +// sizingHandle infers the size of the underlying fsctx.File stream based on +// EOF. +type sizingHandle struct { + n *regInode + f fsctx.File + r *trailingbuf.ReaderAt + cache *loadingcache.Map +} + +var ( + _ fs.FileGetattrer = (*sizingHandle)(nil) + _ fs.FileReader = (*sizingHandle)(nil) + _ fs.FileReleaser = (*sizingHandle)(nil) +) + +func (h sizingHandle) Getattr(ctx context.Context, a *fuse.AttrOut) (errno syscall.Errno) { + defer handlePanicErrno(&errno) + ctx = ctxloadingcache.With(ctx, h.cache) + + // Note: Implementations that don't know the exact data size in advance may used some fixed + // overestimate for size. + statInfo, err := h.f.Stat(ctx) + if err != nil { + return errToErrno(err) + } + info := fsnode.CopyFileInfo(statInfo) + + localSize, localKnown, err := h.r.Size(ctx) + if err != nil { + return errToErrno(err) + } + + h.n.defaultSizeMu.RLock() + sharedKnown := h.n.defaultSizeKnown + sharedSize := h.n.defaultSize + h.n.defaultSizeMu.RUnlock() + + if localKnown && !sharedKnown { + // This may be the first handle to reach EOF. Update the shared data. + h.n.defaultSizeMu.Lock() + if !h.n.defaultSizeKnown { + h.n.defaultSizeKnown = true + h.n.defaultSize = localSize + sharedSize = localSize + } else { + sharedSize = h.n.defaultSize + } + h.n.defaultSizeMu.Unlock() + sharedKnown = true + } + if sharedKnown { + if localKnown && localSize != sharedSize { + log.Error.Printf( + "fsnodefuse.sizingHandle.Getattr: size-at-EOF mismatch: this handle: %d, earlier: %d", + localSize, sharedSize) + return syscall.EIO + } + info = info.WithSize(sharedSize) + } + setAttrFromFileInfo(&a.Attr, info) + return fs.OK +} + +func (h sizingHandle) Read(ctx context.Context, dst []byte, off int64) (_ fuse.ReadResult, errno syscall.Errno) { + defer handlePanicErrno(&errno) + ctx = ctxloadingcache.With(ctx, h.cache) + + n, err := h.r.ReadAt(ctx, dst, off) + if err == io.EOF { + err = nil + } + return fuse.ReadResultData(dst[:n]), errToErrno(err) +} + +func (h *sizingHandle) Release(ctx context.Context) (errno syscall.Errno) { + defer handlePanicErrno(&errno) + if h.f == nil { + return syscall.EBADF + } + ctx = ctxloadingcache.With(ctx, h.cache) + + err := h.f.Close(ctx) + h.f = nil + h.r = nil + h.cache = nil + return errToErrno(err) +} + +type fileReaderSpliceio struct{ spliceio.ReaderAt } + +func (r fileReaderSpliceio) Read( + ctx context.Context, + dest []byte, + off int64, +) (_ fuse.ReadResult, errno syscall.Errno) { + fd, fdSize, fdOff, err := r.SpliceReadAt(ctx, len(dest), off) + if err != nil { + return nil, errToErrno(err) + } + return fuse.ReadResultFd(fd, fdOff, fdSize), fs.OK +} + +type fileReaderIoctx struct{ ioctx.ReaderAt } + +func (r fileReaderIoctx) Read( + ctx context.Context, + dest []byte, + off int64, +) (_ fuse.ReadResult, errno syscall.Errno) { + n, err := r.ReadAt(ctx, dest, off) + if err == io.EOF { + err = nil + } + return fuse.ReadResultData(dest[:n]), errToErrno(err) +} + +type ( + // Writable is the interface that must be implemented by files returned by + // (fsnode.Leaf).OpenFile to support writing. + Writable interface { + WriteAt(ctx context.Context, p []byte, off int64) (n int, err error) + Truncate(ctx context.Context, n int64) error + // Flush is called on (FileFlusher).Flush, i.e. on the close(2) call on + // a file descriptor. Implementors can assume that no writes happen + // between Flush and (fsctx.File).Close. + Flush(ctx context.Context) error + Fsync(ctx context.Context) error + } + // handle is an implementation of fs.FileHandle that wraps an fsctx.File. + // The behavior of the handle depends on the functions implemented by the + // fsctx.File value. + handle struct { + f fsctx.File + r fs.FileReader + w Writable + cache *loadingcache.Map + } +) + +var ( + _ fs.FileFlusher = (*handle)(nil) + _ fs.FileFsyncer = (*handle)(nil) + _ fs.FileGetattrer = (*handle)(nil) + _ fs.FileLseeker = (*handle)(nil) + _ fs.FileReader = (*handle)(nil) + _ fs.FileReleaser = (*handle)(nil) + _ fs.FileSetattrer = (*handle)(nil) + _ fs.FileWriter = (*handle)(nil) +) + +func (h handle) Getattr(ctx context.Context, out *fuse.AttrOut) (errno syscall.Errno) { + defer handlePanicErrno(&errno) + if h.f == nil { + return syscall.EBADF + } + ctx = ctxloadingcache.With(ctx, h.cache) + info, err := h.f.Stat(ctx) + if err != nil { + return errToErrno(err) + } + if statT := fuse.ToStatT(info); statT != nil { + // Stat returned a *syscall.Stat_t, so just plumb that through. + out.FromStat(statT) + } else { + setAttrFromFileInfo(&out.Attr, info) + } + out.SetTimeout(getCacheTimeout(h.f)) + return fs.OK +} + +func (h handle) Setattr( + ctx context.Context, + in *fuse.SetAttrIn, + out *fuse.AttrOut, +) (errno syscall.Errno) { + defer handlePanicErrno(&errno) + if h.f == nil { + return syscall.EBADF + } + if h.w == nil { + return syscall.ENOSYS + } + h.cache.DeleteAll() + if usize, ok := in.GetSize(); ok { + return errToErrno(h.w.Truncate(ctx, int64(usize))) + } + return fs.OK +} + +func (h handle) Read( + ctx context.Context, + dst []byte, + off int64, +) (_ fuse.ReadResult, errno syscall.Errno) { + defer handlePanicErrno(&errno) + if h.f == nil { + return nil, syscall.EBADF + } + if h.r == nil { + return nil, syscall.ENOSYS + } + ctx = ctxloadingcache.With(ctx, h.cache) + return h.r.Read(ctx, dst, off) +} + +func (h handle) Lseek( + ctx context.Context, + off uint64, + whence uint32, +) (_ uint64, errno syscall.Errno) { + defer handlePanicErrno(&errno) + if h.f == nil { + return 0, syscall.EBADF + } + // We expect this to only be called with {SEEK_DATA,SEEK_HOLE}. + // https://github.com/torvalds/linux/blob/v5.13/fs/fuse/file.c#L2619-L2648 + const ( + // Copied from https://github.com/torvalds/linux/blob/v5.13/include/uapi/linux/fs.h#L46-L47 + SEEK_DATA = 3 + SEEK_HOLE = 4 + ) + switch whence { + case SEEK_DATA: + return off, fs.OK // We don't support holes so current offset is correct. + case SEEK_HOLE: + info, err := h.f.Stat(ctx) + if err != nil { + return 0, errToErrno(err) + } + return uint64(info.Size()), fs.OK + } + return 0, syscall.ENOTSUP +} + +func (h handle) Write( + ctx context.Context, + p []byte, + off int64, +) (_ uint32, errno syscall.Errno) { + defer handlePanicErrno(&errno) + if h.f == nil { + return 0, syscall.EBADF + } + if h.w == nil { + return 0, syscall.ENOSYS + } + ctx = ctxloadingcache.With(ctx, h.cache) + n, err := h.w.WriteAt(ctx, p, off) + return uint32(n), errToErrno(err) +} + +func (h handle) Flush(ctx context.Context) (errno syscall.Errno) { + defer handlePanicErrno(&errno) + if h.f == nil { + return syscall.EBADF + } + if h.w == nil { + return fs.OK + } + ctx = ctxloadingcache.With(ctx, h.cache) + err := h.w.Flush(ctx) + return errToErrno(err) +} + +func (h handle) Fsync(ctx context.Context, flags uint32) (errno syscall.Errno) { + defer handlePanicErrno(&errno) + if h.f == nil { + return syscall.EBADF + } + if h.w == nil { + return fs.OK + } + ctx = ctxloadingcache.With(ctx, h.cache) + err := h.w.Fsync(ctx) + return errToErrno(err) +} + +func (h *handle) Release(ctx context.Context) (errno syscall.Errno) { + defer handlePanicErrno(&errno) + if h.f == nil { + return syscall.EBADF + } + ctx = ctxloadingcache.With(ctx, h.cache) + err := h.f.Close(ctx) + h.f = nil + h.r = nil + h.w = nil + h.cache = nil + return errToErrno(err) +} diff --git a/file/fsnodefuse/inode.go b/file/fsnodefuse/inode.go new file mode 100644 index 00000000..bbf521e4 --- /dev/null +++ b/file/fsnodefuse/inode.go @@ -0,0 +1,26 @@ +package fsnodefuse + +import ( + "github.com/grailbio/base/file/fsnode" + "github.com/grailbio/base/log" + "github.com/hanwen/go-fuse/v2/fs" +) + +// setFSNode updates inode to be backed by fsNode. The caller must ensure that +// inode and fsNode are compatible: +// *dirInode <-> fsnode.Parent +// *regInode <-> fsnode.Leaf +func setFSNode(inode *fs.Inode, fsNode fsnode.T) { + switch embed := inode.Operations().(type) { + case *dirInode: + embed.mu.Lock() + embed.n = fsNode.(fsnode.Parent) + embed.mu.Unlock() + case *regInode: + embed.mu.Lock() + embed.n = fsNode.(fsnode.Leaf) + embed.mu.Unlock() + default: + log.Panicf("unexpected inodeEmbedder: %T", embed) + } +} diff --git a/file/fsnodefuse/readdirplus_test.go b/file/fsnodefuse/readdirplus_test.go new file mode 100644 index 00000000..f7489d56 --- /dev/null +++ b/file/fsnodefuse/readdirplus_test.go @@ -0,0 +1,148 @@ +package fsnodefuse + +import ( + "context" + "fmt" + "math/rand" + "os" + "sync/atomic" + "testing" + "time" + + "github.com/grailbio/base/file/fsnode" + "github.com/grailbio/testutil" + "github.com/hanwen/go-fuse/v2/fs" + "github.com/hanwen/go-fuse/v2/fuse" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +// TestReaddirplus verifies that servicing a READDIRPLUS request does not +// trigger calls to (fsnode.Parent).Child. Note that this test uses +// (*os.File).Readdirnames to trigger the READDIRPLUS request. +func TestReaddirplus(t *testing.T) { + const NumChildren = 1000 + children := makeTestChildren(NumChildren) + root := newParent("root", children) + withMounted(t, root, func(mountDir string) { + err := checkDir(t, children, mountDir) + require.NoError(t, err) + assert.Equal(t, int64(0), root.childCalls) + }) +} + +// TestReaddirplusConcurrent verifies that servicing many concurrent +// READDIRPLUS requests does not trigger any calls to (fsnode.Parent).Child. +// Note that this test uses (*os.File).Readdirnames to trigger the READDIRPLUS +// requests. +func TestReaddirplusConcurrent(t *testing.T) { + const ( + NumIter = 20 + MaxNumChildren = 1000 + MaxConcurrentReaddirs = 100 + ) + // Note that specifying a constant seed does not make this test + // deterministic, as the concurrent READDIRPLUS callers race + // non-deterministically. + r := rand.New(rand.NewSource(0)) + for i := 0; i < NumIter; i++ { + var ( + numChildren = r.Intn(MaxNumChildren-1) + 1 + concurrentReaddirs = r.Intn(MaxConcurrentReaddirs-2) + 2 + children = makeTestChildren(numChildren) + root = newParent("root", children) + ) + t.Run(fmt.Sprintf("iter%02d", i), func(t *testing.T) { + t.Logf( + "numChildren=%d concurrentReaddirs=%d", + numChildren, + concurrentReaddirs, + ) + withMounted(t, root, func(mountDir string) { + var grp errgroup.Group + for j := 0; j < concurrentReaddirs; j++ { + grp.Go(func() error { + return checkDir(t, children, mountDir) + }) + } + require.NoError(t, grp.Wait()) + assert.Equal(t, int64(0), root.childCalls) + }) + }) + } +} + +func makeTestChildren(n int) []fsnode.T { + children := make([]fsnode.T, n) + for i := range children { + children[i] = fsnode.ConstLeaf( + fsnode.NewRegInfo(fmt.Sprintf("%04d", i)), + []byte{}, + ) + } + return children +} + +// withMounted sets up and tears down a FUSE mount for root. +// f is called with the path where root is mounted. +func withMounted(t *testing.T, root fsnode.T, f func(rootPath string)) { + mountDir, cleanUp := testutil.TempDir(t, "", "fsnodefuse-testreaddirplus") + defer cleanUp() + server, err := fs.Mount(mountDir, NewRoot(root), &fs.Options{ + MountOptions: fuse.MountOptions{ + FsName: "test", + DisableXAttrs: true, + }, + }) + require.NoError(t, err, "mounting %q", mountDir) + defer func() { + assert.NoError(t, server.Unmount(), + "unmount of FUSE mounted at %q failed; may need manual cleanup", + mountDir, + ) + }() + f(mountDir) +} + +func checkDir(t *testing.T, children []fsnode.T, path string) (err error) { + var want []string + for _, c := range children { + want = append(want, c.Info().Name()) + } + f, err := os.Open(path) + if err != nil { + return err + } + defer func() { assert.NoError(t, f.Close()) }() + // Use Readdirnames instead of Readdir because Readdir adds an extra call + // lstat outside of the READDIRPLUS operation. + got, err := f.Readdirnames(0) + // Sanity check that the names of the entries match the children. + assert.ElementsMatch(t, want, got) + return err +} + +type parent struct { + fsnode.Parent + childCalls int64 +} + +func (p *parent) Child(ctx context.Context, name string) (fsnode.T, error) { + atomic.AddInt64(&p.childCalls, 1) + return p.Parent.Child(ctx, name) +} + +// CacheableFor implements fsnode.Cacheable. +func (p parent) CacheableFor() time.Duration { + return fsnode.CacheableFor(p.Parent) +} + +func newParent(name string, children []fsnode.T) *parent { + return &parent{ + Parent: fsnode.NewParent( + fsnode.NewDirInfo("root"), + fsnode.ConstChildren(children...), + ), + } +} diff --git a/file/fsnodefuse/reg.go b/file/fsnodefuse/reg.go new file mode 100644 index 00000000..c2bdd227 --- /dev/null +++ b/file/fsnodefuse/reg.go @@ -0,0 +1,144 @@ +package fsnodefuse + +import ( + "context" + "io" + "os" + "sync" + "syscall" + + "github.com/grailbio/base/errors" + "github.com/grailbio/base/file/fsnode" + "github.com/grailbio/base/ioctx" + "github.com/grailbio/base/log" + "github.com/grailbio/base/sync/loadingcache" + "github.com/grailbio/base/sync/loadingcache/ctxloadingcache" + "github.com/hanwen/go-fuse/v2/fs" + "github.com/hanwen/go-fuse/v2/fuse" +) + +// TODO: Fix BXDS-1029. Without this, readers of non-constant files may see staleness and +// concurrent readers of such files may see corruption. +type regInode struct { + fs.Inode + cache loadingcache.Map + + mu sync.Mutex + n fsnode.Leaf + + // defaultSize is a shared record of file size for all file handles of type sizingHandle + // created for this inode. The first sizingHandle to reach EOF (for its io.Reader) sets + // defaultSizeKnown and defaultSize and after that all other handles will return the same size + // from Getattr calls. + // + // sizingHandle returns incorrect size information until the underlying Reader reaches EOF. The + // kernel issues concurrent reads to prepopulate the page cache, for performance, and also + // interleaves Getattr calls to confirm where EOF really is. Complicating matters, multiple open + // handles share the page cache, allowing situations where one handle has populated the page + // cache, reached EOF, and knows the right size, whereas another handle's Reader is not there + // yet so it continues to use the fake size (which we may choose to be some giant number so + // users keep going until the end). This seems to cause bugs where user programs think they got + // real data past EOF (which is probably just padded/zeros). + // + // To avoid this problem, all open sizingHandles share a size value, after first EOF. + // TODO: Document more loudly the requirement that fsnode.Leaf.Open's files must return + // identical data (same size, same bytes) to avoid corrupt page cache interactions. + // + // TODO: Investigate more thoroughly or at least with a newer kernel (this was observed on + // 4.15.0-1099-aws). + defaultSizeMu sync.RWMutex + defaultSizeKnown bool + defaultSize int64 +} + +var ( + _ fs.InodeEmbedder = (*regInode)(nil) + + _ fs.NodeReadlinker = (*regInode)(nil) + _ fs.NodeOpener = (*regInode)(nil) + _ fs.NodeGetattrer = (*regInode)(nil) + _ fs.NodeSetattrer = (*regInode)(nil) +) + +func (n *regInode) Open(ctx context.Context, inFlags uint32) (_ fs.FileHandle, outFlags uint32, errno syscall.Errno) { + defer handlePanicErrno(&errno) + ctx = ctxloadingcache.With(ctx, &n.cache) + file, err := n.n.OpenFile(ctx, int(inFlags)) + if err != nil { + return nil, 0, errToErrno(err) + } + h, err := makeHandle(n, inFlags, file) + return h, 0, errToErrno(err) +} + +func (n *regInode) Readlink(ctx context.Context) (_ []byte, errno syscall.Errno) { + defer handlePanicErrno(&errno) + ctx = ctxloadingcache.With(ctx, &n.cache) + file, err := n.n.OpenFile(ctx, 0) + if err != nil { + return nil, errToErrno(err) + } + defer func() { + if errClose := file.Close(ctx); errClose != nil && errno == fs.OK { + errno = errToErrno(errClose) + } + }() + content, err := io.ReadAll(ioctx.ToStdReader(ctx, file)) + if err != nil { + return nil, errToErrno(err) + } + return content, fs.OK +} + +func (n *regInode) Getattr(ctx context.Context, h fs.FileHandle, a *fuse.AttrOut) (errno syscall.Errno) { + defer handlePanicErrno(&errno) + ctx = ctxloadingcache.With(ctx, &n.cache) + + if h != nil { + if hg, ok := h.(fs.FileGetattrer); ok { + return hg.Getattr(ctx, a) + } + } + + setAttrFromFileInfo(&a.Attr, n.n.Info()) + a.SetTimeout(getCacheTimeout(n.n)) + return fs.OK +} + +func (n *regInode) Setattr(ctx context.Context, h fs.FileHandle, in *fuse.SetAttrIn, a *fuse.AttrOut) (errno syscall.Errno) { + defer handlePanicErrno(&errno) + if h, ok := h.(fs.FileSetattrer); ok { + return h.Setattr(ctx, in, a) + } + if usize, ok := in.GetSize(); ok { + if usize != 0 { + // We only support setting the size to 0. + return syscall.ENOTSUP + } + err := func() (err error) { + f, err := n.n.OpenFile(ctx, os.O_WRONLY|os.O_TRUNC) + if err != nil { + return errToErrno(err) + } + defer errors.CleanUpCtx(ctx, f.Close, &err) + w, ok := f.(Writable) + if !ok { + return syscall.ENOTSUP + } + return w.Flush(ctx) + }() + if err != nil { + return errToErrno(err) + } + } + n.cache.DeleteAll() + if errno := n.NotifyContent(0 /* offset */, 0 /* len, zero means all */); errno != fs.OK { + log.Error.Printf("regInode.Setattr %s: error from NotifyContent: %v", n.Path(nil), errno) + return errToErrno(errno) + } + // TODO(josh): Is this the right invalidation, and does it work? Maybe page cache only matters + // if we set some other flags in open or read to enable it? + setAttrFromFileInfo(&a.Attr, n.n.Info()) + a.SetTimeout(getCacheTimeout(n.n)) + return fs.OK +} diff --git a/file/fsnodefuse/trailingbuf/trailbuf_test.go b/file/fsnodefuse/trailingbuf/trailbuf_test.go new file mode 100644 index 00000000..2e00288a --- /dev/null +++ b/file/fsnodefuse/trailingbuf/trailbuf_test.go @@ -0,0 +1,71 @@ +package trailingbuf + +import ( + "context" + "io" + "strings" + "testing" + + "github.com/grailbio/base/ioctx" + "github.com/stretchr/testify/require" +) + +func TestBasic(t *testing.T) { + const src = "0123456789" + r := New(ioctx.FromStdReader(strings.NewReader(src)), 0, 2) + ctx := context.Background() + + // Initial read, larger than trail buf. + b := make([]byte, 3) + n, err := r.ReadAt(ctx, b, 0) + require.NoError(t, err) + require.Equal(t, len(b), n) + require.Equal(t, "012", string(b)) + + // Backwards by a little. + b = make([]byte, 2) + n, err = r.ReadAt(ctx, b, 1) + require.NoError(t, err) + require.Equal(t, len(b), n) + require.Equal(t, "12", string(b)) + + // Backwards by too much. + b = make([]byte, 1) + _, err = r.ReadAt(ctx, b, 0) + require.Error(t, err) + require.ErrorIs(t, err, ErrTooFarBehind) + + // Jump forward. Discards (not visible; exercising internal paths) because we skip off=3. + b = make([]byte, 2) + n, err = r.ReadAt(ctx, b, 4) + require.NoError(t, err) + require.Equal(t, len(b), n) + require.Equal(t, "45", string(b)) + + // Forwards, overlapping. + b = make([]byte, 4) + n, err = r.ReadAt(ctx, b, 4) + require.NoError(t, err) + require.Equal(t, len(b), n) + require.Equal(t, "4567", string(b)) + + // Jump again. + b = make([]byte, 5) + n, err = r.ReadAt(ctx, b, 9) + if err != io.EOF { + require.NoError(t, err) + } + require.Equal(t, 1, n) + require.Equal(t, "9", string(b[:1])) + + // Make sure we can still read backwards after EOF. + b = make([]byte, 1) + n, err = r.ReadAt(ctx, b, 8) + if err != io.EOF { + require.NoError(t, err) + } + require.Equal(t, len(b), n) + require.Equal(t, "8", string(b)) +} + +// TODO: Randomized, concurrent tests. diff --git a/file/fsnodefuse/trailingbuf/trailingbuf.go b/file/fsnodefuse/trailingbuf/trailingbuf.go new file mode 100644 index 00000000..13c558ad --- /dev/null +++ b/file/fsnodefuse/trailingbuf/trailingbuf.go @@ -0,0 +1,141 @@ +package trailingbuf + +import ( + "context" + "fmt" + "io" + + "github.com/grailbio/base/errors" + "github.com/grailbio/base/file/internal/s3bufpool" + "github.com/grailbio/base/ioctx" + "github.com/grailbio/base/morebufio" + "github.com/grailbio/base/must" +) + +// ErrTooFarBehind is returned if a read goes too far behind the current position. +// It's set as a cause (which callers can unwrap) on some errors returned by ReadAt. +var ErrTooFarBehind = errors.New("trailbuf: read too far behind") + +type ReaderAt struct { + // semaphore guards all subsequent fields. It's used to serialize operations. + semaphore chan struct{} + // r is the data source. + pr morebufio.PeekBackReader + // off is the number of bytes we've read from r. + off int64 + // eof is true after r returns io.EOF. + eof bool +} + +// New creates a ReaderAt that can respond to arbitrary reads as long as they're close +// to the current position. trailSize controls the max distance (controlling buffer space usage). +// Reads too far behind the current position return an error with cause ErrTooFarBehind. +// off is the current position of r (for example, zero for the start of a file, or non-zero for +// reading somewhere in the middle). +// Note: Alternatively, callers could manipulate the offsets in their ReadAt calls to be relative +// to r's initial position. However, since we put offsets in our error message strings, users may +// find debugging easier if they don't need to de-relativize the errors. +func New(r ioctx.Reader, off int64, trailSize int) *ReaderAt { + return &ReaderAt{ + semaphore: make(chan struct{}, 1), + pr: morebufio.NewPeekBackReader(r, trailSize), + off: off, + } +} + +// ReadAt implements io.ReaderAt. +func (r *ReaderAt) ReadAt(ctx context.Context, dst []byte, off int64) (int, error) { + if len(dst) == 0 { + return 0, nil + } + if off < 0 { + return 0, errors.E(errors.Invalid, "trailbuf: negative offset") + } + + select { + case r.semaphore <- struct{}{}: + defer func() { <-r.semaphore }() + case <-ctx.Done(): + return 0, ctx.Err() + } + + var nDst int + // Try to peek backwards from r.off, if requested. + if back := r.off - off; back > 0 { + peekBack := r.pr.PeekBack() + if back > int64(len(peekBack)) { + return nDst, errors.E(errors.Invalid, ErrTooFarBehind, + fmt.Sprintf("trailbuf: read would seek backwards: request %d(%d), current pos %d(-%d)", + off, len(dst), r.off, len(peekBack))) + } + peekUsed := copy(dst, peekBack[len(peekBack)-int(back):]) + dst = dst[peekUsed:] + nDst += int(peekUsed) + off += int64(peekUsed) + } + // If we're already at EOF (so there's not enough data to reach off), or len(dst) + // is small enough (off + len(dst) < r.off), we exit early. + // Otherwise, we've advanced the request offset up to the current cursor and need to + // read more of the underlying stream. + if r.eof { + return nDst, io.EOF + } + if len(dst) == 0 { + return nDst, nil + } + must.Truef(off >= r.off, "%d, %d", off, r.off) + + // Skip forward in r.pr, if necessary. + if skip := off - r.off; skip > 0 { + // Copying to io.Discard ends up using small chunks from an internal pool. This is a fairly + // pessimal S3 read size, so since we sometimes read from S3 streams here, we use larger + // buffers. + // + // Note that we may eventually want to use some internal read buffer for all S3 reads, so + // clients don't accidentally experience bad performance because their application happens + // to use a pattern of small reads. In that case, this special skip buffer would just add + // copies, and not help, and we may want to remove it. + discardBuf := s3bufpool.Get() + n, err := io.CopyBuffer( + // Hide io.Discard's io.ReadFrom implementation because CopyBuffer detects that and + // ignores our buffer. + struct{ io.Writer }{io.Discard}, + io.LimitReader(ioctx.ToStdReader(ctx, r.pr), skip), + *discardBuf) + s3bufpool.Put(discardBuf) + r.off += n + if n < skip { + r.eof = true + err = io.EOF + } + if err != nil { + return nDst, err + } + } + + // Complete the read. + n, err := io.ReadFull(ioctx.ToStdReader(ctx, r.pr), dst) + r.off += int64(n) + nDst += n + if err == io.EOF || err == io.ErrUnexpectedEOF { + err = io.EOF + r.eof = true + } + return nDst, err +} + +// Size returns the final number of bytes obtained from the underlying stream, if we've already +// found EOF, else _, false. +func (r *ReaderAt) Size(ctx context.Context) (size int64, known bool, err error) { + select { + case r.semaphore <- struct{}{}: + defer func() { <-r.semaphore }() + case <-ctx.Done(): + return 0, false, ctx.Err() + } + + if r.eof { + return r.off, true, nil + } + return 0, false, nil +} diff --git a/file/gfilefs/gfile.go b/file/gfilefs/gfile.go new file mode 100644 index 00000000..6f8d32ab --- /dev/null +++ b/file/gfilefs/gfile.go @@ -0,0 +1,527 @@ +// Copyright 2022 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +package gfilefs + +import ( + "bytes" + "context" + "fmt" + "io" + "io/ioutil" + "os" + "time" + + "github.com/grailbio/base/errors" + "github.com/grailbio/base/file" + "github.com/grailbio/base/file/fsnodefuse" + "github.com/grailbio/base/file/internal/readmatcher" + "github.com/grailbio/base/ioctx" + "github.com/grailbio/base/ioctx/fsctx" + "github.com/grailbio/base/sync/ctxsync" + "github.com/hanwen/go-fuse/v2/fuse" +) + +// gfile implements fsctx.File and fsnodefuse.Writable to represent open +// gfilefs files. +type gfile struct { + // n is the node for which this instance is an open file. + n *fileNode + // flag holds the flag bits specified when this file was opened. + flag int + + // readerAt is an optional ReaderAt implementation. It may only be set upon + // construction, and must not be modified later. Thus, it can be read by + // multiple goroutines without holding the lock, without a data race. + // When non-nil, it serves ReadAt requests concurrently, without ops. + // Otherwise, gfile.ReadAt uses ops.ReadAt. + readerAt ioctx.ReaderAt + + // mu provides mutually exclusive access to the fields below. + mu ctxsync.Mutex + // requestedSize is the size requested by Truncate. Note that we only + // really support truncation to 0 well, as it is mainly used by go-fuse for + // truncation when handling O_TRUNC. + requestedSize int64 + // flushed is true if there are no writes that need to be flushed. If + // flushed == true, Flush is a no-op. + flushed bool + // anyWritten tracks whether we have written any bytes to this file. We + // use this to decide whether we can use direct writing. + anyWritten bool + // ops handles underlying I/O operations. See ioOps. ops may be lazily + // populated, and it may be reassigned over the lifetime of the file, e.g. + // after truncation, we may switch to an ops that no longer uses a + // temporary file. + ops ioOps +} + +var ( + _ fsctx.File = (*gfile)(nil) + _ ioctx.ReaderAt = (*gfile)(nil) + _ fsnodefuse.Writable = (*gfile)(nil) +) + +// OpenFile opens the file at n and returns a *gfile representing it for file +// operations. +func OpenFile(ctx context.Context, n *fileNode, flag int) (*gfile, error) { + gf := &gfile{ + n: n, + flag: flag, + requestedSize: -1, + // Creation and truncation require flushing. + flushed: (flag&os.O_CREATE) == 0 && (flag&os.O_TRUNC) == 0, + } + if (flag & int(fuse.O_ANYWRITE)) == 0 { + // Read-only files are initialized eagerly, as it is cheap, and we can + // immediately return any errors. Writable files are initialized + // lazily; see lockedInitOps. + f, err := file.Open(ctx, n.path) + if err != nil { + return nil, err + } + dr := directRead{ + f: f, + matcher: readmatcher.New(f.OffsetReader), + r: f.Reader(context.Background()), // TODO: Tie to gf lifetime? + } + gf.ops = &dr + gf.readerAt = dr.matcher + return gf, nil + } + return gf, nil +} + +// Stat implements fsctx.File. +func (gf *gfile) Stat(ctx context.Context) (os.FileInfo, error) { + if err := gf.mu.Lock(ctx); err != nil { + return nil, err + } + defer gf.mu.Unlock() + if err := gf.lockedInitOps(ctx); err != nil { + return nil, err + } + info, err := gf.ops.Stat(ctx) + if err != nil { + if errors.Recover(err).Kind == errors.NotSupported { + return gf.n.Info(), nil + } + return nil, errors.E(err, "getting stat info from underlying I/O") + } + newInfo := gf.n.fsnodeInfo(). + WithModTime(info.ModTime()). + WithSize(info.Size()) + gf.n.setFsnodeInfo(newInfo) + return newInfo, nil +} + +// Read implements fsctx.File. +func (gf *gfile) Read(ctx context.Context, p []byte) (int, error) { + if err := gf.mu.Lock(ctx); err != nil { + return 0, err + } + defer gf.mu.Unlock() + if err := gf.lockedInitOps(ctx); err != nil { + return 0, err + } + return gf.ops.Read(ctx, p) +} + +// ReadAt implements ioctx.ReaderAt. +func (gf *gfile) ReadAt(ctx context.Context, p []byte, off int64) (int, error) { + if gf.readerAt != nil { + return gf.readerAt.ReadAt(ctx, p, off) + } + if err := gf.mu.Lock(ctx); err != nil { + return 0, err + } + defer gf.mu.Unlock() + if err := gf.lockedInitOps(ctx); err != nil { + return 0, err + } + return gf.ops.ReadAt(ctx, p, off) +} + +// WriteAt implements fsnodefuse.Writable. +func (gf *gfile) WriteAt(ctx context.Context, p []byte, off int64) (int, error) { + if err := gf.mu.Lock(ctx); err != nil { + return 0, err + } + defer gf.mu.Unlock() + if err := gf.lockedInitOps(ctx); err != nil { + return 0, err + } + n, err := gf.ops.WriteAt(ctx, p, off) + if err != nil { + return n, err + } + gf.anyWritten = true + gf.flushed = false + return n, err +} + +// Truncate implements fsnodefuse.Writable. +func (gf *gfile) Truncate(ctx context.Context, size int64) error { + if err := gf.mu.Lock(ctx); err != nil { + return err + } + defer gf.mu.Unlock() + gf.flushed = false + if gf.ops == nil { + gf.requestedSize = 0 + return nil + } + return gf.ops.Truncate(ctx, size) +} + +// Flush implements fsnodefuse.Writable. +func (gf *gfile) Flush(ctx context.Context) error { + if err := gf.mu.Lock(ctx); err != nil { + return err + } + defer gf.mu.Unlock() + return gf.lockedFlush() +} + +// Fsync implements fsnodefuse.Writable. +func (gf *gfile) Fsync(ctx context.Context) error { + // We treat Fsync as Flush, mostly because leaving it unimplemented + // (ENOSYS) breaks too many applications. + return gf.Flush(ctx) +} + +// Close implements fsctx.File. +func (gf *gfile) Close(ctx context.Context) error { + if err := gf.mu.Lock(ctx); err != nil { + return err + } + defer gf.mu.Unlock() + if gf.ops == nil { + return nil + } + return gf.ops.Close(ctx) +} + +// lockedInitOps initializes the ops that handle the underlying I/O operations +// of gf. This is done lazily in some cases, as it may be expensive, e.g. +// downloading a remotely stored file locally. Initialization may also depend +// on other operations, e.g. if the first manipulation is truncation, then we +// won't download existing data. gf.ops is non-nil iff lockedInitOps returns a +// nil error. The caller must have gf.mu locked. +func (gf *gfile) lockedInitOps(ctx context.Context) (err error) { + if gf.ops != nil { + return nil + } + // base/file does not expose an API to open a file for writing without + // creating it, so writing implies creation. + const tmpPattern = "gfilefs-" + var ( + rdwr = (gf.flag & os.O_RDWR) == os.O_RDWR + // Treat O_EXCL as O_TRUNC, as the file package does not support + // O_EXCL. + trunc = !gf.anyWritten && + (gf.requestedSize == 0 || + (gf.flag&os.O_TRUNC) == os.O_TRUNC || + (gf.flag&os.O_EXCL) == os.O_EXCL) + ) + switch { + case trunc && rdwr: + tmp, err := ioutil.TempFile("", tmpPattern) + if err != nil { + return errors.E(err, "making temp file") + } + gf.ops = &tmpIO{n: gf.n, f: tmp} + return nil + case trunc: + f, err := file.Create(ctx, gf.n.path) + if err != nil { + return errors.E(err, fmt.Sprintf("creating file at %q", gf.n.path)) + } + // This is a workaround for the fact that directWrite ops do not + // support Stat (as write-only s3files do not support Stat). Callers, + // e.g. fsnodefuse, may fall back to use the node's information, so we + // zero that to keep a sensible view. + gf.n.setFsnodeInfo(gf.n.fsnodeInfo().WithSize(0)) + gf.ops = &directWrite{ + n: gf.n, + f: f, + w: f.Writer(context.Background()), // TODO: Tie to gf lifetime? + off: 0, + } + return nil + default: + // existing reads out existing file contents. Contents may be empty if + // no file exists yet. + var existing io.Reader + f, err := file.Open(ctx, gf.n.path) + if err == nil { + existing = f.Reader(ctx) + } else { + if errors.Is(errors.NotExist, err) { + if !rdwr { + // Write-only and no existing file, so we can use direct + // I/O. + f, err = file.Create(ctx, gf.n.path) + if err != nil { + return errors.E(err, fmt.Sprintf("creating file at %q", gf.n.path)) + } + gf.ops = &directWrite{ + n: gf.n, + f: f, + w: f.Writer(context.Background()), // TODO: Tie to gf lifetime? + off: 0, + } + return nil + } + // No existing file, so there are no existing contents. + err = nil + existing = &bytes.Buffer{} + } else { + return errors.E(err, fmt.Sprintf("opening file for %q", gf.n.path)) + } + } + tmp, err := ioutil.TempFile("", tmpPattern) + if err != nil { + // fp was opened for reading, so don't worry about the error on + // Close. + _ = f.Close(ctx) + return errors.E(err, "making temp file") + } + _, err = io.Copy(tmp, existing) + if err != nil { + // We're going to report the copy error, so we treat closing as + // best-effort. + _ = f.Close(ctx) + _ = tmp.Close() + return errors.E(err, fmt.Sprintf("copying current contents to temp file %q", tmp.Name())) + } + gf.ops = &tmpIO{n: gf.n, f: tmp} + return nil + } +} + +// lockedFlush flushes writes to the backing write I/O state. The caller must +// have gf.mu locked. +func (gf *gfile) lockedFlush() (err error) { + // We use a background context when flushing as a workaround for handling + // interrupted operations, particularly from Go clients. As of Go 1.14, + // slow system calls may see more EINTR errors[1]. While most file + // operations are automatically retried[2], closing (which results in + // flushing) is not[3]. Ultimately, clients may see spurious, confusing + // failures calling (*os.File).Close. Given that it is extremely uncommon + // for callers to retry, we ignore interrupts to avoid the confusion. The + // significant downside is that intentional interruption, e.g. CTRL-C on a + // program that is taking too long, is also ignored, so processes can + // appear hung. + // + // TODO: Consider a better way of handling this problem. + // + // [1] https://go.dev/doc/go1.14#runtime + // [2] https://github.com/golang/go/commit/6b420169d798c7ebe733487b56ea5c3fa4aab5ce + // [3] https://github.com/golang/go/blob/go1.17.8/src/internal/poll/fd_unix.go#L79-L83 + ctx := context.Background() + if gf.flushed { + return nil + } + defer func() { + if err == nil { + gf.flushed = true + } + }() + if (gf.flag & int(fuse.O_ANYWRITE)) != 0 { + if err = gf.lockedInitOps(ctx); err != nil { + return err + } + } + reuseOps, err := gf.ops.Flush(ctx) + if err != nil { + return err + } + if !reuseOps { + gf.ops = nil + } + return nil +} + +// ioOps handles the underlying I/O operations for a *gfile. Implementations +// may directly call base/file or use a temporary file on local disk until +// flush. +type ioOps interface { + Stat(ctx context.Context) (file.Info, error) + Read(ctx context.Context, p []byte) (int, error) + ReadAt(ctx context.Context, p []byte, off int64) (int, error) + WriteAt(ctx context.Context, p []byte, off int64) (int, error) + Truncate(ctx context.Context, size int64) error + Flush(ctx context.Context) (reuseOps bool, _ error) + Close(ctx context.Context) error +} + +// directRead implements ioOps. It reads directly using base/file and does not +// support writes, e.g. to handle O_RDONLY. +type directRead struct { + f file.File + matcher interface { + ioctx.ReaderAt + ioctx.Closer + } + r io.ReadSeeker +} + +var _ ioOps = (*directRead)(nil) + +func (ops *directRead) Stat(ctx context.Context) (file.Info, error) { + return ops.f.Stat(ctx) +} + +func (ops *directRead) Read(ctx context.Context, p []byte) (int, error) { + return ops.r.Read(p) +} + +func (ops *directRead) ReadAt(ctx context.Context, p []byte, off int64) (_ int, err error) { + return ops.matcher.ReadAt(ctx, p, off) +} + +func (*directRead) WriteAt(ctx context.Context, p []byte, off int64) (int, error) { + return 0, errors.E(errors.Invalid, "writing read-only file") +} + +func (*directRead) Truncate(ctx context.Context, size int64) error { + return errors.E(errors.Invalid, "cannot truncate read-only file") +} + +func (*directRead) Flush(ctx context.Context) (reuseOps bool, _ error) { + return true, nil +} + +func (ops *directRead) Close(ctx context.Context) error { + err := ops.matcher.Close(ctx) + errors.CleanUpCtx(ctx, ops.f.Close, &err) + return err +} + +// directWrite implements ioOps. It writes directly using base/file and does +// not support reads, e.g. to handle O_WRONLY|O_TRUNC. +type directWrite struct { + n *fileNode + f file.File + w io.Writer + off int64 +} + +var _ ioOps = (*directWrite)(nil) + +func (ops directWrite) Stat(ctx context.Context) (file.Info, error) { + return ops.f.Stat(ctx) +} + +func (directWrite) Read(ctx context.Context, p []byte) (int, error) { + return 0, errors.E(errors.Invalid, "reading write-only file") +} + +func (directWrite) ReadAt(ctx context.Context, p []byte, off int64) (int, error) { + return 0, errors.E(errors.Invalid, "reading write-only file") +} + +func (ops *directWrite) WriteAt(ctx context.Context, p []byte, off int64) (int, error) { + if off != ops.off { + return 0, errors.E(errors.NotSupported, "non-contiguous write") + } + n, err := ops.w.Write(p) + ops.off += int64(n) + return n, err +} + +func (ops directWrite) Truncate(ctx context.Context, size int64) error { + if ops.off != size { + return errors.E(errors.NotSupported, "truncating to %d not supported by direct I/O") + } + return nil +} + +func (ops *directWrite) Flush(ctx context.Context) (reuseOps bool, _ error) { + err := ops.f.Close(ctx) + ops.n.setFsnodeInfo( + ops.n.fsnodeInfo(). + WithModTime(time.Now()). + WithSize(ops.off), + ) + // Clear to catch accidental reuse. + *ops = directWrite{} + return false, err +} + +func (ops directWrite) Close(ctx context.Context) error { + return ops.f.Close(ctx) +} + +// tmpIO implements ioOps. It is backed by a temporary local file, e.g. to +// handle O_RDWR. +type tmpIO struct { + n *fileNode + f *os.File // refers to a file in -tmp-dir. +} + +var _ ioOps = (*tmpIO)(nil) + +func (ops tmpIO) Stat(_ context.Context) (file.Info, error) { + return ops.f.Stat() +} + +func (ops tmpIO) Read(_ context.Context, p []byte) (int, error) { + return ops.f.Read(p) +} + +func (ops tmpIO) ReadAt(_ context.Context, p []byte, off int64) (int, error) { + return ops.f.ReadAt(p, off) +} + +func (ops tmpIO) WriteAt(_ context.Context, p []byte, off int64) (int, error) { + return ops.f.WriteAt(p, off) +} + +func (ops tmpIO) Truncate(_ context.Context, size int64) error { + return ops.f.Truncate(size) +} + +func (ops *tmpIO) Flush(ctx context.Context) (reuseOps bool, err error) { + dst, err := file.Create(ctx, ops.n.path) + if err != nil { + return false, errors.E(err, fmt.Sprintf("creating file %q", ops.n.path)) + } + defer file.CloseAndReport(ctx, dst, &err) + n, err := io.Copy(dst.Writer(ctx), &readerAdapter{r: ops.f}) + if err != nil { + return false, errors.E( + err, + fmt.Sprintf("copying from %q to %q", ops.f.Name(), ops.n.path), + ) + } + ops.n.setFsnodeInfo( + ops.n.fsnodeInfo(). + WithModTime(time.Now()). + WithSize(n), + ) + return true, nil +} + +// readerAdapter adapts an io.ReaderAt to be an io.Reader, calling ReadAt and +// maintaining the offset for the next Read. +type readerAdapter struct { + r io.ReaderAt + off int64 +} + +func (a *readerAdapter) Read(p []byte) (int, error) { + n, err := a.r.ReadAt(p, a.off) + a.off += int64(n) + return n, err +} + +func (ops *tmpIO) Close(_ context.Context) error { + err := ops.f.Close() + if errRemove := os.Remove(ops.f.Name()); errRemove != nil && err == nil { + err = errors.E(errRemove, "removing tmpIO file") + } + return err +} diff --git a/file/gfilefs/gfilefs.go b/file/gfilefs/gfilefs.go new file mode 100644 index 00000000..fd778144 --- /dev/null +++ b/file/gfilefs/gfilefs.go @@ -0,0 +1,214 @@ +// Copyright 2022 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +package gfilefs + +import ( + "context" + "os" + "sync/atomic" + "time" + + "github.com/grailbio/base/errors" + "github.com/grailbio/base/file" + "github.com/grailbio/base/file/fsnode" + "github.com/grailbio/base/grail/biofs/biofseventlog" + "github.com/grailbio/base/ioctx/fsctx" + "github.com/grailbio/base/log" + "github.com/grailbio/base/vcontext" + v23context "v.io/v23/context" +) + +// New returns a new parent node rooted at root. root must be a directory path +// that can be handled by github.com/grailbio/base/file. name will become the +// name of the returned node. +func New(root, name string) fsnode.Parent { + return newDirNode(root, name) +} + +const fileInfoCacheFor = 1 * time.Hour + +func newDirNode(path, name string) fsnode.Parent { + return dirNode{ + FileInfo: fsnode.NewDirInfo(name). + WithModePerm(0777). + WithCacheableFor(fileInfoCacheFor). + // TODO: Remove after updating fragments to support fsnode.T directly. + WithSys(path), + path: path, + } +} + +// dirNode implements fsnode.Parent and represents a directory. +type dirNode struct { + fsnode.ParentReadOnly + fsnode.FileInfo + path string +} + +var ( + _ fsnode.Parent = dirNode{} + _ fsnode.Cacheable = dirNode{} +) + +// Child implements fsnode.Parent. +func (d dirNode) Child(ctx context.Context, name string) (fsnode.T, error) { + log.Debug.Printf("gfilefs child name=%s", name) + biofseventlog.UsedFeature("gfilefs.dir.child") + var ( + path = file.Join(d.path, name) + child fsnode.T + ) + vctx := v23context.FromGoContextWithValues(ctx, vcontext.Background()) + lister := file.List(vctx, path, true /* recursive */) + // Look for either a file or a directory at this path. If both exist, + // assume file is a directory marker. + // TODO: Consider making base/file API more ergonomic for file and + // directory name collisions, e.g. by making it easy to systematically + // shadow one. + for lister.Scan() { + if lister.IsDir() || // We've found an exact match, and it's a directory. + lister.Path() != path { // We're seeing children, so path must be a directory. + child = newDirNode(path, name) + break + } + child = newFileNode(path, toRegInfo(name, lister.Info())) + } + if err := lister.Err(); err != nil { + return nil, errors.E(err, "scanning", path) + } + if child == nil { + return nil, errors.E(errors.NotExist, path, "not found") + } + return child, nil +} + +// Children implements fsnode.Parent. +func (d dirNode) Children() fsnode.Iterator { + biofseventlog.UsedFeature("gfilefs.dir.children") + return fsnode.NewLazyIterator(d.generateChildren) +} + +// AddChildLeaf implements fsnode.Parent. +func (d dirNode) AddChildLeaf( + ctx context.Context, + name string, + flags uint32, +) (fsnode.Leaf, fsctx.File, error) { + biofseventlog.UsedFeature("gfilefs.dir.addLeaf") + path := file.Join(d.path, name) + info := fsnode.NewRegInfo(name). + WithModePerm(0444). + WithCacheableFor(fileInfoCacheFor) + n := newFileNode(path, info) + f, err := n.OpenFile(ctx, int(flags)) + if err != nil { + return nil, nil, errors.E(err, "creating file") + } + return n, f, nil +} + +// AddChildParent implements fsnode.Parent. +func (d dirNode) AddChildParent(_ context.Context, name string) (fsnode.Parent, error) { + biofseventlog.UsedFeature("gfilefs.dir.addParent") + // TODO: Consider supporting directories better in base/file, maybe with + // some kind of directory marker. + path := file.Join(d.path, name) + return newDirNode(path, name), nil +} + +// RemoveChild implements fsnode.Parent. +func (d dirNode) RemoveChild(ctx context.Context, name string) error { + biofseventlog.UsedFeature("gfilefs.rmChild") + return file.Remove(ctx, file.Join(d.path, name)) +} + +func (d dirNode) FSNodeT() {} + +func (d dirNode) generateChildren(ctx context.Context) ([]fsnode.T, error) { + var ( + // byName is keyed by child name and is used to handle duplicate names + // we may get when scanning, i.e. if there is a directory and file with + // the same name (which is possible in S3). + byName = make(map[string]fsnode.T) + vctx = v23context.FromGoContextWithValues(ctx, vcontext.Background()) + lister = file.List(vctx, d.path, false) + ) + for lister.Scan() { + var ( + childPath = lister.Path() + name = file.Base(childPath) + ) + // Resolve duplicates by preferring the directory and shadowing the + // file. This should be kept consistent with the behavior of Child. + // We do not expect multiple files or directories with the same name, + // so behavior of that case is undefined. + if lister.IsDir() { + byName[name] = newDirNode(childPath, name) + } else if _, ok := byName[name]; !ok { + byName[name] = newFileNode(childPath, toRegInfo(name, lister.Info())) + } + } + if err := lister.Err(); err != nil { + return nil, errors.E(err, "listing", d.path) + } + children := make([]fsnode.T, 0, len(byName)) + for _, child := range byName { + children = append(children, child) + } + return children, nil +} + +type fileNode struct { + // path of the file this node represents. + path string + + // TODO: Consider expiring this info to pick up external changes (or fix + // possible inconsistency due to races?). + info atomic.Value // fsnode.FileInfo +} + +var ( + _ (fsnode.Cacheable) = (*fileNode)(nil) + _ (fsnode.Leaf) = (*fileNode)(nil) +) + +func newFileNode(path string, info fsnode.FileInfo) *fileNode { + // TODO: Remove after updating fragments to support fsnode.T directly. + info = info.WithSys(path) + n := fileNode{path: path} + n.info.Store(info) + return &n +} + +func (n *fileNode) Info() os.FileInfo { + return n.fsnodeInfo() +} + +func (fileNode) FSNodeT() {} + +func (n *fileNode) OpenFile(ctx context.Context, flag int) (fsctx.File, error) { + biofseventlog.UsedFeature("gfilefs.file.open") + return OpenFile(ctx, n, flag) +} + +func (n *fileNode) CacheableFor() time.Duration { + return fsnode.CacheableFor(n.fsnodeInfo()) +} + +func (n *fileNode) fsnodeInfo() fsnode.FileInfo { + return n.info.Load().(fsnode.FileInfo) +} + +func (n *fileNode) setFsnodeInfo(info fsnode.FileInfo) { + n.info.Store(info) +} + +func toRegInfo(name string, info file.Info) fsnode.FileInfo { + return fsnode.NewRegInfo(name). + WithModePerm(0666). + WithSize(info.Size()). + WithModTime(info.ModTime()). + WithCacheableFor(fileInfoCacheFor) +} diff --git a/file/gfilefs/write_test.go b/file/gfilefs/write_test.go new file mode 100644 index 00000000..f400696c --- /dev/null +++ b/file/gfilefs/write_test.go @@ -0,0 +1,425 @@ +// Copyright 2022 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +package gfilefs_test + +import ( + "context" + "flag" + gofs "io/fs" + "io/ioutil" + "log" + "math/rand" + "os" + "path/filepath" + "strings" + "sync" + "testing" + + "github.com/grailbio/base/errors" + "github.com/grailbio/base/file" + "github.com/grailbio/base/file/fsnodefuse" + "github.com/grailbio/base/file/gfilefs" + "github.com/grailbio/base/file/s3file" + "github.com/grailbio/testutil" + "github.com/hanwen/go-fuse/v2/fs" + "github.com/hanwen/go-fuse/v2/fuse" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func init() { + file.RegisterImplementation("s3", func() file.Implementation { + return s3file.NewImplementation( + s3file.NewDefaultProvider(), s3file.Options{}, + ) + }) +} + +// s3RootFlag sets an S3 root directory to use for test files. When set to a +// non-empty S3 path, e.g. "s3://some-bucket/some-writable/prefix", tests will +// run with a mount point with this root. These tests will run in addition to +// the normal local root testing. +var s3RootFlag = flag.String( + "s3-root", + "", + "optional S3 root directory to use for testing, in addition to the local root", +) + +// TestCreateEmpty verifies that we can create an new empty file using various +// flag parameters when opening, e.g. O_TRUNC. +func TestCreateEmpty(t *testing.T) { + flagElements := [][]int{ + {os.O_RDONLY, os.O_RDWR, os.O_WRONLY}, + {0, os.O_TRUNC}, + {0, os.O_EXCL}, + } + // combos produces the flag parameters to test (less O_CREATE, which is + // applied below). + var combos func(elems [][]int) []int + combos = func(elems [][]int) []int { + if len(elems) == 1 { + return elems[0] + } + var result []int + for _, elem := range elems[0] { + for _, flag := range combos(elems[1:]) { + flag |= elem + result = append(result, flag) + } + } + return result + } + // name generates a nice name for a subtest for a given flag. + name := func(flags int) string { + var ( + parts []string + access string + ) + switch { + case flags&os.O_RDWR == os.O_RDWR: + access = "RDWR" + case flags&os.O_WRONLY == os.O_WRONLY: + access = "WRONLY" + default: + access = "RDONLY" + } + parts = append(parts, access) + if flags&os.O_TRUNC == os.O_TRUNC { + parts = append(parts, "TRUNC") + } + if flags&os.O_EXCL == os.O_EXCL { + parts = append(parts, "EXCL") + } + return strings.Join(parts, "_") + } + for _, flag := range combos(flagElements) { + withTestMounts(t, func(m testMount) { + t.Run(name(flag), func(t *testing.T) { + path := filepath.Join(m.mountPoint, "test") + flag |= os.O_CREATE + f, err := os.OpenFile(path, flag, 0666) + require.NoError(t, err, "creating file") + require.NoError(t, f.Close(), "closing file") + + info, err := os.Stat(path) + require.NoError(t, err, "stat of file") + assert.Equal(t, int64(0), info.Size(), "file should have zero size") + + bs, err := ioutil.ReadFile(path) + require.NoError(t, err, "reading file") + assert.Empty(t, bs, "file should be empty") + }) + }) + } +} + +// TestCreate verifies that we can create a new file, write content to it, and +// read the same content back. +func TestCreate(t *testing.T) { + withTestMounts(t, func(m testMount) { + var ( + r = rand.New(rand.NewSource(0)) + path = filepath.Join(m.mountPoint, "test") + rootPath = file.Join(m.root, "test") + ) + assertRoundTrip(t, path, rootPath, r, 10*(1<<20)) + assertRoundTrip(t, path, rootPath, r, 10*(1<<16)) + }) +} + +// TestOverwrite verifies that we can overwrite the same file repeatedly, and +// that the updated content is correct. +func TestOverwrite(t *testing.T) { + withTestMounts(t, func(m testMount) { + const NumOverwrites = 20 + var ( + r = rand.New(rand.NewSource(0)) + path = filepath.Join(m.mountPoint, "test") + rootPath = file.Join(m.root, "test") + ) + for i := 0; i < NumOverwrites+1; i++ { + // Each iteration uses a random size between 5 and 10 MiB. + n := 5 + r.Intn(10) + n *= 1 << 20 + assertRoundTrip(t, path, rootPath, r, n) + } + }) +} + +// TestTruncFlag verifies that opening with O_TRUNC truncates the file. +func TestTruncFlag(t *testing.T) { + t.Run("WRONLY", func(t *testing.T) { + testTruncFlag(t, os.O_WRONLY) + }) + t.Run("RDWR", func(t *testing.T) { + testTruncFlag(t, os.O_RDWR) + }) +} + +func testTruncFlag(t *testing.T, flag int) { + withTestMounts(t, func(m testMount) { + path := filepath.Join(m.mountPoint, "test") + // Write the file we will truncate to test. + err := ioutil.WriteFile(path, []byte{0, 1, 2}, 0644) + require.NoError(t, err, "writing file") + + f, err := os.OpenFile(path, flag|os.O_TRUNC, 0666) + require.NoError(t, err, "opening for truncation") + func() { + defer func() { + require.NoError(t, f.Close()) + }() + var info gofs.FileInfo + info, err = f.Stat() + require.NoError(t, err, "getting file stats") + assert.Equal(t, int64(0), info.Size(), "truncated file should be zero bytes") + }() + + // Verify that reading the truncated file yields zero bytes. + bsRead, err := ioutil.ReadFile(path) + require.NoError(t, err, "reading truncated file") + assert.Empty(t, bsRead, "reading truncated file should yield no data") + }) +} + +// TestTruncateZero verifies that truncation to zero works. +func TestTruncateZero(t *testing.T) { + t.Run("WRONLY", func(t *testing.T) { + testTruncateZero(t, os.O_WRONLY) + }) + t.Run("RDWR", func(t *testing.T) { + testTruncateZero(t, os.O_RDWR) + }) +} + +func testTruncateZero(t *testing.T, flag int) { + withTestMounts(t, func(m testMount) { + path := filepath.Join(m.mountPoint, "test") + // Write the file we will truncate to test. + err := ioutil.WriteFile(path, []byte{0, 1, 2}, 0644) + require.NoError(t, err, "writing file") + + f, err := os.OpenFile(path, os.O_WRONLY, 0666) + require.NoError(t, err, "opening for truncation") + + func() { + defer func() { + require.NoError(t, f.Close(), "closing") + }() + // Sanity check that the initial file handle is the correct size. + var info gofs.FileInfo + info, err = f.Stat() + require.NoError(t, err, "getting file stats") + assert.Equal(t, int64(3), info.Size(), "file to truncate should be three bytes") + + require.NoError(t, f.Truncate(0), "truncating") + + // Verify that the file handle is actually truncated. + info, err = f.Stat() + require.NoError(t, err, "getting file stats") + assert.Equal(t, int64(0), info.Size(), "truncated file should be zero bytes") + }() + + // Verify that an independent stat shows zero size. + info, err := os.Stat(path) + require.NoError(t, err, "getting file stats") + assert.Equal(t, int64(0), info.Size(), "truncated file should be zero bytes") + + // Verify that reading the truncated file yields zero bytes. + bsRead, err := ioutil.ReadFile(path) + require.NoError(t, err, "reading truncated file") + assert.Empty(t, bsRead, "reading truncated file should yield no data") + }) +} + +// TestRemove verifies that we can remove a file. +func TestRemove(t *testing.T) { + withTestMounts(t, func(m testMount) { + var ( + r = rand.New(rand.NewSource(0)) + path = filepath.Join(m.mountPoint, "test") + rootPath = file.Join(m.root, "test") + ) + bs := make([]byte, 1*(1<<20)) + _, err := r.Read(bs) + require.NoError(t, err, "making random data") + err = ioutil.WriteFile(path, bs, 0644) + require.NoError(t, err, "writing file") + err = os.Remove(path) + require.NoError(t, err, "removing file") + _, err = os.Stat(path) + require.True(t, os.IsNotExist(err), "file was not removed") + _, err = os.Stat(rootPath) + require.True(t, os.IsNotExist(err), "file was not removed in root") + }) +} + +// TestDirListing verifies that the directory listing of a file is updated when +// the file is modified. +func TestDirListing(t *testing.T) { + withTestMounts(t, func(m testMount) { + path := file.Join(m.mountPoint, "test") + // assertSize asserts that the listed FileInfo of the file at path reports + // the given size. + assertSize := func(size int64) { + infos, err := ioutil.ReadDir(m.mountPoint) + require.NoError(t, err, "listing directory") + require.Equal(t, 1, len(infos), "should only be one file in directory") + assert.Equal(t, size, infos[0].Size(), "file should be 3 bytes") + } + + // Write a 3-byte file, and verify that its listing has the correct size. + require.NoError(t, ioutil.WriteFile(path, make([]byte, 3), 0644), "writing file") + assertSize(3) + + // Overwrite it to be 1 byte, and verify that the listing is updated. + require.NoError(t, ioutil.WriteFile(path, make([]byte, 1), 0644), "overwriting file") + assertSize(1) + + // Append 3 bytes, and verify that the listing is updated. + f, err := os.OpenFile(path, os.O_WRONLY|os.O_APPEND, 0644) + require.NoError(t, err, "opening file for append") + _, err = f.Write(make([]byte, 3)) + require.NoError(t, err, "appending to file") + require.NoError(t, f.Close(), "closing file") + assertSize(4) + }) +} + +// TestMkdir verifies that we can make a directory. +func TestMkdir(t *testing.T) { + withTestMounts(t, func(m testMount) { + var ( + r = rand.New(rand.NewSource(0)) + path = filepath.Join(m.mountPoint, "test-dir") + ) + err := os.Mkdir(path, 0775) + require.NoError(t, err, "making directory") + + filePath := filepath.Join(path, "test") + rootFilePath := file.Join(m.root, "test-dir", "test") + assertRoundTrip(t, filePath, rootFilePath, r, 0) + + info, err := os.Stat(path) + require.NoError(t, err, "getting file info of created directory") + require.True(t, info.IsDir(), "created directory is not a directory") + }) +} + +func withTestMounts(t *testing.T, f func(m testMount)) { + type makeRootFunc func(*testing.T) (string, func()) + makeRoots := map[string]makeRootFunc{ + "local": func(t *testing.T) (string, func()) { + return testutil.TempDir(t, "", "gfilefs-mnt") + }, + } + if *s3RootFlag != "" { + makeRoots["s3"] = func(t *testing.T) (string, func()) { + ctx := context.Background() + lister := file.List(ctx, *s3RootFlag, true) + exists := lister.Scan() + if exists { + t.Logf("path exists: %s", lister.Path()) + } + require.NoErrorf(t, lister.Err(), "listing %s", *s3RootFlag) + require.False(t, exists) + return *s3RootFlag, func() { + err := forEachFile(ctx, *s3RootFlag, func(path string) error { + return file.Remove(ctx, path) + }) + require.NoError(t, err, "cleaning up test root") + } + } + } + for name, makeRoot := range makeRoots { + t.Run(name, func(t *testing.T) { + root, rootCleanUp := makeRoot(t) + defer rootCleanUp() + mountPoint, mountPointCleanUp := testutil.TempDir(t, "", "gfilefs-mnt") + defer mountPointCleanUp() + server, err := fs.Mount( + mountPoint, + fsnodefuse.NewRoot(gfilefs.New(root, "root")), + // TODO: Set fsnodefuse.ConfigureRequiredMountOptions. + &fs.Options{ + MountOptions: fuse.MountOptions{ + FsName: "test", + DisableXAttrs: true, + Debug: true, + MaxBackground: 1024, + }, + }, + ) + require.NoError(t, err, "mounting %q", mountPoint) + defer func() { + log.Printf("unmounting %q", mountPoint) + assert.NoError(t, server.Unmount(), + "unmount of FUSE mounted at %q failed; may need manual cleanup", + mountPoint, + ) + log.Printf("unmounted %q", mountPoint) + }() + f(testMount{root: root, mountPoint: mountPoint}) + }) + } +} + +type testMount struct { + // root is the root path that is mounted at dir. + root string + // mountPoint is the FUSE mount point. + mountPoint string +} + +// forEachFile runs the callback for every file under the directory in +// parallel. It returns any of the errors returned by the callback. It is +// cribbed from github.com/grailbio/base/cmd/grail-file/cmd. +func forEachFile(ctx context.Context, dir string, callback func(path string) error) error { + const parallelism = 32 + err := errors.Once{} + wg := sync.WaitGroup{} + ch := make(chan string, parallelism*100) + for i := 0; i < parallelism; i++ { + wg.Add(1) + go func() { + for path := range ch { + err.Set(callback(path)) + } + wg.Done() + }() + } + + lister := file.List(ctx, dir, true /*recursive*/) + for lister.Scan() { + if !lister.IsDir() { + ch <- lister.Path() + } + } + close(ch) + err.Set(lister.Err()) + wg.Wait() + return err.Err() +} + +func assertRoundTrip(t *testing.T, path, rootPath string, r *rand.Rand, size int) { + bs := make([]byte, size) + _, err := r.Read(bs) + require.NoError(t, err, "making random data") + err = ioutil.WriteFile(path, bs, 0644) + require.NoError(t, err, "writing file") + + got, err := ioutil.ReadFile(path) + require.NoError(t, err, "reading file back") + assert.Equal(t, bs, got, "data read != data written") + + info, err := os.Stat(path) + require.NoError(t, err, "stat of file") + assert.Equal(t, int64(len(bs)), info.Size(), "len(data read) != len(data written)") + + // Verify that the file is written correctly to mounted root. + got, err = file.ReadFile(context.Background(), rootPath) + require.NoErrorf(t, err, "reading file in root %s back", rootPath) + assert.Equal(t, bs, got, "data read != data written") +} diff --git a/file/implementation.go b/file/implementation.go index 13f36798..267a1c45 100644 --- a/file/implementation.go +++ b/file/implementation.go @@ -8,6 +8,7 @@ import ( "context" "fmt" "sync" + "time" ) // Implementation implements operations for a file-system type. @@ -18,14 +19,29 @@ type Implementation interface { // Open opens a file for reading. The pathname given to file.Open() is passed // here unchanged. Thus, it contains the URL prefix such as "s3://". - Open(ctx context.Context, path string) (File, error) + // + // Open returns an error of kind errors.NotExist if there is + // no file at the provided path. + Open(ctx context.Context, path string, opts ...Opts) (File, error) // Create opens a file for writing. If "path" already exists, the old contents // will be destroyed. If "path" does not exist already, the file will be newly - // created. If the directory part of the path does not exist already, it will - // be created. The pathname given to file.Open() is passed here unchanged. + // created. The pathname given to file.Create() is passed here unchanged. // Thus, it contains the URL prefix such as "s3://". - Create(ctx context.Context, path string) (File, error) + // + // Creating a file with the same name as an existing directory is unspecified + // behavior and varies by implementation. Users are thus advised to avoid + // this if possible. + // + // For filesystem based storage engines (e.g. localfile), if the directory + // part of the path does not exist already, it will be created. If the path + // is a directory, an error will be returned. + // + // For key based storage engines (e.g. S3), it is OK to create a file that + // already exists as a common prefix for other objects, assuming a pseudo + // path separator. So both "foo" and "foo/bar" can be used as paths for + // creating regular files in the same storage. See List() for more context. + Create(ctx context.Context, path string, opts ...Opts) (File, error) // List finds files and directories. If "path" points to a regular file, the // lister will return information about the file itself and finishes. @@ -50,14 +66,28 @@ type Implementation interface { // Stat returns the file metadata. It returns nil if path is // a directory. (There is no direct test for existence of a // directory.) - Stat(ctx context.Context, path string) (Info, error) + // + // Stat returns an error of kind errors.NotExist if there is + // no file at the provided path. + Stat(ctx context.Context, path string, opts ...Opts) (Info, error) // Remove removes the file. The path passed to file.Remove() is passed here // unchanged. Remove(ctx context.Context, path string) error + + // Presign returns a URL that can be used to perform the given HTTP method, + // usually one of "GET", "PUT" or "DELETE", on the path for the duration + // specified in expiry. + // + // It returns an error of kind errors.NotSupported for implementations that + // do not support signed URLs, or that do not support the given HTTP method. + // + // Unlike Open and Stat, this method does not return an error of kind + // errors.NotExist if there is no file at the provided path. + Presign(ctx context.Context, path, method string, expiry time.Duration) (url string, err error) } -// Lister lists files in a directory tree. +// Lister lists files in a directory tree. Not thread safe. type Lister interface { // Scan advances the lister to the next entry. It returns // false either when the scan stops because we have reached the end of the input @@ -169,32 +199,38 @@ func findImpl(path string) (Implementation, error) { // Open opens the given file readonly. It is a shortcut for calling // ParsePath(), then FindImplementation, then Implementation.Open. -func Open(ctx context.Context, path string) (File, error) { +// +// Open returns an error of kind errors.NotExist if the file at the +// provided path does not exist. +func Open(ctx context.Context, path string, opts ...Opts) (File, error) { impl, err := findImpl(path) if err != nil { return nil, err } - return impl.Open(ctx, path) + return impl.Open(ctx, path, opts...) } // Create opens the given file writeonly. It is a shortcut for calling // ParsePath(), then FindImplementation, then Implementation.Create. -func Create(ctx context.Context, path string) (File, error) { +func Create(ctx context.Context, path string, opts ...Opts) (File, error) { impl, err := findImpl(path) if err != nil { return nil, err } - return impl.Create(ctx, path) + return impl.Create(ctx, path, opts...) } // Stat returns the give file's metadata. Is a shortcut for calling ParsePath(), // then FindImplementation, then Implementation.Stat. -func Stat(ctx context.Context, path string) (Info, error) { +// +// Stat returns an error of kind errors.NotExist if the file at the +// provided path does not exist. +func Stat(ctx context.Context, path string, opts ...Opts) (Info, error) { impl, err := findImpl(path) if err != nil { return nil, err } - return impl.Stat(ctx, path) + return impl.Stat(ctx, path, opts...) } type errorLister struct{ err error } @@ -236,3 +272,56 @@ func Remove(ctx context.Context, path string) error { } return impl.Remove(ctx, path) } + +// Presign is a shortcut for calling ParsePath(), then calling +// Implementation.Presign method. +func Presign(ctx context.Context, path, method string, expiry time.Duration) (string, error) { + impl, err := findImpl(path) + if err != nil { + return "", err + } + return impl.Presign(ctx, path, method, expiry) +} + +// Opts controls the file access requests, such as Open and Stat. +type Opts struct { + // When set, this flag causes the file package to keep retrying when the file + // is reported as not found. This flag should be set when: + // + // 1. you are accessing a file on S3, and + // + // 2. an application may have attempted to GET the same file in recent past + // (~5 minutes). The said application may be on a different machine. + // + // This flag is honored only by S3 to work around the problem where s3 may + // report spurious KeyNotFound error after a GET request to the same file. + // For more details, see + // https://docs.aws.amazon.com/AmazonS3/latest/dev/Introduction.html#CoreConcepts, + // section "S3 Data Consistency Model". In particular: + // + // The caveat is that if you make a HEAD or GET request to the key + // name (to find if the object exists) before creating the object, Amazon S3 + // provides eventual consistency for read-after-write. + RetryWhenNotFound bool + + // When set, Close will ignore NoSuchUpload error from S3 + // CompleteMultiPartUpload and silently returns OK. + // + // This is to work around a bug where concurrent uploads to one file sometimes + // causes an upload request to be lost on the server side. + // https://console.aws.amazon.com/support/cases?region=us-west-2#/6299905521/en + // https://github.com/yasushi-saito/s3uploaderror + // + // Set this flag only if: + // + // 1. you are writing to a file on S3, and + // + // 2. possible concurrent writes to the same file produce the same + // contents, so you are ok with taking any of them. + // + // If you don't set this flag, then concurrent writes to the same file may + // fail with a NoSuchUpload error, and it is up to you to retry. + // + // On non-S3 file systems, this flag is ignored. + IgnoreNoSuchUpload bool +} diff --git a/file/info.go b/file/info.go index 9ded0470..339a0b71 100644 --- a/file/info.go +++ b/file/info.go @@ -16,5 +16,4 @@ type Info interface { ModTime() time.Time // TODO: add attributes, in form map[string]interface{}. - } diff --git a/file/internal/kernel/BUILD.bazel b/file/internal/kernel/BUILD.bazel new file mode 100644 index 00000000..317dd099 --- /dev/null +++ b/file/internal/kernel/BUILD.bazel @@ -0,0 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "go_default_library", + srcs = ["kernel.go"], + importpath = "github.com/grailbio/base/file/internal/kernel", + visibility = ["//go/src/github.com/grailbio/base/file:__subpackages__"], +) diff --git a/file/internal/kernel/kernel.go b/file/internal/kernel/kernel.go new file mode 100644 index 00000000..f6b20f5a --- /dev/null +++ b/file/internal/kernel/kernel.go @@ -0,0 +1,13 @@ +package kernel + +// MaxReadAhead configures the kernel's maximum readahead for file handles on this FUSE mount +// (via ConfigureMount) and our corresponding "trailing" buffer. +// +// Our sizingHandle implements Read operations for read-only fsctx.File objects that don't support +// random access or seeking. Generally this requires that the user reading such a file does so +// in-order. However, the kernel attempts to optimize i/o speed by reading ahead into the page cache +// and to do so it can issue concurrent reads for a few blocks ahead of the user's current position. +// We respond to such requests from our trailing buffer. +// TODO: Choose a value more carefully. This value was chosen fairly roughly based on some +// articles/discussion that suggested this was a kernel default. +const MaxReadAhead = 512 * 1024 diff --git a/file/internal/readmatcher/readmatcher.go b/file/internal/readmatcher/readmatcher.go new file mode 100644 index 00000000..f470e05a --- /dev/null +++ b/file/internal/readmatcher/readmatcher.go @@ -0,0 +1,179 @@ +package readmatcher + +import ( + "context" + stderrors "errors" + "sync" + + "github.com/grailbio/base/errors" + "github.com/grailbio/base/file/fsnodefuse/trailingbuf" + "github.com/grailbio/base/file/internal/kernel" + "github.com/grailbio/base/ioctx" +) + +type ( + // TODO: Avoid somewhat hidden internal dependency on kernel.MaxReadAhead. + readerAt struct { + offsetReader func(int64) ioctx.ReadCloser + softMaxReaders int + + // mu guards the fields below. It's held while looking up a reader, but not during reads. + // TODO: Consider making this RWMutex. + mu sync.Mutex + // clock counts reader usages, increasing monotonically. Reader creation and usage is + // "timestamped" according to this clock, letting us prune least-recently-used. + clock int64 + // TODO: More efficient data structure. + readers readers + } + // readers is a collection of backend readers. It's ordered by createdAt: + // readers[i].createdAt < readers[j].createdAt iff i < j. + // Elements in the middle may be removed; then we just shift the tail forward by 1. + readers []*reader + reader struct { + // These fields are set at creation and never mutated. + ioctx.ReaderAt + ioctx.Closer + + // These fields are accessed only while holding the the parent readerAt's lock. + maxPos int64 + inUse int64 + lastUsedAt int64 + createdAt int64 + } +) + +const defaultMaxReaders = 1024 + +var ( + _ ioctx.ReaderAt = (*readerAt)(nil) + _ ioctx.Closer = (*readerAt)(nil) +) + +type Opt func(*readerAt) + +func SoftMaxReaders(n int) Opt { return func(r *readerAt) { r.softMaxReaders = n } } + +// New returns a ReaderAt that "multiplexes" incoming reads onto one of a collection of "backend" +// readers. It matches read to backend based on last read position; a reader is selected if its last +// request ended near where the new read starts. +// +// It is intended for use with biofs+S3. S3 readers have high initialization costs vs. +// subsequently reading bytes, because that is S3's performance characteristic. ReaderAt maps +// incoming reads to a backend S3 reader that may be able to efficiently serve it. Otherwise, it +// opens a new reader. Our intention is that this will adapt to non-S3-aware clients' read +// patterns (small reads). S3-aware clients can always choose to read big chunks to avoid +// performance worst-cases. But, the Linux kernel limits FUSE read requests to 128 KiB, and we +// can't feasibly change that, so we adapt. +// +// To performantly handle Linux kernel readahead requests, the matching algorithm allows +// out-of-order positions within a small window (see trailingbuf). +// +// offsetReader opens a reader into the underlying file. +func New(offsetReader func(int64) ioctx.ReadCloser, opts ...Opt) interface { + ioctx.ReaderAt + ioctx.Closer +} { + r := readerAt{offsetReader: offsetReader, softMaxReaders: defaultMaxReaders} + for _, opt := range opts { + opt(&r) + } + return &r +} + +func (m *readerAt) ReadAt(ctx context.Context, dst []byte, off int64) (int, error) { + var minCreatedAt int64 + for { + r := m.acquire(off, minCreatedAt) + n, err := r.ReadAt(ctx, dst, off) + m.release(r, off+int64(n)) + if err != nil && stderrors.Is(err, trailingbuf.ErrTooFarBehind) { + minCreatedAt = r.createdAt + 1 + continue + } + return n, err + } +} + +func (m *readerAt) acquire(off int64, minCreatedAt int64) *reader { + m.mu.Lock() + defer m.mu.Unlock() + for _, r := range m.readers { + if r.createdAt < minCreatedAt { + continue + } + if r.maxPos-kernel.MaxReadAhead <= off && off <= r.maxPos+kernel.MaxReadAhead { + r.inUse++ + r.lastUsedAt = m.clock + m.clock++ + return r + } + } + m.lockedGC() + rc := m.offsetReader(off) + r := &reader{ + ReaderAt: trailingbuf.New(rc, off, kernel.MaxReadAhead), + Closer: rc, + maxPos: off, + inUse: 1, + lastUsedAt: m.clock, + createdAt: m.clock, + } + m.clock++ + m.readers.add(r) + return r +} + +func (m *readerAt) release(r *reader, newPos int64) { + m.mu.Lock() + defer m.mu.Unlock() + if newPos > r.maxPos { + r.maxPos = newPos + } + r.inUse-- + m.lockedGC() +} + +func (m *readerAt) lockedGC() { + for len(m.readers) > m.softMaxReaders { + i, ok := m.readers.idleLeastRecentlyUsedIndex() + if !ok { + return + } + m.readers.remove(i) + } +} + +func (m *readerAt) Close(ctx context.Context) (err error) { + m.mu.Lock() + defer m.mu.Unlock() + for _, rc := range m.readers { + errors.CleanUpCtx(ctx, rc.Close, &err) + } + m.readers = nil + return +} + +func (rs *readers) add(r *reader) { + *rs = append(*rs, r) +} + +func (rs *readers) remove(i int) { + *rs = append((*rs)[:i], (*rs)[i+1:]...) +} + +func (rs *readers) idleLeastRecentlyUsedIndex() (int, bool) { + minIdx := -1 + for i, r := range *rs { + if r.inUse > 0 { + continue + } + if minIdx < 0 || r.lastUsedAt < (*rs)[minIdx].lastUsedAt { + minIdx = i + } + } + if minIdx < 0 { + return -1, false + } + return minIdx, true +} diff --git a/file/internal/readmatcher/readmatcher_test.go b/file/internal/readmatcher/readmatcher_test.go new file mode 100644 index 00000000..8e5fe038 --- /dev/null +++ b/file/internal/readmatcher/readmatcher_test.go @@ -0,0 +1,110 @@ +package readmatcher_test + +import ( + "bytes" + "flag" + "io" + "math/rand" + "os" + "path" + "runtime" + "testing" + + "github.com/grailbio/base/file/fsnode" + "github.com/grailbio/base/file/fsnodefuse" + "github.com/grailbio/base/file/internal/readmatcher" + "github.com/grailbio/base/file/internal/readmatcher/readmatchertest" + "github.com/grailbio/base/ioctx" + "github.com/grailbio/base/log" + "github.com/grailbio/base/must" + "github.com/grailbio/testutil" + "github.com/grailbio/testutil/assert" + "github.com/hanwen/go-fuse/v2/fs" + "github.com/hanwen/go-fuse/v2/fuse" + "github.com/stretchr/testify/require" +) + +var ( + dataBytes = flag.Int("data-bytes", 1<<27, "read corpus size") + stressParallelism = flag.Int("stress-parallelism", runtime.NumCPU(), + "number of parallel readers during stress test") + fuseFlag = flag.Bool("fuse", false, "create a temporary FUSE mount and test through that") +) + +func TestStress(t *testing.T) { + var data = make([]byte, *dataBytes) + _, _ = rand.New(rand.NewSource(1)).Read(data) + offsetReader := func(start int64) ioctx.ReadCloser { + return ioctx.FromStdReadCloser(io.NopCloser(bytes.NewReader(data[start:]))) + } + type fuseCase struct { + name string + test func(*testing.T, ioctx.ReaderAt) + } + fuseCases := []fuseCase{ + { + "nofuse", + func(t *testing.T, r ioctx.ReaderAt) { + readmatchertest.Stress(data, r, *stressParallelism) + }, + }, + } + if *fuseFlag { + fuseCases = append(fuseCases, fuseCase{ + "fuse", + func(t *testing.T, rAt ioctx.ReaderAt) { + mountPoint, cleanUpMountPoint := testutil.TempDir(t, "", "readmatcher_test") + defer cleanUpMountPoint() + const filename = "data" + server, err := fs.Mount( + mountPoint, + fsnodefuse.NewRoot(fsnode.NewParent( + fsnode.NewDirInfo("root"), + fsnode.ConstChildren( + fsnode.ReaderAtLeaf( + fsnode.NewRegInfo(filename).WithSize(int64(len(data))), + rAt, + ), + ), + )), + &fs.Options{ + MountOptions: func() fuse.MountOptions { + opts := fuse.MountOptions{FsName: "test", Debug: log.At(log.Debug)} + fsnodefuse.ConfigureRequiredMountOptions(&opts) + fsnodefuse.ConfigureDefaultMountOptions(&opts) + return opts + }(), + }, + ) + require.NoError(t, err, "mounting %q", mountPoint) + defer func() { + log.Printf("unmounting %q", mountPoint) + assert.NoError(t, server.Unmount(), + "unmount of FUSE mounted at %q failed; may need manual cleanup", + mountPoint, + ) + log.Printf("unmounted %q", mountPoint) + }() + f, err := os.Open(path.Join(mountPoint, filename)) + require.NoError(t, err) + defer func() { require.NoError(t, f.Close()) }() + readmatchertest.Stress(data, ioctx.FromStdReaderAt(f), *stressParallelism) + }, + }) + } + for _, c := range fuseCases { + t.Run(c.name, func(t *testing.T) { + t.Run("less parallelism", func(t *testing.T) { + readerParallelism := *stressParallelism / 2 + must.True(readerParallelism > 0) + m := readmatcher.New(offsetReader, readmatcher.SoftMaxReaders(readerParallelism)) + c.test(t, m) + }) + t.Run("more parallelism", func(t *testing.T) { + readerParallelism := 2 * *stressParallelism + m := readmatcher.New(offsetReader, readmatcher.SoftMaxReaders(readerParallelism)) + c.test(t, m) + }) + }) + } +} diff --git a/file/internal/readmatcher/readmatchertest/stress.go b/file/internal/readmatcher/readmatchertest/stress.go new file mode 100644 index 00000000..1cc20c70 --- /dev/null +++ b/file/internal/readmatcher/readmatchertest/stress.go @@ -0,0 +1,118 @@ +package readmatchertest + +import ( + "bytes" + "context" + "fmt" + "io" + "math/rand" + + "github.com/grailbio/base/file/internal/kernel" + "github.com/grailbio/base/ioctx" + "github.com/grailbio/base/must" + "github.com/grailbio/base/traverse" +) + +// Stress runs a stress test on rAt. +// If rAt is a readmatcher, consider setting parallelism below and above readmatcher.SoftMaxReaders +// to exercise those cases. +func Stress(want []byte, rAt ioctx.ReaderAt, parallelism int) { + ctx := context.Background() + + size := len(want) + // Only use rnd in the sequentially-executed task builders, not the parallelized actual reads. + rnd := rand.New(rand.NewSource(1)) + + const fuseReadSize = 1 << 17 // 128 KiB = FUSE max read size. + sequentialBuilder := func() task { + var t task + start, limit := randInterval(rnd, size) + for ; start < limit; start += fuseReadSize { + limit := start + fuseReadSize + if limit > size { + limit = size + } + t = append(t, read{start, limit}) + } + return t + } + taskBuilders := []func() task{ + // Read sequentially in FUSE-like chunks. + sequentialBuilder, + // Read some subset of the file, mostly sequentially, occasionally jumping. + // The jumps reorder reads within the bounds of kernel.MaxReadAhead. + // This is not quite the kernel readahead pattern because our reads are inherently + // sequential, whereas kernel readahead parallelizes. But, assuming we know that there is + // internal serialization somewhere, this at least simulates the variable ordering. + func() task { + // For simplicity, we choose to swap (or skip) adjacent pairs. Each item can only move 1 + // position, so the largest interval we can generate is changing a 1 read gap (adjacent) + // into a 3 read gap. + // If we allowed further, or second, swaps, we'd have to be careful about introducing + // longer gaps. Maybe we'll do that later. + must.True(kernel.MaxReadAhead >= 3*fuseReadSize) + t := sequentialBuilder() + for i := 0; i+1 < len(t); i += 2 { + if rnd.Intn(2) == 0 { + t[i], t[i+1] = t[i+1], t[i] + } + } + return t + }, + // Random reads covering some part of the data. + func() task { + t := sequentialBuilder() + rnd.Shuffle(len(t), func(i, j int) { t[i], t[j] = t[j], t[i] }) + return t[:rnd.Intn(len(t))] + }, + } + tasks := make([]task, parallelism*10) + for i := range tasks { + tasks[i] = taskBuilders[rnd.Intn(len(taskBuilders))]() + } + err := traverse.T{Limit: parallelism}.Each(len(tasks), func(i int) (err error) { + var dst []byte + for _, t := range tasks[i] { + readSize := t.limit - t.start + if cap(dst) < readSize { + dst = make([]byte, 2*readSize) + } + dst = dst[:readSize] + n, err := rAt.ReadAt(ctx, dst, int64(t.start)) + if err == io.EOF { + if n == readSize { + err = nil + } else { + err = fmt.Errorf("early EOF: %d, %v", n, t) + } + } + if err != nil { + return err + } + if !bytes.Equal(want[t.start:t.limit], dst) { + return fmt.Errorf("read mismatch: %v", t) + } + } + return nil + }) + must.Nil(err) +} + +type ( + read struct{ start, limit int } + task []read +) + +// randInterval returns a subset of [0, size). Interval selection is biased so that a substantial +// number of returned intervals will touch 0 and/or size. +func randInterval(rnd *rand.Rand, size int) (start, limit int) { + start = rnd.Intn(2*size) - size + if start < 0 { // Around half will start at 0. + start = 0 + } + limit = start + rnd.Intn(2*(size-start+1)) + if limit > size { // And around half read till the end. + limit = size + } + return +} diff --git a/file/internal/readmatcher/readmatchertest/stress_test.go b/file/internal/readmatcher/readmatchertest/stress_test.go new file mode 100644 index 00000000..244165aa --- /dev/null +++ b/file/internal/readmatcher/readmatchertest/stress_test.go @@ -0,0 +1,34 @@ +package readmatchertest + +import ( + "math/rand" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRandInterval(t *testing.T) { + const ( + size = 10000 + trials = 10000 + ) + var ( + rnd = rand.New(rand.NewSource(1)) + touchZero, touchSize int + ) + for i := 0; i < trials; i++ { + start, limit := randInterval(rnd, size) + require.GreaterOrEqual(t, start, 0) + require.LessOrEqual(t, limit, size) + if start == 0 { + touchZero++ + } + if limit == size { + touchSize++ + } + } + // 10% is a very loose constraint. We could be more precise, but we don't care too much. + assert.Greater(t, touchZero, trials/10) + assert.Greater(t, touchSize, trials/10) +} diff --git a/file/internal/readmatcher/rules.bzl b/file/internal/readmatcher/rules.bzl new file mode 100644 index 00000000..7e711782 --- /dev/null +++ b/file/internal/readmatcher/rules.bzl @@ -0,0 +1,12 @@ +load("@io_bazel_rules_go//go:def.bzl", _go_test = "go_test") + +def go_test(**kwargs): + _go_test(**kwargs) + kwargs.pop("name") + _go_test( + name = "race_test", + race = "on", + # Run a smaller test under the race detector because execution is slower. + args = ["-data-bytes={}".format(1 << 24), "-stress-parallelism=2"], + **kwargs + ) diff --git a/file/internal/s3bufpool/s3bufpool.go b/file/internal/s3bufpool/s3bufpool.go new file mode 100644 index 00000000..560feabf --- /dev/null +++ b/file/internal/s3bufpool/s3bufpool.go @@ -0,0 +1,27 @@ +package s3bufpool + +import ( + "sync" +) + +var ( + BufBytes = 16 * 1024 * 1024 + pool = sync.Pool{ + New: func() any { + b := make([]byte, BufBytes) + // Note: Return *[]byte, not []byte, so there's one heap allocation now to create the + // interface value, rather than one per Put. + return &b + }, + } +) + +func Get() *[]byte { return pool.Get().(*[]byte) } +func Put(b *[]byte) { pool.Put(b) } + +// SetBufSize modifies the buffer size. It's for testing only, and callers are responsible for +// making sure there's no race with Get or Put. +func SetBufSize(bytes int) { + BufBytes = bytes + pool = sync.Pool{New: pool.New} // Empty the pool. +} diff --git a/file/internal/testutil/testutil.go b/file/internal/testutil/testutil.go index c495c118..8fd1fe05 100644 --- a/file/internal/testutil/testutil.go +++ b/file/internal/testutil/testutil.go @@ -6,13 +6,19 @@ package testutil import ( "context" + "fmt" "io" "io/ioutil" + "math/rand" + "runtime" "sort" "testing" "time" + "github.com/grailbio/base/errors" "github.com/grailbio/base/file" + "github.com/grailbio/base/ioctx" + "github.com/grailbio/base/traverse" "github.com/grailbio/testutil/assert" ) @@ -60,10 +66,11 @@ func doWriteFile(ctx context.Context, t *testing.T, impl file.Implementation, pa } func fileExists(ctx context.Context, impl file.Implementation, path string) bool { - if _, err := impl.Stat(ctx, path); err != nil { - return false + _, err := impl.Stat(ctx, path) + if err != nil && !errors.Is(errors.NotExist, err) { + panic(err) } - return true + return err == nil } // TestEmpty creates an empty file and tests its operations. @@ -93,6 +100,19 @@ func TestEmpty( assert.NoError(t, f.Close(ctx)) } +// TestNotExist tests that the implementation behaves correctly +// for paths that do not exist. +func TestNotExist( + ctx context.Context, + t *testing.T, + impl file.Implementation, + path string) { + _, err := impl.Open(ctx, path) + assert.True(t, errors.Is(errors.NotExist, err)) + _, err = impl.Stat(ctx, path) + assert.True(t, errors.Is(errors.NotExist, err)) +} + // TestErrors tests handling of errors. "path" shouldn't exist. func TestErrors( ctx context.Context, @@ -155,6 +175,11 @@ func TestReads( doSeek(t, r, 0, io.SeekStart) assert.EQ(t, expected, doReadAll(t, r)) + // Seek twice to the same offset + doSeek(t, r, 1, io.SeekStart) + doSeek(t, r, 1, io.SeekStart) + assert.EQ(t, expected[1:], doReadAll(t, r)) + doSeek(t, r, 8, io.SeekStart) doSeek(t, r, -6, io.SeekCurrent) assert.EQ(t, "purple", doRead(t, r, 6)) @@ -212,9 +237,7 @@ func TestDiscard(ctx context.Context, t *testing.T, impl file.Implementation, di assert.NoError(t, err) // Discard, and then make sure it doesn't exist. - err = f.Discard(ctx) - assert.NoError(t, err) - + f.Discard(ctx) if fileExists(ctx, impl, path) { t.Errorf("path %s exists after call to discard", path) } @@ -342,12 +365,12 @@ func TestListDir(ctx context.Context, t *testing.T, impl file.Implementation, di }, doList(dir+"/d0/")) } -// TestAll runs all the tests in this package. -func TestAll(ctx context.Context, t *testing.T, impl file.Implementation, dir string) { +// TestStandard runs tests for all of the standard file API functionality. +func TestStandard(ctx context.Context, t *testing.T, impl file.Implementation, dir string) { iName := impl.String() t.Run(iName+"_Empty", func(t *testing.T) { TestEmpty(ctx, t, impl, dir+"/empty.txt") }) - + t.Run(iName+"_NotExist", func(t *testing.T) { TestNotExist(ctx, t, impl, dir+"/notexist.txt") }) t.Run(iName+"_Errors", func(t *testing.T) { TestErrors(ctx, t, impl, dir+"/errors.txt") }) t.Run(iName+"_Reads", func(t *testing.T) { TestReads(ctx, t, impl, dir+"/reads.txt") }) t.Run(iName+"_Writes", func(t *testing.T) { TestWrites(ctx, t, impl, dir+"/writes") }) @@ -357,3 +380,47 @@ func TestAll(ctx context.Context, t *testing.T, impl file.Implementation, dir st t.Run(iName+"_List", func(t *testing.T) { TestList(ctx, t, impl, dir+"/match") }) t.Run(iName+"_ListDir", func(t *testing.T) { TestListDir(ctx, t, impl, dir+"/dirmatch") }) } + +// TestConcurrentOffsetReads tests arbitrarily-ordered, concurrent reads. +func TestConcurrentOffsetReads( + ctx context.Context, + t *testing.T, + impl file.Implementation, + path string, +) { + expected := "A purple fox jumped over a blue cat" + doWriteFile(ctx, t, impl, path, expected) + + parallelism := runtime.NumCPU() + const readsPerShard = 1024 + + f, err := impl.Open(ctx, path) + assert.NoError(t, err) + + rnds := make([]*rand.Rand, parallelism) + rnds[0] = rand.New(rand.NewSource(1)) + for i := 1; i < len(rnds); i++ { + rnds[i] = rand.New(rand.NewSource(rnds[0].Int63())) + } + + assert.NoError(t, traverse.Limit(parallelism).Each(parallelism, func(shard int) (err error) { + rnd := rnds[shard] + for i := 0; i < readsPerShard; i++ { + start := rnd.Intn(len(expected)) + limit := start + rnd.Intn(len(expected)+1-start) + got := make([]byte, limit-start) + rc := f.OffsetReader(int64(start)) + defer errors.CleanUpCtx(ctx, rc.Close, &err) + _, err = io.ReadFull(ioctx.ToStdReader(ctx, rc), got) + if err != nil { + return err + } + if got, want := string(got), expected[start:limit]; got != want { + return fmt.Errorf("got: %s, want: %s", got, want) + } + } + return nil + })) + + assert.NoError(t, f.Close(ctx)) +} diff --git a/file/localdev_test.go b/file/localdev_test.go index f2a4c29b..1ee1b6f0 100644 --- a/file/localdev_test.go +++ b/file/localdev_test.go @@ -1,4 +1,5 @@ -// +build arc-ignore phabricator-ignore +//go:build !unit +// +build !unit package file_test @@ -7,40 +8,27 @@ import ( "io/ioutil" "os/exec" "path/filepath" + "runtime" "sync" "testing" "github.com/grailbio/base/file" "github.com/grailbio/testutil" + "github.com/grailbio/testutil/assert" "github.com/stretchr/testify/require" ) // Write to /dev/stdout. This test only checks that the write succeeds. func TestStdout(t *testing.T) { - var err error - - if testing.Short() { - t.Skip("Cannot open /dev/tty or /dev/stdout from automated tests") - } - - for _, path := range []string{ - "/dev/tty", // works on darwin - "/dev/stdout", // works on linux - } { - ctx := context.Background() - var w file.File - w, err = file.Create(ctx, path) - if err != nil { - continue - } - _, err = w.Writer(ctx).Write([]byte("Hello")) - if err != nil { - continue - } - require.NoError(t, w.Close(ctx)) - break + if runtime.GOOS == "darwin" { + t.Skip("This test does not consistently work on macOS") } - require.NoError(t, err) + ctx := context.Background() + w, err := file.Create(ctx, "/dev/stdout") + assert.Nil(t, err) + _, err = w.Writer(ctx).Write([]byte("Hello\n")) + assert.Nil(t, err) + require.NoError(t, w.Close(ctx)) } // Read and write a FIFO. @@ -61,7 +49,8 @@ func TestDevice(t *testing.T) { go func() { w, err := file.Create(ctx, fifoPath) require.NoError(t, err) - w.Writer(ctx).Write([]byte("Hello\n")) + _, err = w.Writer(ctx).Write([]byte("Hello\n")) + require.NoError(t, err) require.NoError(t, w.Close(ctx)) wg.Done() }() diff --git a/file/localfile.go b/file/localfile.go index 997220dd..5da748a7 100644 --- a/file/localfile.go +++ b/file/localfile.go @@ -14,6 +14,8 @@ import ( "sort" "time" + "github.com/grailbio/base/errors" + "github.com/grailbio/base/ioctx" "github.com/grailbio/base/log" ) @@ -53,18 +55,25 @@ func (impl *localImpl) String() string { } // Open implements file.Implementation. -func (impl *localImpl) Open(ctx context.Context, path string) (File, error) { +func (impl *localImpl) Open(ctx context.Context, path string, _ ...Opts) (File, error) { f, err := os.Open(path) if err != nil { + if os.IsNotExist(err) { + err = errors.E(err, errors.NotExist) + } return nil, err } - return &localFile{f: f, mode: readonly, path: path}, nil + lf := localFile{f: f, mode: readonly, path: path} + return &lf, nil } // Create implements file.Implementation. To make writes appear linearizable, // it creates a temporary file with name .tmp, then renames the temp file // to on Close. -func (*localImpl) Create(ctx context.Context, path string) (File, error) { +func (*localImpl) Create(ctx context.Context, path string, _ ...Opts) (File, error) { + if path == "" { // Detect common errors quickly. + return nil, fmt.Errorf("file.Create: empty pathname") + } realPath, err := filepath.EvalSymlinks(path) if err != nil { // This happens when the file doesn't exist, including the case where path @@ -75,12 +84,17 @@ func (*localImpl) Create(ctx context.Context, path string) (File, error) { // symlink on close. realPath = path } - if stat, err := os.Stat(path); err == nil && ((stat.Mode()&os.ModeDevice != 0) || (stat.Mode()&os.ModeNamedPipe != 0) || (stat.Mode()&os.ModeSocket != 0)) { - f, err := os.Create(path) - if err != nil { - return nil, err + if stat, err := os.Stat(path); err == nil { + if (stat.Mode()&os.ModeDevice != 0) || (stat.Mode()&os.ModeNamedPipe != 0) || (stat.Mode()&os.ModeSocket != 0) { + f, err := os.Create(path) + if err != nil { + return nil, err + } + return &localFile{f: f, mode: writeonlyDev, path: path, realPath: realPath}, nil + } + if stat.IsDir() { + return nil, fmt.Errorf("file.Create %s: is a directory", path) } - return &localFile{f: f, mode: writeonlyDev, path: path, realPath: realPath}, nil } // filepath.Dir just strips the last "/" if path ends with "/". Else, it @@ -130,19 +144,17 @@ func (f *localFile) close(_ context.Context, doSync bool) error { } // Discard implements file.File. -func (f *localFile) Discard(ctx context.Context) error { +func (f *localFile) Discard(ctx context.Context) { switch f.mode { - case readonly: - return fmt.Errorf("discard %s: file is not opened in write mode", f.Name()) - case writeonlyDev: - return fmt.Errorf("discard %s: cannot discard writes to devices or sockets", f.Name()) + case readonly, writeonlyDev: + return + } + if err := f.f.Close(); err != nil { + log.Printf("discard %s: close: %v", f.Name(), err) } - err := f.f.Close() - e2 := os.Remove(f.f.Name()) - if e2 != nil && err == nil { - err = e2 + if err := os.Remove(f.f.Name()); err != nil { + log.Printf("discard %s: remove: %v", f.Name(), err) } - return err } // String implements file.File. @@ -158,15 +170,39 @@ func (f *localFile) Name() string { // Reader implements file.File func (f *localFile) Reader(context.Context) io.ReadSeeker { if f.mode != readonly { - return NewErrorReader(fmt.Errorf("reader %v: file is not opened in read mode", f.Name())) + return NewError(fmt.Errorf("reader %v: file is not opened in read mode", f.Name())) } return f.f } +type localReader struct { + f *os.File + pos int64 +} + +func (r *localReader) Read(_ context.Context, p []byte) (int, error) { + n, err := r.f.ReadAt(p, r.pos) + r.pos += int64(n) + return n, err +} + +func (r *localReader) Close(context.Context) error { + r.f = nil + return nil +} + +// OffsetReader implements file.File +func (f *localFile) OffsetReader(offset int64) ioctx.ReadCloser { + if f.mode != readonly { + return ioctx.FromStdReadCloser(NewError(fmt.Errorf("reader %v: file is not opened in read mode", f.Name()))) + } + return &localReader{f: f.f, pos: offset} +} + // Writer implements file.Writer func (f *localFile) Writer(context.Context) io.Writer { if f.mode == readonly { - return NewErrorWriter(fmt.Errorf("writer %v: file is not opened in write mode", f.Name())) + return NewError(fmt.Errorf("writer %v: file is not opened in write mode", f.Name())) } return f.f } @@ -181,10 +217,18 @@ func (*localImpl) Remove(ctx context.Context, path string) error { return os.Remove(path) } +func (*localImpl) Presign(_ context.Context, path, _ string, _ time.Duration) (string, error) { + return "", errors.E(errors.NotSupported, + fmt.Sprintf("presign %v: local files not supported", path)) +} + // Stat implements file.Implementation -func (impl *localImpl) Stat(ctx context.Context, path string) (Info, error) { +func (impl *localImpl) Stat(ctx context.Context, path string, _ ...Opts) (Info, error) { info, err := os.Stat(path) if err != nil { + if os.IsNotExist(err) { + err = errors.E(err, errors.NotExist) + } return nil, err } if info.IsDir() { @@ -205,6 +249,12 @@ func (f *localFile) Stat(context.Context) (Info, error) { return &localInfo{size: info.Size(), modTime: info.ModTime()}, nil } +var _ ioctx.WriterAt = (*localFile)(nil) + +func (f *localFile) WriteAt(_ context.Context, p []byte, off int64) (n int, err error) { + return f.f.WriteAt(p, off) +} + func (i *localInfo) Size() int64 { return i.size } func (i *localInfo) ModTime() time.Time { return i.modTime } diff --git a/file/localfile_test.go b/file/localfile_test.go index e1d192d4..8fb97feb 100644 --- a/file/localfile_test.go +++ b/file/localfile_test.go @@ -6,6 +6,7 @@ package file_test import ( "context" + "fmt" "io/ioutil" "os" "path/filepath" @@ -14,6 +15,7 @@ import ( "github.com/grailbio/base/file" filetestutil "github.com/grailbio/base/file/internal/testutil" "github.com/grailbio/testutil" + "github.com/grailbio/testutil/assert" "github.com/stretchr/testify/require" ) @@ -22,7 +24,12 @@ func TestAll(t *testing.T) { defer cleanup() impl := file.NewLocalImplementation() ctx := context.Background() - filetestutil.TestAll(ctx, t, impl, tempDir) + filetestutil.TestStandard(ctx, t, impl, tempDir) +} + +func TestEmptyPath(t *testing.T) { + _, err := file.Create(context.Background(), "") + require.Regexp(t, "empty pathname", err) } // Test that Create on a symlink will preserve it. @@ -40,7 +47,8 @@ func TestCreateSymlink(t *testing.T) { ctx := context.Background() w, err := file.Create(context.Background(), newPath) require.NoError(t, err) - w.Writer(ctx).Write([]byte("hello")) + _, err = w.Writer(ctx).Write([]byte("hello")) + require.NoError(t, err) require.NoError(t, w.Close(ctx)) data, err := ioutil.ReadFile(newPath) @@ -52,3 +60,16 @@ func TestCreateSymlink(t *testing.T) { require.NoError(t, err) require.Equal(t, "hello", string(data)) } + +func TestCreateDirectory(t *testing.T) { + tmp, cleanup0 := testutil.TempDir(t, "", "") + defer cleanup0() + + dirPath := file.Join(tmp, "dir") + err := os.Mkdir(dirPath, 0777) + assert.Nil(t, err) + + ctx := context.Background() + _, err = file.Create(ctx, dirPath) + require.EqualError(t, err, fmt.Sprintf("file.Create %s: is a directory", dirPath)) +} diff --git a/file/loopbackfs/loopbackfs.go b/file/loopbackfs/loopbackfs.go new file mode 100644 index 00000000..1c6b61fb --- /dev/null +++ b/file/loopbackfs/loopbackfs.go @@ -0,0 +1,122 @@ +package loopbackfs + +import ( + "context" + "io/ioutil" + "os" + "path" + "time" + + "github.com/grailbio/base/file/fsnode" + "github.com/grailbio/base/ioctx/fsctx" + "github.com/grailbio/base/ioctx/spliceio" +) + +// New returns an fsnode.T representing a path on the local filesystem. +// TODO: Replace this with a generic io/fs.FS wrapper around os.DirFS, after upgrading Go. +func New(name string, path string) (fsnode.T, error) { + info, err := os.Stat(path) + if err != nil { + return nil, err + } + node := newT(name, info, path) + if node == nil { + return nil, os.ErrInvalid + } + return node, nil +} + +func newT(name string, info os.FileInfo, path string) fsnode.T { + switch info.Mode() & os.ModeType { + case os.ModeDir: + // Temporary hack: record the original path so we can peek at it. + // TODO: Eventually, adapt our libraries to operate on FS's directly. + // TODO: Consider preserving executable bits. But, copying permissions without also checking + // owner UID/GID may not make sense. + info := fsnode.NewDirInfo(name).WithModTime(info.ModTime()).WithSys(path). + WithCacheableFor(time.Hour) + return parent{dir: path, FileInfo: info} + case 0: + info := fsnode.NewRegInfo(name).WithModTime(info.ModTime()).WithSys(path). + WithCacheableFor(time.Hour) + return leaf{path, info} + } + return nil +} + +type parent struct { + fsnode.ParentReadOnly + dir string + fsnode.FileInfo +} + +var _ fsnode.Parent = parent{} + +func (p parent) FSNodeT() {} + +func (p parent) Child(_ context.Context, name string) (fsnode.T, error) { + return New(name, path.Join(p.dir, name)) +} + +type iterator struct { + dir string + fetched bool + delegate fsnode.Iterator +} + +func (p parent) Children() fsnode.Iterator { return &iterator{dir: p.dir} } + +func (it *iterator) ensureFetched() error { + if it.fetched { + if it.delegate == nil { + return os.ErrClosed + } + return nil + } + entries, err := ioutil.ReadDir(it.dir) + if err != nil { + return err + } + nodes := make([]fsnode.T, 0, len(entries)) + for _, info := range entries { + fullPath := path.Join(it.dir, info.Name()) + node := newT(info.Name(), info, fullPath) + if node == nil { + continue + } + nodes = append(nodes, node) + } + it.fetched = true + it.delegate = fsnode.NewIterator(nodes...) + return nil +} + +func (it *iterator) Next(ctx context.Context) (fsnode.T, error) { + if err := it.ensureFetched(); err != nil { + return nil, err + } + return it.delegate.Next(ctx) +} + +func (it *iterator) Close(context.Context) error { + it.fetched = true + it.delegate = nil + return nil +} + +type leaf struct { + path string + fsnode.FileInfo +} + +var _ fsnode.Leaf = leaf{} + +func (l leaf) FSNodeT() {} + +func (l leaf) OpenFile(context.Context, int) (fsctx.File, error) { + file, err := os.Open(l.path) + if err != nil { + return nil, err + } + return (*spliceio.OSFile)(file), nil +} diff --git a/file/path.go b/file/path.go index e7d54998..ba118d6d 100644 --- a/file/path.go +++ b/file/path.go @@ -108,22 +108,35 @@ func Dir(path string) string { return path[:len(scheme)+3] } -// Join joins any number of path elements into a single path, adding a separator -// if necessary. It is the same as filepath.Join if elems[0] is a local -// filesystem path. Else, it works like filepath.Join, with the following -// differences: (1) the path separator is always '/'. (3) The each element is -// not cleaned; for example if an element contains repeated "/"s in the middle, -// they are preserved. +// Join joins any number of path elements into a single path, adding a +// separator if necessary. It works like filepath.Join, with the following +// differences: +// 1. The path separator is always '/' (so this doesn't work on Windows). +// 2. The interior of each element is not cleaned; for example if an element +// contains repeated "/"s in the middle, they are preserved. +// 3. If elems[0] has a prefix of the form "://" or "//", that prefix +// is retained. (A prefix of "/" is also retained; that matches +// filepath.Join's behavior.) func Join(elems ...string) string { if len(elems) == 0 { return filepath.Join(elems...) } - scheme, suffix, err := ParsePath(elems[0]) - if scheme == "" || err != nil { - return filepath.Join(elems...) + var prefix string + n, err := getURLScheme(elems[0]) + if err == nil && n > 0 { + prefix = elems[0][:n+3] + elems[0] = elems[0][n+3:] + } else if len(elems[0]) > 0 && elems[0][0] == '/' { + if elems[0][1] == '/' { + prefix = "//" + elems[0] = elems[0][2:] + } else { + prefix = "/" + elems[0] = elems[0][1:] + } } - // Remove leading or trailing "/"s from the string. + // Remove leading (optional) or trailing "/"s from the string. clean := func(p string) string { var s, e int for s = 0; s < len(p); s++ { @@ -136,21 +149,20 @@ func Join(elems ...string) string { break } } - if e <= s { + if e < s { return "" } return p[s : e+1] } newElems := make([]string, 0, len(elems)) - newElems = append(newElems, scheme+"://"+clean(suffix)) - for i := 1; i < len(elems); i++ { + for i := 0; i < len(elems); i++ { e := clean(elems[i]) if e != "" { newElems = append(newElems, e) } } - return strings.Join(newElems, "/") + return prefix + strings.Join(newElems, "/") } // IsAbs returns true if pathname is absolute local path. For non-local file, it diff --git a/file/path_test.go b/file/path_test.go new file mode 100644 index 00000000..efb3ad8e --- /dev/null +++ b/file/path_test.go @@ -0,0 +1,70 @@ +package file_test + +import ( + "fmt" + "testing" + + "github.com/grailbio/base/file" + "github.com/grailbio/testutil/expect" +) + +func TestJoin(t *testing.T) { + tests := []struct { + elems []string + want string + }{ + { + []string{"foo/"}, // trailing separator removed from first element. + "foo", + }, + { + []string{"foo", "bar"}, // join adds separator + "foo/bar", + }, + { + []string{"foo", "bar/"}, // trailing separator removed from second element. + "foo/bar", + }, + { + []string{"/foo", "bar"}, // leading separator is retained in first element. + "/foo/bar", + }, + { + []string{"foo/", "bar"}, // trailing separator removed before join. + "foo/bar", + }, + { + []string{"foo/", "/bar"}, // all separators removed before join. + "foo/bar", + }, + { + []string{"foo/", "/bar", "baz"}, // all separators removed before join. + "foo/bar/baz", + }, + { + []string{"foo/", "bar", "/baz"}, // all separators removed before join. + "foo/bar/baz", + }, + { + []string{"http://foo/", "/bar"}, // separators inside the element are retained. + "http://foo/bar", + }, + { + []string{"s3://", "bar"}, + "s3://bar", + }, + { + []string{"s3://", "/bar"}, + "s3://bar", + }, + { + []string{"//go/src/grailbio/base/file", "path_test.go"}, + "//go/src/grailbio/base/file/path_test.go", + }, + } + for i, test := range tests { + t.Run(fmt.Sprint(i), func(t *testing.T) { + expect.EQ(t, file.Join(test.elems...), test.want) + }) + } +} diff --git a/file/s3file/awserr.go b/file/s3file/awserr.go new file mode 100644 index 00000000..20d6c72b --- /dev/null +++ b/file/s3file/awserr.go @@ -0,0 +1,102 @@ +package s3file + +import ( + "fmt" + + "github.com/aws/aws-sdk-go/aws/awserr" + awsrequest "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/grailbio/base/errors" +) + +// Annotate interprets err as an AWS request error and returns a version of it +// annotated with severity and kind from the errors package. The optional args +// are passed to errors.E. +func annotate(err error, ids s3RequestIDs, retry *retryPolicy, args ...interface{}) error { + e := func(prefixArgs ...interface{}) error { + msgs := append(prefixArgs, args...) + msgs = append(msgs, "awsrequestID:", ids.String()) + if retry.waitErr != nil { + msgs = append(msgs, fmt.Sprintf("[waitErr=%v]", retry.waitErr)) + } + msgs = append(msgs, fmt.Sprintf("[retries=%d, start=%v]", retry.retries, retry.startTime)) + return errors.E(msgs...) + } + aerr, ok := getAWSError(err) + if !ok { + return e(err) + } + if awsrequest.IsErrorThrottle(err) { + return e(err, errors.Temporary, errors.Unavailable) + } + if awsrequest.IsErrorRetryable(err) { + return e(err, errors.Temporary) + } + // The underlying error was an S3 error. Try to classify it. + // Best guess based on Amazon's descriptions: + switch aerr.Code() { + // Code NotFound is not documented, but it's what the API actually returns. + case s3.ErrCodeNoSuchBucket, s3.ErrCodeNoSuchKey, "NoSuchVersion", "NotFound": + return e(err, errors.NotExist) + case awsrequest.CanceledErrorCode: + return e(err, errors.Canceled) + case "AccessDenied": + return e(err, errors.NotAllowed) + case "InvalidRequest", "InvalidArgument", "EntityTooSmall", "EntityTooLarge", "KeyTooLong", "MethodNotAllowed": + return e(err, errors.Fatal) + case "ExpiredToken", "AccountProblem", "ServiceUnavailable", "TokenRefreshRequired", "OperationAborted": + return e(err, errors.Unavailable) + case "PreconditionFailed": + return e(err, errors.Precondition) + case "SlowDown": + return e(errors.Temporary, errors.Unavailable) + } + return e(err) +} + +func getAWSError(err error) (awsError awserr.Error, found bool) { + errors.Visit(err, func(err error) { + if err == nil || awsError != nil { + return + } + if e, ok := err.(awserr.Error); ok { + found = true + awsError = e + } + }) + return +} + +type s3RequestIDs struct { + amzRequestID string + amzID2 string +} + +func (ids s3RequestIDs) String() string { + return fmt.Sprintf("x-amz-request-id: %s, x-amz-id-2: %s", ids.amzRequestID, ids.amzID2) +} + +// This is the same as awsrequest.WithGetResponseHeader, except that it doesn't +// crash when the request fails w/o receiving an HTTP response. +// +// TODO(saito) Revert once awsrequest.WithGetResponseHeaders starts acting more +// gracefully. +func withGetResponseHeaderWithNilCheck(key string, val *string) awsrequest.Option { + return func(r *awsrequest.Request) { + r.Handlers.Complete.PushBack(func(req *awsrequest.Request) { + *val = "(no HTTP response)" + if req.HTTPResponse != nil && req.HTTPResponse.Header != nil { + *val = req.HTTPResponse.Header.Get(key) + } + }) + } +} + +func (ids *s3RequestIDs) captureOption() awsrequest.Option { + h0 := withGetResponseHeaderWithNilCheck("x-amz-request-id", &ids.amzRequestID) + h1 := withGetResponseHeaderWithNilCheck("x-amz-id-2", &ids.amzID2) + return func(r *awsrequest.Request) { + h0(r) + h1(r) + } +} diff --git a/file/s3file/bucketcache.go b/file/s3file/bucketcache.go index ee29590d..8f1ac57f 100644 --- a/file/s3file/bucketcache.go +++ b/file/s3file/bucketcache.go @@ -6,77 +6,81 @@ package s3file import ( "context" - "sync" + "fmt" + "time" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + awsrequest "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/s3" "github.com/aws/aws-sdk-go/service/s3/s3iface" + "github.com/aws/aws-sdk-go/service/s3/s3manager" + "github.com/grailbio/base/file" + "github.com/grailbio/base/sync/loadingcache" ) -// bucketCache is a singleton cache manager. -type bucketCache struct { - mu sync.Mutex - cache map[string]string // maps S3 bucket to region (e.g., "us-east-2"). +// bucketRegionCacheDuration is chosen fairly arbitrarily. We expect region changes to be +// extremely rare (deleting a bucket, then recreating elsewhere) so a long time seems fine. +const bucketRegionCacheDuration = time.Hour + +// FindBucketRegion locates the AWS region in which bucket is located. +// The lookup is cached internally. +// +// It assumes the region is in the "aws" partition, not other partitions like "aws-us-gov". +// See: https://docs.aws.amazon.com/AmazonS3/latest/userguide/UsingBucket.html +func FindBucketRegion(ctx context.Context, bucket string) (string, error) { + return globalBucketRegionCache.locate(ctx, bucket) } -var bCache = bucketCache{ - cache: make(map[string]string), +type bucketRegionCache struct { + cache loadingcache.Map + // getBucketRegionWithClient indirectly references s3manager.GetBucketRegionWithClient to + // allow unit testing. + getBucketRegionWithClient func(ctx aws.Context, svc s3iface.S3API, bucket string, opts ...awsrequest.Option) (string, error) } -// Find finds the region of the bucket using the given s3client. The client need -// not be in the same region as bucket. -func (c *bucketCache) find(ctx context.Context, client s3iface.S3API, bucket string) (string, error) { - c.mu.Lock() - val, ok := c.cache[bucket] - c.mu.Unlock() - if ok { // Common case - return val, nil +var ( + globalBucketRegionCache = bucketRegionCache{ + getBucketRegionWithClient: s3manager.GetBucketRegionWithClient, } - resp, err := client.GetBucketLocationWithContext(ctx, - &s3.GetBucketLocationInput{Bucket: aws.String(bucket)}) + bucketRegionClient = s3.New( + session.Must(session.NewSessionWithOptions(session.Options{ + Config: aws.Config{ + // This client is only used for looking up bucket locations, which doesn't + // require any credentials. + Credentials: credentials.AnonymousCredentials, + // Note: This region is just used to infer the relevant AWS partition (group of + // regions). This would fail for, say, "aws-us-gov", but we only use "aws". + // See: https://docs.aws.amazon.com/AmazonS3/latest/userguide/UsingBucket.html + Region: aws.String("us-west-2"), + }, + SharedConfigState: session.SharedConfigDisable, + })), + ) +) + +func (c *bucketRegionCache) locate(ctx context.Context, bucket string) (string, error) { + var region string + err := c.cache. + GetOrCreate(bucket). + GetOrLoad(ctx, ®ion, func(ctx context.Context, opts *loadingcache.LoadOpts) (err error) { + opts.CacheFor(bucketRegionCacheDuration) + policy := newBackoffPolicy([]s3iface.S3API{bucketRegionClient}, file.Opts{}) + for { + var ids s3RequestIDs + region, err = c.getBucketRegionWithClient(ctx, + bucketRegionClient, bucket, ids.captureOption()) + if err == nil { + return nil + } + if !policy.shouldRetry(ctx, err, fmt.Sprintf("locate region: %s", bucket)) { + return annotate(err, ids, &policy) + } + } + }) if err != nil { return "", err } - // nil location means us-east-1. - // https://docs.aws.amazon.com/AmazonS3/latest/API/RESTBucketGETlocation.html - region := "us-east-1" - if s := aws.StringValue(resp.LocationConstraint); s != "" { - region = s - } - c.mu.Lock() - c.cache[bucket] = region - c.mu.Unlock() return region, nil } - -// Set overrides the region for the provided bucket. -func (c *bucketCache) set(bucket, region string) { - c.mu.Lock() - c.cache[bucket] = region - c.mu.Unlock() -} - -// Invalidate removes the cached bucket-to-region mapping. -func (c *bucketCache) invalidate(bucket string) { - c.mu.Lock() - delete(c.cache, bucket) - c.mu.Unlock() -} - -// GetBucketRegion finds the AWS region for the S3 bucket and inserts it in the -// cache. "client" is used to issue the GetBucketRegion S3 call. It doesn't need -// to be in the region for the "bucket". -func GetBucketRegion(ctx context.Context, client s3iface.S3API, bucket string) (string, error) { - return bCache.find(ctx, client, bucket) -} - -// InvalidateBucketRegion removes the cache entry for bucket, if it exists. -func InvalidateBucketRegion(bucket string) { - bCache.invalidate(bucket) -} - -// SetBucketRegion sets a bucket's region, overriding region discovery and -// defaults. -func SetBucketRegion(bucket, region string) { - bCache.set(bucket, region) -} diff --git a/file/s3file/bucketcache_test.go b/file/s3file/bucketcache_test.go index 3e20e6e9..de06eec8 100644 --- a/file/s3file/bucketcache_test.go +++ b/file/s3file/bucketcache_test.go @@ -2,59 +2,55 @@ // Use of this source code is governed by the Apache-2.0 // license that can be found in the LICENSE file. -// +build arc-ignore phabricator-ignore - -package s3file_test +package s3file import ( "context" "flag" - "os" "testing" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/s3" - "github.com/grailbio/base/file/s3file" + awsrequest "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/s3/s3iface" + "github.com/grailbio/base/errors" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -var ( - manualFlag = flag.Bool("run-manual-test", false, "If true, run tests that access AWS.") -) +var awsFlag = flag.Bool("aws", false, "If true, run tests that access AWS.") -func maybeSkipManualTest(t *testing.T) { - if *manualFlag { - return - } - if os.Getenv("TEST_TMPDIR") == "" { +func TestBucketRegion(t *testing.T) { + if !*awsFlag { + t.Skipf("skipping %s, pass -aws to run", t.Name()) return } - t.Skip("not enabled") -} - -func getBucketRegion(t *testing.T, ctx context.Context, bucket string) string { - sess, err := session.NewSession(&aws.Config{ - MaxRetries: aws.Int(10), - Region: aws.String("us-east-1"), - }) - require.NoError(t, err) - client := s3.New(sess) - region, err := s3file.GetBucketRegion(ctx, client, bucket) - require.NoError(t, err) - return region -} - -func TestBucketRegion(t *testing.T) { - maybeSkipManualTest(t) ctx := context.Background() - region := getBucketRegion(t, ctx, "grail-ysaito") + region := findBucketRegion(t, ctx, "grail-ccga2-evaluation-runs") require.Equal(t, region, "us-west-2") - region = getBucketRegion(t, ctx, "grail-test-us-east-1") + region = findBucketRegion(t, ctx, "grail-test-us-east-1") require.Equal(t, region, "us-east-1") - region = getBucketRegion(t, ctx, "grail-test-us-east-2") + region = findBucketRegion(t, ctx, "grail-test-us-east-2") require.Equal(t, region, "us-east-2") } + +func findBucketRegion(t *testing.T, ctx context.Context, bucket string) string { + region, err := FindBucketRegion(ctx, bucket) + require.NoError(t, err) + return region +} + +func TestBucketRegionCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _ = <-ctx.Done() + cache := bucketRegionCache{ + getBucketRegionWithClient: func(aws.Context, s3iface.S3API, string, ...awsrequest.Option) (string, error) { + return "", errors.E(errors.Temporary, "test transient error") + }, + } + _, err := cache.locate(ctx, "grail-ccga2-evaluation-runs") + assert.Contains(t, err.Error(), context.Canceled.Error()) +} diff --git a/file/s3file/clientprovider.go b/file/s3file/clientprovider.go deleted file mode 100644 index 94a74fc4..00000000 --- a/file/s3file/clientprovider.go +++ /dev/null @@ -1,138 +0,0 @@ -// Copyright 2018 GRAIL, Inc. All rights reserved. -// Use of this source code is governed by the Apache-2.0 -// license that can be found in the LICENSE file. - -package s3file - -import ( - "context" - "sync" - - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/s3" - "github.com/aws/aws-sdk-go/service/s3/s3iface" - - "github.com/grailbio/base/log" -) - -const ( - defaultRegion = "us-west-2" - defaultMaxRetries = 25 -) - -// ClientProvider is responsible for creating an S3 client object. Get() is -// called whenever s3File needs to access a file. The provider should cache and -// reuse the client objects, if needed. The implementation must be thread safe. -type ClientProvider interface { - // Get returns S3 clients that can be used to perform "op" on "path". - // - // "op" is an S3 operation name, without the "s3:" prefix; for example - // "PutObject" or "ListBucket". The full list of operations is defined in - // https://docs.aws.amazon.com/AmazonS3/latest/dev/using-with-s3-actions.html - // - // Path is a full URL of form "s3://bucket/key". This method may be called - // concurrently from multiple threads. - // - // Usually Get() returns one S3 client object on success. If it returns - // multiple clients, the s3 file implementation will try each client in order, - // until the operation succeeds. - // - // REQUIRES: Get returns either >=1 clients, or a non-nil error. - Get(ctx context.Context, op, path string) ([]s3iface.S3API, error) - - // NotifyResult is called to inform that using "client" to perform "op" on - // "path" resulted in the given error (err is nil if the op succeeded). The - // provider should use it to optimize the list of clients to return in Get in - // a future. - // - // Parameter "client" is one of the clients returned by the Get call. - NotifyResult(ctx context.Context, op, path string, client s3iface.S3API, err error) -} - -type regionCache struct { - session *session.Session - clients []s3iface.S3API -} - -// NewDefaultProvider creates a trivial ClientProvider that uses AWS -// session.NewSession() -// (https://docs.aws.amazon.com/sdk-for-go/api/aws/session/). -// -// opts is passed to NewSession. The exception is opts.Config.Region, which will -// be be overwritten to point to the actual bucket location. -func NewDefaultProvider(opts session.Options) ClientProvider { - region := defaultRegion - if opts.Config.Region != nil { - region = *opts.Config.Region - } - return &defaultProvider{ - opts: opts, - defaultRegion: region, - regions: make(map[string]*regionCache), - } -} - -type defaultProvider struct { - opts session.Options - defaultRegion string - - mu sync.Mutex - regions map[string]*regionCache - mruRegion *regionCache -} - -// REQUIRES: p.mu is locked -func (p *defaultProvider) getRegion(region string) (*regionCache, error) { - c, ok := p.regions[region] - if !ok { - opts := p.opts - opts.Config.Region = ®ion - s, err := session.NewSessionWithOptions(opts) - if err != nil { - return nil, err - } - client := s3.New(s) - c = ®ionCache{ - session: s, - clients: []s3iface.S3API{client}, - } - p.regions[region] = c - } - p.mruRegion = c - return c, nil -} - -func (p *defaultProvider) getBucketRegion(ctx context.Context, bucket string) string { - p.mu.Lock() - rc := p.mruRegion - if rc == nil { - var err error - if rc, err = p.getRegion(p.defaultRegion); err != nil { - log.Error.Printf("getcketregion: Failed to create client in default region %s: %v", p.defaultRegion, err) - p.mu.Unlock() - return p.defaultRegion - } - } - p.mu.Unlock() - region, err := GetBucketRegion(ctx, rc.clients[0], bucket) - if err != nil { - log.Printf("getbucketregion %s: %v. using %v", bucket, err, p.defaultRegion) - return p.defaultRegion - } - return region -} - -func (p *defaultProvider) Get(ctx context.Context, op, path string) ([]s3iface.S3API, error) { - _, bucket, _, err := ParseURL(path) - if err != nil { - return nil, err - } - region := p.getBucketRegion(ctx, bucket) - p.mu.Lock() - c, err := p.getRegion(region) - p.mu.Unlock() - return c.clients, err -} - -func (p *defaultProvider) NotifyResult(ctx context.Context, op, path string, client s3iface.S3API, err error) { -} diff --git a/file/s3file/file.go b/file/s3file/file.go new file mode 100644 index 00000000..a5181307 --- /dev/null +++ b/file/s3file/file.go @@ -0,0 +1,392 @@ +package s3file + +import ( + "context" + "fmt" + "io" + "time" + + "github.com/aws/aws-sdk-go/service/s3/s3iface" + "github.com/grailbio/base/errors" + "github.com/grailbio/base/file" + "github.com/grailbio/base/ioctx" +) + +// s3File implements file.File interface. +// +// Operations on a file are internally implemented by a goroutine running handleRequests, +// which reads requests from s3file.reqCh and sends responses to request.ch. +// +// s3File's API methods (Read, Seek, etc.) are implemented by: +// - Create a chan response. +// - Construct a request{} object describing the operation and send it to reqCh. +// - Wait for a message from either the response channel or context.Done(), +// whichever comes first. +type s3File struct { + name string // "s3://bucket/key/.." + clientsForAction clientsForActionFunc + mode accessMode + opts file.Opts + + bucket string // bucket part of "name". + key string // key part "name". + + // info is file metadata. Set at construction if mode == readonly, otherwise nil. + info *s3Info + + // reqCh transports user operations (like Read) to the worker goroutine (handleRequests). + // This allows respecting context cancellation (regardless of what underlying AWS SDK operations + // do). It also guards subsequent fields; they are only accessed by the handleRequests + // goroutine. + reqCh chan request + + // readerState is used for Reader(), which shares state across multiple callers. + readerState + + // Used by files opened for writing. + uploader *s3Uploader +} + +// Name returns the name of the file. +func (f *s3File) Name() string { + return f.name +} + +func (f *s3File) String() string { + return f.name +} + +// s3Info implements file.Info interface. +type s3Info struct { + name string + size int64 + modTime time.Time + etag string // = GetObjectOutput.ETag +} + +func (i *s3Info) Name() string { return i.name } +func (i *s3Info) Size() int64 { return i.size } +func (i *s3Info) ModTime() time.Time { return i.modTime } +func (i *s3Info) ETag() string { return i.etag } + +func (f *s3File) Stat(ctx context.Context) (file.Info, error) { + if f.mode != readonly { + return nil, errors.E(errors.NotSupported, f.name, "stat for writeonly file not supported") + } + if f.info == nil { + panic(f) + } + return f.info, nil +} + +type ( + reader struct { + f *s3File + *readerState + } + readerState struct { + position int64 + bodyReader chunkReaderCache + } + defaultReader struct { + ctx context.Context + f *s3File + } +) + +func (r reader) Read(ctx context.Context, p []byte) (int, error) { + // TODO: Defensively guard against the underlying http body reader not respecting context + // cancellation. Note that the handleRequests mechanism guards against this for its + // operations (in addition to synchronizing), but that's not true here. + // Such defense may be appropriate here, or deeper in the stack. + n, err := r.f.readAt(ctx, &r.bodyReader, p, r.position) + r.position += int64(n) + return n, err +} + +func (r *readerState) Close(ctx context.Context) error { + r.bodyReader.close() + return nil +} + +func (f *s3File) OffsetReader(offset int64) ioctx.ReadCloser { + return reader{f, &readerState{position: offset}} +} + +func (r defaultReader) Read(p []byte) (int, error) { + res := r.f.runRequest(r.ctx, request{ + reqType: readRequest, + buf: p, + }) + return res.n, res.err +} + +func (r defaultReader) Seek(offset int64, whence int) (int64, error) { + res := r.f.runRequest(r.ctx, request{ + reqType: seekRequest, + off: offset, + whence: whence, + }) + return res.off, res.err +} + +// Reader returns the default reader. There is only one default reader state for the entire file, +// and all objects returned by Reader share it. +// TODO: Consider deprecating this in favor of NewReader. +func (f *s3File) Reader(ctx context.Context) io.ReadSeeker { + if f.mode != readonly { + return file.NewError(fmt.Errorf("reader %v: file is not opened in read mode", f.name)) + } + return defaultReader{ctx, f} +} + +// s3Writer implements a placeholder io.Writer for S3. +type s3Writer struct { + ctx context.Context + f *s3File +} + +func (w *s3Writer) Write(p []byte) (n int, err error) { + if len(p) == 0 { + return 0, nil + } + res := w.f.runRequest(w.ctx, request{ + reqType: writeRequest, + buf: p, + }) + return res.n, res.err +} + +func (f *s3File) Writer(ctx context.Context) io.Writer { + if f.mode != writeonly { + return file.NewError(fmt.Errorf("writer %v: file is not opened in write mode", f.name)) + } + return &s3Writer{ctx: ctx, f: f} +} + +func (f *s3File) Close(ctx context.Context) error { + err := f.runRequest(ctx, request{reqType: closeRequest}).err + close(f.reqCh) + return err +} + +func (f *s3File) Discard(ctx context.Context) { + if f.mode != writeonly { + return + } + _ = f.runRequest(ctx, request{reqType: abortRequest}) + close(f.reqCh) +} + +type requestType int + +const ( + seekRequest requestType = iota + readRequest + statRequest + writeRequest + closeRequest + abortRequest +) + +type request struct { + ctx context.Context // context passed to Read, Seek, Close, etc. + reqType requestType + + // For Read and Write + buf []byte + + // For Seek + off int64 + whence int + + // For sending the response + ch chan response +} + +type response struct { + n int // # of bytes read. Set only by Read. + off int64 // Seek location. Set only by Seek. + info *s3Info // Set only by Stat. + signedURL string // Set only by Presign. + err error // Any error + uploader *s3Uploader +} + +func (f *s3File) handleRequests() { + for req := range f.reqCh { + switch req.reqType { + case statRequest: + f.handleStat(req) + case seekRequest: + f.handleSeek(req) + case readRequest: + f.handleRead(req) + case writeRequest: + f.handleWrite(req) + case closeRequest: + f.handleClose(req) + case abortRequest: + f.handleAbort(req) + default: + panic(fmt.Sprintf("Illegal request: %+v", req)) + } + close(req.ch) + } +} + +// Send a request to the handleRequests goroutine and wait for a response. The +// caller must set all the necessary fields in req, except ctx and ch, which are +// filled by this method. On ctx timeout or cancellation, returns a response +// with non-nil err field. +func (f *s3File) runRequest(ctx context.Context, req request) response { + resCh := make(chan response, 1) + req.ctx = ctx + req.ch = resCh + f.reqCh <- req + select { + case res := <-resCh: + return res + case <-ctx.Done(): + return response{err: errors.E(errors.Canceled)} + } +} + +func (f *s3File) handleStat(req request) { + ctx := req.ctx + clients, err := f.clientsForAction(ctx, "GetObject", f.bucket, f.key) + if err != nil { + req.ch <- response{err: errors.E(err, fmt.Sprintf("s3file.stat %v", f.name))} + return + } + policy := newBackoffPolicy(clients, f.opts) + info, err := stat(ctx, clients, policy, f.name, f.bucket, f.key) + if err != nil { + req.ch <- response{err: err} + return + } + f.info = info + req.ch <- response{err: nil} +} + +// Seek implements io.Seeker +func (f *s3File) handleSeek(req request) { + if f.info == nil { + panic("stat not filled") + } + var newPosition int64 + switch req.whence { + case io.SeekStart: + newPosition = req.off + case io.SeekCurrent: + newPosition = f.position + req.off + case io.SeekEnd: + newPosition = f.info.size + req.off + default: + req.ch <- response{off: f.position, err: fmt.Errorf("s3file.seek(%s,%d,%d): illegal whence", f.name, req.off, req.whence)} + return + } + if newPosition < 0 { + req.ch <- response{off: f.position, err: fmt.Errorf("s3file.seek(%s,%d,%d): out-of-bounds seek", f.name, req.off, req.whence)} + return + } + if newPosition == f.position { + req.ch <- response{off: f.position} + return + } + f.position = newPosition + req.ch <- response{off: f.position} +} + +func (f *s3File) readAt( + ctx context.Context, + readerCache *chunkReaderCache, + buf []byte, + off int64, +) (int, error) { + if f.mode != readonly { + return 0, errors.E(errors.NotAllowed, "not opened for read") + } + if f.info == nil { + panic("stat not filled") + } + + reader, cleanUp, err := readerCache.getOrCreate(ctx, func() (*chunkReaderAt, error) { + clients, err := f.clientsForAction(ctx, "GetObject", f.bucket, f.key) + if err != nil { + return nil, errors.E(err, "getting clients") + } + return &chunkReaderAt{ + name: f.name, + bucket: f.bucket, + key: f.key, + newRetryPolicy: func() retryPolicy { + return newBackoffPolicy(append([]s3iface.S3API{}, clients...), f.opts) + }, + }, nil + }) + if err != nil { + return 0, err + } + defer cleanUp() + + var n int + // Note: We allow seeking past EOF, consistent with io.Seeker.Seek's documentation. We simply + // return EOF in this situation. + if bytesUntilEOF := f.info.size - off; bytesUntilEOF <= 0 { + err = io.EOF + } else { + // Because we know the size of the object, pass a smaller buffer to the + // chunk reader to save it the effort of trying to fill it (with + // parallel reads). This is an optimization that does not affect + // correctness. + // TODO: Consider how to move this optimization into the chunk reader + // itself, possibly by optionally passing in the size/metadata. + if len(buf) > int(bytesUntilEOF) { + buf = buf[:bytesUntilEOF] + } + var info s3Info + n, info, err = reader.ReadAt(ctx, buf, off) + if err != nil && err != io.EOF { + err = errors.E(err, fmt.Sprintf("s3file.read %v", f.name)) + } else if info == (s3Info{}) { + // Maybe EOF or len(req.buf) == 0. + } else if f.info.etag != info.etag { + // Note: If err was io.EOF, we intentionally drop that in favor of flagging ETag mismatch. + err = eTagChangedError(f.name, f.info.etag, info.etag) + } + } + return n, err +} + +func (f *s3File) handleRead(req request) { + n, err := reader{f, &f.readerState}.Read(req.ctx, req.buf) + req.ch <- response{n: n, err: err} +} + +func (f *s3File) handleWrite(req request) { + f.uploader.write(req.buf) + req.ch <- response{n: len(req.buf), err: nil} +} + +func (f *s3File) handleClose(req request) { + var err error + if f.uploader != nil { + err = f.uploader.finish() + } + errors.CleanUpCtx(req.ctx, f.readerState.Close, &err) + if err != nil { + err = errors.E(err, "s3file.close", f.name) + } + f.clientsForAction = nil + req.ch <- response{err: err} +} + +func (f *s3File) handleAbort(req request) { + err := f.uploader.abort() + if err != nil { + err = errors.E(err, "s3file.abort", f.name) + } + f.clientsForAction = nil + req.ch <- response{err: err} +} diff --git a/file/s3file/file_chunk_read.go b/file/s3file/file_chunk_read.go new file mode 100644 index 00000000..09d9183f --- /dev/null +++ b/file/s3file/file_chunk_read.go @@ -0,0 +1,322 @@ +package s3file + +import ( + "context" + "fmt" + "io" + "path/filepath" + "sync" + "sync/atomic" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3iface" + "github.com/grailbio/base/errors" + "github.com/grailbio/base/file/internal/s3bufpool" + "github.com/grailbio/base/file/s3file/internal/autolog" + "github.com/grailbio/base/log" + "github.com/grailbio/base/traverse" +) + +type ( + // chunkReaderAt is similar to ioctx.ReaderAt except it is not concurrency-safe. + // It's currently used to implement S3-recommended read parallelism for large reads, though + // clients of s3file still only see the non-parallel io.Reader API. + // TODO: Expose concurrency-safe ReaderAt API to clients. + chunkReaderAt struct { + // name is redundant with (bucket, key). + name, bucket, key, versionID string + // newRetryPolicy creates retry policies. It must be concurrency- and goroutine-safe. + newRetryPolicy func() retryPolicy + + // previousR is a body reader open from a previous ReadAt. It's an optimization for + // clients that do many small reads. It may be nil (before first read, after errors, etc.). + previousR *posReader + // chunks is used locally within ReadAt. It's stored here only to reduce allocations. + chunks []readChunk + } + readChunk struct { + // s3Offset is the position of this *chunk* in the coordinates of the S3 object. + // That is, dst[0] will eventually contain s3Object[s3Offset]. + s3Offset int64 + // dst contains the chunk's data after read. After read, dstN < len(dst) iff there was an + // error or EOF. + dst []byte + // dstN tracks how much of dst is already filled. + dstN int + // r is the current reader for this chunk. It may be nil or at the wrong position for + // this chunk's state; then we'd need a new reader. + r *posReader + } + + // posReader wraps the S3 SDK's reader with retries and remembers its offset in the S3 object. + posReader struct { + rc io.ReadCloser + offset int64 + // ids is set when posReader is opened. + ids s3RequestIDs + // info is set when posReader is opened, unless there's an error or EOF. + info s3Info + } +) + +// ReadChunkBytes is the size for individual S3 API read operations, guided by S3 docs: +// As a general rule, when you download large objects within a Region from Amazon S3 to +// Amazon EC2, we suggest making concurrent requests for byte ranges of an object at the +// granularity of 8–16 MB. +// https://web.archive.org/web/20220325121400/https://docs.aws.amazon.com/AmazonS3/latest/userguide/optimizing-performance-design-patterns.html +func ReadChunkBytes() int { return s3bufpool.BufBytes } + +// ReadAt is not concurrency-safe. +// s3Info may be empty if no object metadata is fetched (zero-sized request, error). +func (r *chunkReaderAt) ReadAt(ctx context.Context, dst []byte, offset int64) (int, s3Info, error) { + if len(dst) == 0 { + return 0, s3Info{}, nil + } + r.chunks = r.chunks[:0] + for buf, bufOff := dst, offset; len(buf) > 0; { + size := len(buf) + if size > s3bufpool.BufBytes { + size = s3bufpool.BufBytes + } + r.chunks = append(r.chunks, readChunk{ + s3Offset: bufOff, + dst: buf[:size:size], + }) + bufOff += int64(size) + buf = buf[size:] + } + + // The first chunk gets to try to use a previously-opened reader (best-effort). + // Note: If len(r.chunks) == 1 we're both reusing a saved reader and saving it again. + r.chunks[0].r, r.previousR = r.previousR, nil + defer func() { + r.previousR = r.chunks[len(r.chunks)-1].r + }() + + var ( + infoMu sync.Mutex + info s3Info + ) + // TODO: traverse (or other common lib) support for exiting on first error to reduce latency. + err := traverse.Each(len(r.chunks), func(chunkIdx int) (err error) { + chunk := &r.chunks[chunkIdx] + policy := r.newRetryPolicy() + + defer func() { + if err != nil { + err = annotate(err, chunk.r.maybeIDs(), &policy) + } + }() + // Leave the last chunk's reader open for future reuse. + if chunkIdx < len(r.chunks)-1 { + defer func() { chunk.r.Close(); chunk.r = nil }() + } + + metric := metrics.Op("read").Start() + defer metric.Done() + + attemptLoop: + for attempt := 0; ; attempt++ { + switch err { + case nil: // Initial attempt. + case io.EOF, io.ErrUnexpectedEOF: + // In rare cases the S3 SDK returns EOF for chunks that are not actually at EOF. + // To work around this, we ignore EOF errors, and keep reading as long as the + // object metadata size field says we're not done. See BXDS-2220 for details. + // See also: https://github.com/aws/aws-sdk-go/issues/4510 + default: + if !policy.shouldRetry(ctx, err, r.name) { + break attemptLoop + } + } + err = nil + remainingBuf := chunk.dst[chunk.dstN:] + if len(remainingBuf) == 0 { + break + } + + if attempt > 0 { + metric.Retry() + } + + rangeStart := chunk.s3Offset + int64(chunk.dstN) + switch { + case chunk.r != nil && chunk.r.offset == rangeStart: + // We're ready to read. + case chunk.r != nil: + chunk.r.Close() + fallthrough + default: + chunk.r, err = newPosReader(ctx, policy.client(), r.name, r.bucket, r.key, r.versionID, rangeStart) + if err == io.EOF { + // rangeStart is at or past EOF, so this chunk is done. + err = nil + break attemptLoop + } + if err != nil { + continue + } + } + + var size int64 + infoMu.Lock() + if info == (s3Info{}) { + info = chunk.r.info + } else if info.etag != chunk.r.info.etag { + err = eTagChangedError(r.name, info.etag, chunk.r.info.etag) + } + size = info.size + infoMu.Unlock() + if err != nil { + continue + } + + bytesUntilEOF := size - chunk.s3Offset - int64(chunk.dstN) + if bytesUntilEOF <= 0 { + break + } + if bytesUntilEOF < int64(len(remainingBuf)) { + remainingBuf = remainingBuf[:bytesUntilEOF] + } + var n int + n, err = io.ReadFull(chunk.r, remainingBuf) + chunk.dstN += n + if err == nil { + break + } + // Discard our reader after an error. This error is often due to throttling + // (especially connection reset), so we want to retry with a new HTTP request which + // may go to a new host. + chunk.r.Close() + chunk.r = nil + } + metric.Bytes(chunk.dstN) + return err + }) + + var nBytes int + for _, chunk := range r.chunks { + nBytes += chunk.dstN + if chunk.dstN < len(chunk.dst) { + if err == nil { + err = io.EOF + } + break + } + } + return nBytes, info, err +} + +func eTagChangedError(name, oldETag, newETag string) error { + return errors.E(errors.Precondition, fmt.Sprintf( + "read %v: ETag changed from %v to %v", name, oldETag, newETag)) +} + +func (r *chunkReaderAt) Close() { r.previousR.Close() } + +var ( + nOpenPos int32 + nOpenPosOnce sync.Once +) + +func newPosReader( + ctx context.Context, + client s3iface.S3API, + name, bucket, key, versionID string, + offset int64, +) (*posReader, error) { + nOpenPosOnce.Do(func() { + autolog.Register(func() { + log.Printf("s3file open posReader: %d", atomic.LoadInt32(&nOpenPos)) + }) + }) + r := posReader{offset: offset} + input := s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + Range: aws.String(fmt.Sprintf("bytes=%d-", r.offset)), + } + if versionID != "" { + input.VersionId = aws.String(versionID) + } + output, err := client.GetObjectWithContext(ctx, &input, r.ids.captureOption()) + if err != nil { + if output.Body != nil { + if errClose := output.Body.Close(); errClose != nil { + log.Printf("s3file.newPosReader: ignoring body close error: %v", err) + } + } + if awsErr, ok := getAWSError(err); ok && awsErr.Code() == "InvalidRange" { + // Since we're reading many chunks in parallel, some can be past the end of + // the object, resulting in range errors. Treat these as EOF. + err = io.EOF + } + return nil, err + } + _ = atomic.AddInt32(&nOpenPos, 1) + if output.ContentLength == nil || output.ETag == nil || output.LastModified == nil { + return nil, errors.E("s3file.newPosReader: object missing metadata (ContentLength, ETag, LastModified)") + } + if *output.ContentLength < 0 { + // We do not expect AWS to return negative ContentLength, but we are + // defensive, as things may otherwise break very confusingly for + // callers. + return nil, io.EOF + } + r.info = s3Info{ + name: filepath.Base(name), + size: offset + *output.ContentLength, + modTime: *output.LastModified, + etag: *output.ETag, + } + r.rc = output.Body + return &r, nil +} + +// Read usually delegates to the underlying reader, except: (&posReader{}).Read is valid and +// always at EOF; nil.Read panics. +func (p *posReader) Read(dst []byte) (int, error) { + if p.rc == nil { + return 0, io.EOF + } + n, err := p.rc.Read(dst) + p.offset += int64(n) + return n, err +} + +// Close usually delegates to the underlying reader, except: (&posReader{}).Close +// and nil.Close do nothing. +func (p *posReader) Close() { + if p == nil || p.rc == nil { + return + } + _ = atomic.AddInt32(&nOpenPos, -1) + if err := p.rc.Close(); err != nil { + // Note: Since the caller is already done reading from p.rc, we don't expect this error to + // indicate a problem with the correctness of past Reads, instead signaling some resource + // leakage (network connection, buffers, etc.). We can't retry the resource release: + // * io.Closer does not define behavior for multiple Close calls and + // s3.GetObjectOutput.Body doesn't say anything implementation-specific. + // * Body may be a net/http.Response.Body [1] but the standard library doesn't say + // anything about multiple Close either (and even if it did, we shouldn't rely on the + // AWS SDK's implementation details in all cases or in the future). + // Without a retry opportunity, it seems like callers could either ignore the potential + // leak, or exit the OS process. We assume, for now, that callers won't want to do the + // latter, so we hide the error. (This could eventually lead to OS process exit due to + // resource exhaustion, so arguably this hiding doesn't add much harm, though of course it + // may be confusing.) We could consider changing this in the future, especially if we notice + // such resource leaks in real programs. + // + // [1] https://github.com/aws/aws-sdk-go/blob/e842504a6323096540dc3defdc7cb357d8749893/private/protocol/rest/unmarshal.go#L89-L90 + log.Printf("s3file.posReader.Close: ignoring body close error: %v", err) + } +} + +// maybeIDs returns ids if available, otherwise zero. p == nil is allowed. +func (p *posReader) maybeIDs() s3RequestIDs { + if p == nil { + return s3RequestIDs{} + } + return p.ids +} diff --git a/file/s3file/file_write.go b/file/s3file/file_write.go new file mode 100644 index 00000000..bbe6d698 --- /dev/null +++ b/file/s3file/file_write.go @@ -0,0 +1,264 @@ +package s3file + +import ( + "bytes" + "context" + "fmt" + "sort" + "sync" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3iface" + "github.com/grailbio/base/errors" + "github.com/grailbio/base/file" + "github.com/grailbio/base/log" +) + +// A helper class for driving s3manager.Uploader through an io.Writer-like +// interface. Its write() method will feed data incrementally to the uploader, +// and finish() will wait for all the uploads to finish. +type s3Uploader struct { + ctx context.Context + client s3iface.S3API + path, bucket, key string + opts file.Opts + s3opts Options + uploadID string + createTime time.Time // time of file.Create() call + // curBuf is only accessed by the handleRequests thread. + curBuf *[]byte + nextPartNum int64 + + bufPool sync.Pool + reqCh chan uploadChunk + err errors.Once + sg sync.WaitGroup + mu sync.Mutex + parts []*s3.CompletedPart +} + +type uploadChunk struct { + client s3iface.S3API + uploadID string + partNum int64 + buf *[]byte +} + +const uploadParallelism = 16 + +// UploadPartSize is the size of a chunk during multi-part uploads. It is +// exposed only for unittests. +var UploadPartSize = 16 << 20 + +func newUploader(ctx context.Context, clientsForAction clientsForActionFunc, opts Options, path, bucket, key string, fileOpts file.Opts) (*s3Uploader, error) { + clients, err := clientsForAction(ctx, "PutObject", bucket, key) + if err != nil { + return nil, errors.E(err, "s3file.write", path) + } + params := &s3.CreateMultipartUploadInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + } + // Add any non-default options + if opts.ServerSideEncryption != "" { + params.SetServerSideEncryption(opts.ServerSideEncryption) + } + + u := &s3Uploader{ + ctx: ctx, + path: path, + bucket: bucket, + key: key, + opts: fileOpts, + s3opts: opts, + createTime: time.Now(), + bufPool: sync.Pool{New: func() interface{} { slice := make([]byte, UploadPartSize); return &slice }}, + nextPartNum: 1, + } + policy := newBackoffPolicy(clients, file.Opts{}) + for { + var ids s3RequestIDs + resp, err := policy.client().CreateMultipartUploadWithContext(ctx, + params, ids.captureOption()) + if policy.shouldRetry(ctx, err, path) { + continue + } + if err != nil { + return nil, annotate(err, ids, &policy, "s3file.CreateMultipartUploadWithContext", path) + } + u.client = policy.client() + u.uploadID = *resp.UploadId + if u.uploadID == "" { + panic(fmt.Sprintf("empty uploadID: %+v, awsrequestID: %v", resp, ids)) + } + break + } + + u.reqCh = make(chan uploadChunk, uploadParallelism) + for i := 0; i < uploadParallelism; i++ { + u.sg.Add(1) + go u.uploadThread() + } + return u, nil +} + +func (u *s3Uploader) uploadThread() { + defer u.sg.Done() + for chunk := range u.reqCh { + policy := newBackoffPolicy([]s3iface.S3API{chunk.client}, file.Opts{}) + retry: + params := &s3.UploadPartInput{ + Bucket: aws.String(u.bucket), + Key: aws.String(u.key), + Body: bytes.NewReader(*chunk.buf), + UploadId: aws.String(chunk.uploadID), + PartNumber: &chunk.partNum, + } + var ids s3RequestIDs + resp, err := chunk.client.UploadPartWithContext(u.ctx, params, ids.captureOption()) + if policy.shouldRetry(u.ctx, err, u.path) { + goto retry + } + u.bufPool.Put(chunk.buf) + if err != nil { + u.err.Set(annotate(err, ids, &policy, fmt.Sprintf("s3file.UploadPartWithContext s3://%s/%s", u.bucket, u.key))) + continue + } + partNum := chunk.partNum + completed := &s3.CompletedPart{ETag: resp.ETag, PartNumber: &partNum} + u.mu.Lock() + u.parts = append(u.parts, completed) + u.mu.Unlock() + } +} + +// write appends data to file. It can be called only by the request thread. +func (u *s3Uploader) write(buf []byte) { + if len(buf) == 0 { + panic("empty buf in write") + } + for len(buf) > 0 { + if u.curBuf == nil { + u.curBuf = u.bufPool.Get().(*[]byte) + *u.curBuf = (*u.curBuf)[:0] + } + if cap(*u.curBuf) != UploadPartSize { + panic("empty buf") + } + uploadBuf := *u.curBuf + space := uploadBuf[len(uploadBuf):cap(uploadBuf)] + n := len(buf) + if n < len(space) { + copy(space, buf) + *u.curBuf = uploadBuf[0 : len(uploadBuf)+n] + return + } + copy(space, buf) + buf = buf[len(space):] + *u.curBuf = uploadBuf[0:cap(uploadBuf)] + u.reqCh <- uploadChunk{client: u.client, uploadID: u.uploadID, partNum: u.nextPartNum, buf: u.curBuf} + u.nextPartNum++ + u.curBuf = nil + } +} + +func (u *s3Uploader) abort() error { + policy := newBackoffPolicy([]s3iface.S3API{u.client}, file.Opts{}) + for { + var ids s3RequestIDs + _, err := u.client.AbortMultipartUploadWithContext(u.ctx, &s3.AbortMultipartUploadInput{ + Bucket: aws.String(u.bucket), + Key: aws.String(u.key), + UploadId: aws.String(u.uploadID), + }, ids.captureOption()) + if !policy.shouldRetry(u.ctx, err, u.path) { + if err != nil { + err = annotate(err, ids, &policy, fmt.Sprintf("s3file.AbortMultiPartUploadWithContext s3://%s/%s", u.bucket, u.key)) + } + return err + } + } +} + +// finish finishes writing. It can be called only by the request thread. +func (u *s3Uploader) finish() error { + if u.curBuf != nil && len(*u.curBuf) > 0 { + u.reqCh <- uploadChunk{client: u.client, uploadID: u.uploadID, partNum: u.nextPartNum, buf: u.curBuf} + u.curBuf = nil + } + close(u.reqCh) + u.sg.Wait() + policy := newBackoffPolicy([]s3iface.S3API{u.client}, file.Opts{}) + if err := u.err.Err(); err != nil { + u.abort() // nolint: errcheck + return err + } + if len(u.parts) == 0 { + // Special case: an empty file. CompleteMultiPartUpload with empty parts causes an error, + // so work around the bug by issuing a separate PutObject request. + u.abort() // nolint: errcheck + for { + input := &s3.PutObjectInput{ + Bucket: aws.String(u.bucket), + Key: aws.String(u.key), + Body: bytes.NewReader(nil), + } + if u.s3opts.ServerSideEncryption != "" { + input.SetServerSideEncryption(u.s3opts.ServerSideEncryption) + } + + var ids s3RequestIDs + _, err := u.client.PutObjectWithContext(u.ctx, input, ids.captureOption()) + if !policy.shouldRetry(u.ctx, err, u.path) { + if err != nil { + err = annotate(err, ids, &policy, fmt.Sprintf("s3file.PutObjectWithContext s3://%s/%s", u.bucket, u.key)) + } + u.err.Set(err) + break + } + } + return u.err.Err() + } + // Common case. Complete the multi-part upload. + closeStartTime := time.Now() + sort.Slice(u.parts, func(i, j int) bool { // Parts must be sorted in PartNumber order. + return *u.parts[i].PartNumber < *u.parts[j].PartNumber + }) + params := &s3.CompleteMultipartUploadInput{ + Bucket: aws.String(u.bucket), + Key: aws.String(u.key), + UploadId: aws.String(u.uploadID), + MultipartUpload: &s3.CompletedMultipartUpload{Parts: u.parts}, + } + for { + var ids s3RequestIDs + _, err := u.client.CompleteMultipartUploadWithContext(u.ctx, params, ids.captureOption()) + if aerr, ok := getAWSError(err); ok && aerr.Code() == "NoSuchUpload" { + if u.opts.IgnoreNoSuchUpload { + // Here we managed to upload >=1 part, so the uploadID must have been + // valid some point in the past. + // + // TODO(saito) we could check that upload isn't too old (say <= 7 days), + // or that the file actually exists. + log.Error.Printf("close %s: IgnoreNoSuchUpload is set; ignoring %v %+v", u.path, err, ids) + err = nil + } + } + if !policy.shouldRetry(u.ctx, err, u.path) { + if err != nil { + err = annotate(err, ids, &policy, + fmt.Sprintf("s3file.CompleteMultipartUploadWithContext s3://%s/%s, "+ + "created at %v, started closing at %v, failed at %v", + u.bucket, u.key, u.createTime, closeStartTime, time.Now())) + } + u.err.Set(err) + break + } + } + if u.err.Err() != nil { + u.abort() // nolint: errcheck + } + return u.err.Err() +} diff --git a/file/s3file/internal/autolog/autolog.go b/file/s3file/internal/autolog/autolog.go new file mode 100644 index 00000000..8bd9e602 --- /dev/null +++ b/file/s3file/internal/autolog/autolog.go @@ -0,0 +1,26 @@ +package autolog + +import ( + "flag" + "time" +) + +var autologPeriod = flag.Duration("s3file.autolog_period", 0, + "Interval for logging s3transport metrics. Zero disables logging.") + +// Register configures an internal ticker to periodically call logFn. +func Register(logFn func()) { + if *autologPeriod == 0 { + return + } + go func() { + ticker := time.NewTicker(*autologPeriod) + defer ticker.Stop() + for { + select { + case _ = <-ticker.C: + logFn() + } + } + }() +} diff --git a/file/s3file/internal/cmd/resolvetest/main.go b/file/s3file/internal/cmd/resolvetest/main.go new file mode 100644 index 00000000..c3b5dcf2 --- /dev/null +++ b/file/s3file/internal/cmd/resolvetest/main.go @@ -0,0 +1,62 @@ +// resolvetest simply resolves a hostname at an increasing time interval to +// observe the diversity in DNS lookup addresses for the host. +// +// This quick experiment is motivated by the S3 performance guide, which +// recommends using multiple clients with different remote IPs: +// +// Finally, it’s worth paying attention to DNS and double-checking that +// requests are being spread over a wide pool of Amazon S3 IP addresses. DNS +// queries for Amazon S3 cycle through a large list of IP endpoints. But +// caching resolvers or application code that reuses a single IP address do +// not benefit from address diversity and the load balancing that follows from it. +// +// http://web.archive.org/web/20200624062712/https://docs.aws.amazon.com/AmazonS3/latest/dev/optimizing-performance-design-patterns.html +package main + +import ( + "bufio" + "fmt" + "net" + "os" + "time" + + "github.com/grailbio/base/log" +) + +func main() { + if len(os.Args) > 2 { + log.Fatal("expect 1 argument: hostname to resolve") + } + host := "us-west-2.s3.amazonaws.com" + if len(os.Args) == 2 { + host = os.Args[1] + } + + last := time.Now() + bufOut := bufio.NewWriter(os.Stdout) + for sleepDuration := time.Millisecond; ; { + now := time.Now() + _, _ = fmt.Fprintf(bufOut, "%.6f:\t", now.Sub(last).Seconds()) + last = now + + ips, err := net.LookupIP(host) + if err != nil { + _, _ = bufOut.WriteString(err.Error()) + } else { + for i, ip := range ips { + if i > 0 { + _ = bufOut.WriteByte(' ') + } + _, _ = bufOut.WriteString(ip.String()) + } + } + + _ = bufOut.WriteByte('\n') + _ = bufOut.Flush() + + time.Sleep(sleepDuration) + if sleepDuration < time.Second { + sleepDuration *= 2 + } + } +} diff --git a/file/s3file/list.go b/file/s3file/list.go new file mode 100644 index 00000000..44c5498a --- /dev/null +++ b/file/s3file/list.go @@ -0,0 +1,201 @@ +package s3file + +import ( + "context" + "fmt" + "strings" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/grailbio/base/file" + "github.com/grailbio/base/log" +) + +// List implements file.Implementation interface. +func (impl *s3Impl) List(ctx context.Context, dir string, recurse bool) file.Lister { + scheme, bucket, key, err := ParseURL(dir) + if err != nil { + return &s3Lister{ctx: ctx, dir: dir, err: err} + } + if bucket == "" { + if recurse { + return &s3Lister{ctx: ctx, dir: dir, + err: fmt.Errorf("list %s: ListBuckets cannot be combined with recurse option", dir)} + } + clients, clientsErr := impl.clientsForAction(ctx, "ListAllMyBuckets", bucket, key) + if clientsErr != nil { + return &s3Lister{ctx: ctx, dir: dir, err: clientsErr} + } + return &s3BucketLister{ + ctx: ctx, + scheme: scheme, + clients: clients, + } + } + clients, err := impl.clientsForAction(ctx, "ListBucket", bucket, key) + if err != nil { + return &s3Lister{ctx: ctx, dir: dir, err: err} + } + return &s3Lister{ + ctx: ctx, + policy: newBackoffPolicy(clients, file.Opts{}), + dir: dir, + scheme: scheme, + bucket: bucket, + prefix: key, + recurse: recurse, + } +} + +type s3Lister struct { + ctx context.Context + policy retryPolicy + dir, scheme, bucket, prefix string + + object s3Obj + objects []s3Obj + token *string + err error + done bool + recurse bool + + // consecutiveEmptyResponses counts how many times S3's ListObjectsV2WithContext returned + // 0 records (either contents or common prefixes) consecutively. + // Many empty responses would cause Scan to appear to hang, so we log a warning. + consecutiveEmptyResponses int +} + +type s3Obj struct { + obj *s3.Object + cp *string +} + +func (o s3Obj) name() string { + if o.obj == nil { + return *o.cp + } + return *o.obj.Key +} + +// Scan implements Lister.Scan +func (l *s3Lister) Scan() bool { + for { + if l.err != nil { + return false + } + l.err = l.ctx.Err() + if l.err != nil { + return false + } + if len(l.objects) > 0 { + l.object, l.objects = l.objects[0], l.objects[1:] + ll := len(l.prefix) + // Ignore keys whose path component isn't exactly equal to l.prefix. For + // example, if l.prefix="foo/bar", then we yield "foo/bar" and + // "foo/bar/baz", but not "foo/barbaz". + keyVal := l.object.name() + if ll > 0 && len(keyVal) > ll { + if l.prefix[ll-1] == '/' { + // Treat prefix "foo/bar/" as "foo/bar". + ll-- + } + if keyVal[ll] != '/' { + continue + } + } + return true + } + if l.done { + return false + } + + var prefix string + if l.showDirs() && !strings.HasSuffix(l.prefix, pathSeparator) && l.prefix != "" { + prefix = l.prefix + pathSeparator + } else { + prefix = l.prefix + } + + req := &s3.ListObjectsV2Input{ + Bucket: aws.String(l.bucket), + ContinuationToken: l.token, + Prefix: aws.String(prefix), + } + + if l.showDirs() { + req.Delimiter = aws.String(pathSeparator) + } + var ids s3RequestIDs + res, err := l.policy.client().ListObjectsV2WithContext(l.ctx, req, ids.captureOption()) + if l.policy.shouldRetry(l.ctx, err, l.dir) { + continue + } + if err != nil { + l.err = annotate(err, ids, &l.policy, fmt.Sprintf("s3file.list s3://%s/%s", l.bucket, l.prefix)) + return false + } + l.token = res.NextContinuationToken + nRecords := len(res.Contents) + if l.showDirs() { + nRecords += len(res.CommonPrefixes) + } + if nRecords > 0 { + l.consecutiveEmptyResponses = 0 + } else { + l.consecutiveEmptyResponses++ + if n := l.consecutiveEmptyResponses; n > 7 && n&(n-1) == 0 { + log.Printf("s3file.list.scan: warning: S3 returned empty response %d consecutive times", n) + } + } + l.objects = make([]s3Obj, 0, nRecords) + for _, objVal := range res.Contents { + l.objects = append(l.objects, s3Obj{obj: objVal}) + } + if l.showDirs() { // add the pseudo Dirs + for _, cpVal := range res.CommonPrefixes { + // Follow the Linux convention that directories do not come back with a trailing / + // when read by ListDir. To determine it is a directory, it is necessary to + // call implementation.Stat on the path and check IsDir() + pseudoDirName := *cpVal.Prefix + if strings.HasSuffix(pseudoDirName, pathSeparator) { + pseudoDirName = pseudoDirName[:len(pseudoDirName)-1] + } + l.objects = append(l.objects, s3Obj{cp: &pseudoDirName}) + } + } + + l.done = !aws.BoolValue(res.IsTruncated) + } +} + +// Path implements Lister.Path +func (l *s3Lister) Path() string { + return fmt.Sprintf("%s://%s/%s", l.scheme, l.bucket, l.object.name()) +} + +// Info implements Lister.Info +func (l *s3Lister) Info() file.Info { + if obj := l.object.obj; obj != nil { + return &s3Info{ + size: *obj.Size, + modTime: *obj.LastModified, + etag: *obj.ETag, + } + } + return nil +} + +// IsDir implements Lister.IsDir +func (l *s3Lister) IsDir() bool { + return l.object.cp != nil +} + +// Err returns an error, if any. +func (l *s3Lister) Err() error { + return l.err +} + +// showDirs controls whether CommonPrefixes are returned during a scan +func (l *s3Lister) showDirs() bool { + return !l.recurse +} diff --git a/file/s3file/list_bucket.go b/file/s3file/list_bucket.go new file mode 100644 index 00000000..c2e43d5a --- /dev/null +++ b/file/s3file/list_bucket.go @@ -0,0 +1,106 @@ +package s3file + +import ( + "context" + "fmt" + "sort" + "sync" + + "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3iface" + "github.com/grailbio/base/errors" + "github.com/grailbio/base/file" + "github.com/grailbio/base/log" + "github.com/grailbio/base/traverse" +) + +type s3BucketLister struct { + ctx context.Context + clients []s3iface.S3API + scheme string + + err error + listed bool + bucket string + buckets []string +} + +func (l *s3BucketLister) Scan() bool { + if !l.listed { + l.buckets, l.err = combineClientBuckets(l.ctx, l.clients) + l.listed = true + } + if l.err != nil || len(l.buckets) == 0 { + return false + } + l.bucket, l.buckets = l.buckets[0], l.buckets[1:] + return true +} + +// combineClientBuckets returns the union of buckets from each client, since each may have +// different permissions. +func combineClientBuckets(ctx context.Context, clients []s3iface.S3API) ([]string, error) { + var ( + uniqueBucketsMu sync.Mutex + uniqueBuckets = map[string]struct{}{} + ) + err := traverse.Parallel.Each(len(clients), func(clientIdx int) error { + buckets, err := listClientBuckets(ctx, clients[clientIdx]) + if err != nil { + if errors.Is(errors.NotAllowed, err) { + log.Debug.Printf("s3file.listbuckets: ignoring: %v", err) + return nil + } + return err + } + uniqueBucketsMu.Lock() + defer uniqueBucketsMu.Unlock() + for _, bucket := range buckets { + uniqueBuckets[bucket] = struct{}{} + } + return nil + }) + if err != nil { + return nil, err + } + buckets := make([]string, 0, len(uniqueBuckets)) + for bucket := range uniqueBuckets { + buckets = append(buckets, bucket) + } + sort.Strings(buckets) + return buckets, nil +} + +func listClientBuckets(ctx context.Context, client s3iface.S3API) ([]string, error) { + policy := newBackoffPolicy([]s3iface.S3API{client}, file.Opts{}) + for { + var ids s3RequestIDs + res, err := policy.client().ListBucketsWithContext(ctx, &s3.ListBucketsInput{}, ids.captureOption()) + if policy.shouldRetry(ctx, err, "listbuckets") { + continue + } + if err != nil { + return nil, annotate(err, ids, &policy, "s3file.listbuckets") + } + buckets := make([]string, len(res.Buckets)) + for i, bucket := range res.Buckets { + buckets[i] = *bucket.Name + } + return buckets, nil + } +} + +func (l *s3BucketLister) Path() string { + return fmt.Sprintf("%s://%s", l.scheme, l.bucket) +} + +func (l *s3BucketLister) Info() file.Info { return nil } + +func (l *s3BucketLister) IsDir() bool { + return true +} + +// Err returns an error, if any. +func (l *s3BucketLister) Err() error { + return l.err +} diff --git a/file/s3file/list_bucket_test.go b/file/s3file/list_bucket_test.go new file mode 100644 index 00000000..dbedea6a --- /dev/null +++ b/file/s3file/list_bucket_test.go @@ -0,0 +1,68 @@ +package s3file + +import ( + "context" + "testing" + + "github.com/aws/aws-sdk-go/aws" + awsrequest "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3iface" + "github.com/grailbio/testutil/assert" +) + +func TestS3BucketLister(t *testing.T) { + lister := s3BucketLister{ + ctx: context.Background(), + scheme: "s3", + clients: []s3iface.S3API{ + listBucketsFakeClient{}, + listBucketsFakeClient{ + buckets: []*s3.Bucket{ + {Name: aws.String("bucketA")}, + {Name: aws.String("bucketC")}, + }, + }, + listBucketsFakeClient{ + buckets: []*s3.Bucket{ + {Name: aws.String("bucketC")}, + {Name: aws.String("bucketB")}, + }, + }, + }, + } + + assert.NoError(t, lister.Err()) + + assert.True(t, lister.Scan()) + assert.EQ(t, lister.Path(), "s3://bucketA") // expect alphabetical order + assert.EQ(t, lister.IsDir(), true) + _ = lister.Info() // expect nothing, but it must not panic + assert.NoError(t, lister.Err()) + + assert.True(t, lister.Scan()) + assert.EQ(t, lister.Path(), "s3://bucketB") + assert.EQ(t, lister.IsDir(), true) + _ = lister.Info() + assert.NoError(t, lister.Err()) + + assert.True(t, lister.Scan()) + assert.EQ(t, lister.Path(), "s3://bucketC") + assert.EQ(t, lister.IsDir(), true) + _ = lister.Info() + assert.NoError(t, lister.Err()) + + assert.False(t, lister.Scan()) + assert.NoError(t, lister.Err()) +} + +type listBucketsFakeClient struct { + buckets []*s3.Bucket // stub response + s3iface.S3API // all other methods panic with nil dereference +} + +func (c listBucketsFakeClient) ListBucketsWithContext( + aws.Context, *s3.ListBucketsInput, ...awsrequest.Option, +) (*s3.ListBucketsOutput, error) { + return &s3.ListBucketsOutput{Buckets: c.buckets}, nil +} diff --git a/file/s3file/metrics.go b/file/s3file/metrics.go new file mode 100644 index 00000000..ed8f3856 --- /dev/null +++ b/file/s3file/metrics.go @@ -0,0 +1,149 @@ +package s3file + +import ( + "expvar" + "flag" + "fmt" + "io" + "strings" + "sync" + "time" + + "github.com/grailbio/base/log" +) + +var ( + metricAutologOnce sync.Once + metricAutologPeriod = flag.Duration("s3file.metric_log_period", 0, + "Interval for logging S3 operation metrics. Zero disables logging.") +) + +func metricAutolog() { + metricAutologOnce.Do(func() { + if period := *metricAutologPeriod; period > 0 { + go logMetricsLoop(period) + } + }) +} + +type metricOpMap struct{ m sync.Map } + +func (m *metricOpMap) Op(key string) *metricOp { + var init metricOp + got, _ := m.m.LoadOrStore(key, &init) + return got.(*metricOp) +} + +func (m *metricOpMap) VisitAndReset(f func(string, *metricOp)) { + m.m.Range(func(key, value interface{}) bool { + m.m.Delete(key) + f(key.(string), value.(*metricOp)) + return true + }) +} + +var ( + metrics metricOpMap + metricRemoteAddrs expvar.Map +) + +type metricOp struct { + Count expvar.Int + + Retry1 expvar.Int + Retry2 expvar.Int + Retry4 expvar.Int + Retry8 expvar.Int + + DurationFast expvar.Int + Duration1Ms expvar.Int + Duration10Ms expvar.Int + Duration100Ms expvar.Int + Duration1S expvar.Int + Duration10S expvar.Int + Duration100S expvar.Int + + Bytes expvar.Int +} + +type metricOpProgress struct { + parent *metricOp + start time.Time + retries int // == 0 if first try succeeds +} + +func (m *metricOp) Start() *metricOpProgress { + m.Count.Add(1) + return &metricOpProgress{m, time.Now(), 0} +} + +func (m *metricOpProgress) Retry() { m.retries++ } + +func (m *metricOpProgress) Bytes(b int) { m.parent.Bytes.Add(int64(b)) } + +func (m *metricOpProgress) Done() { + switch { + case m.retries >= 8: + m.parent.Retry8.Add(1) + case m.retries >= 4: + m.parent.Retry4.Add(1) + case m.retries >= 2: + m.parent.Retry2.Add(1) + case m.retries >= 1: + m.parent.Retry1.Add(1) + } + + took := time.Since(m.start) + switch { + case took > 100*time.Second: + m.parent.Duration100S.Add(1) + case took > 10*time.Second: + m.parent.Duration10S.Add(1) + case took > time.Second: + m.parent.Duration1S.Add(1) + case took > 100*time.Millisecond: + m.parent.Duration100Ms.Add(1) + case took > 10*time.Millisecond: + m.parent.Duration10Ms.Add(1) + case took > 1*time.Millisecond: + m.parent.Duration1Ms.Add(1) + default: + m.parent.DurationFast.Add(1) + } +} + +func (m *metricOp) Write(w io.Writer, period time.Duration) (int, error) { + perMinute := 60 / period.Seconds() + return fmt.Fprintf(w, "n:%d r:%d/%d/%d/%d t:%d/%d/%d/%d/%d/%d/%d mib:%d [/min]", + int(float64(m.Count.Value())*perMinute), + int(float64(m.Retry1.Value())*perMinute), + int(float64(m.Retry2.Value())*perMinute), + int(float64(m.Retry4.Value())*perMinute), + int(float64(m.Retry8.Value())*perMinute), + int(float64(m.DurationFast.Value())*perMinute), + int(float64(m.Duration1Ms.Value())*perMinute), + int(float64(m.Duration10Ms.Value())*perMinute), + int(float64(m.Duration100Ms.Value())*perMinute), + int(float64(m.Duration1S.Value())*perMinute), + int(float64(m.Duration10S.Value())*perMinute), + int(float64(m.Duration100S.Value())*perMinute), + int(float64(m.Bytes.Value())/(1<<20)*perMinute), + ) +} + +func logMetricsLoop(period time.Duration) { + ticker := time.NewTicker(period) + defer ticker.Stop() + var buf strings.Builder + for { + select { + case <-ticker.C: + metrics.VisitAndReset(func(op string, metrics *metricOp) { + buf.Reset() + fmt.Fprintf(&buf, "s3file metrics: op:%s ", op) + _, _ = metrics.Write(&buf, period) + log.Print(buf.String()) + }) + } + } +} diff --git a/file/s3file/retry.go b/file/s3file/retry.go new file mode 100644 index 00000000..12cfba2f --- /dev/null +++ b/file/s3file/retry.go @@ -0,0 +1,181 @@ +package s3file + +import ( + "context" + "strings" + "time" + + awsrequest "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3iface" + "github.com/grailbio/base/file" + "github.com/grailbio/base/log" + "github.com/grailbio/base/retry" +) + +var ( + // BackoffPolicy defines backoff timing parameters. It's exported for unit tests only. + // TODO(josh): Rename to `RetryPolicy`. + // TODO(josh): Create `retry.ThrottlePolicy` and `retry.AIMDPolicy` and use here. + BackoffPolicy = retry.Jitter(retry.Backoff(500*time.Millisecond, time.Minute, 1.2), 0.2) + + // WithDeadline allows faking context.WithDeadline. It's exported for unit tests only. + WithDeadline = context.WithDeadline + + // MaxRetryDuration defines the max amount of time a request can spend + // retrying on errors. + // + // Requirements: + // + // - The value must be >5 minutes. 5 min is the S3 negative-cache TTL. If + // less than 5 minutes, an Open() call w/ RetryWhenNotFound may fail. + // + // - It must be long enough to allow CompleteMultiPartUpload to finish after a + // retry. The doc says it may take a few minutes even in a successful case. + MaxRetryDuration = 60 * time.Minute +) + +// TODO: Rename to `retrier`. +type retryPolicy struct { + clients []s3iface.S3API + policy retry.Policy + opts file.Opts // passed to Open() or Stat request. + startTime time.Time // the time requested started. + retryDeadline time.Time // when to give up retrying. + retries int + waitErr error // error happened during wait, typically deadline or cancellation. +} + +func newBackoffPolicy(clients []s3iface.S3API, opts file.Opts) retryPolicy { + now := time.Now() + return retryPolicy{ + clients: clients, + policy: BackoffPolicy, + opts: opts, + startTime: now, + retryDeadline: now.Add(MaxRetryDuration), + } +} + +// client returns the s3 client to be use by the caller. +func (r *retryPolicy) client() s3iface.S3API { return r.clients[0] } + +// shouldRetry determines if the caller should retry after seeing the given +// error. It will modify r.clients if it thinks the caller should retry with a +// different client. +func (r *retryPolicy) shouldRetry(ctx context.Context, err error, message string) bool { + wait := func() bool { + ctx2, cancel := WithDeadline(ctx, r.retryDeadline) + r.waitErr = retry.Wait(ctx2, r.policy, r.retries) + cancel() + if r.waitErr != nil { + // Context timeout or cancellation + r.clients = nil + return false + } + r.retries++ + return true + } + + if err == nil { + return false + } + if awsrequest.IsErrorRetryable(err) || awsrequest.IsErrorThrottle(err) || otherRetriableError(err) { + // Transient errors. Retry with the same client. + log.Printf("retry %s: %v", message, err) + return wait() + } + aerr, ok := getAWSError(err) + if ok { + if r.opts.RetryWhenNotFound && aerr.Code() == s3.ErrCodeNoSuchKey { + log.Printf("retry %s (not found): %v", message, err) + return wait() + } + + switch aerr.Code() { + case s3.ErrCodeNoSuchBucket, s3.ErrCodeNoSuchKey: + // No point in trying again. + r.clients = nil + return false + case "NotFound": + // GetObject seems to return this error rather ErrCodeNoSuchKey + r.clients = nil + return false + default: + // Possible cases: + // + //- permission errors: we retry using a different client. + // + //- non-retriable errors: we retry using a different client, and it will + // fail again, and we eventually give up. The code it at least correct, if + // suboptimal. + // + // - transient errors we don't yet know. We'll abort when we shouldn't, + // but there's not much we can do. We'll add these errors to the above + // case as we discover them. + } + } + if len(r.clients) <= 1 { + // No more alternate clients to try + r.clients = nil + return false + } + r.clients = r.clients[1:] + return true +} + +// Retriable errors not listed in aws' retry policy. +func otherRetriableError(err error) bool { + aerr, ok := getAWSError(err) + if ok && (aerr.Code() == awsrequest.ErrCodeSerialization || + aerr.Code() == awsrequest.ErrCodeRead || + // The AWS SDK method IsErrorRetryable doesn't consider certain errors as retryable + // depending on the underlying cause. (For a detailed explanation as to why, + // see https://github.com/aws/aws-sdk-go/issues/3027) + // In our case, we can safely consider every error of type "RequestError" regardless + // of the underlying cause as a retryable error. + aerr.Code() == "RequestError" || + aerr.Code() == "SlowDown" || + aerr.Code() == "InternalError" || + aerr.Code() == "InternalServerError") { + return true + } + if ok && aerr.Code() == "XAmzContentSHA256Mismatch" { + // Example: + // + // XAmzContentSHA256Mismatch: The provided 'x-amz-content-sha256' header + // does not match what was computed. + // + // Happens sporadically for no discernible reason. Just retry. + return true + } + if ok { + msg := strings.TrimSpace(aerr.Message()) + if strings.HasSuffix(msg, "amazonaws.com: no such host") { + // Example: + // + // RequestError: send request failed caused by: Get + // https://grail-patchcnn.s3.us-west-2.amazonaws.com/key: dial tcp: lookup + // grail-patchcnn.s3.us-west-2.amazonaws.com: no such host + // + // This a DNS lookup error on the client side. This may be + // grail-specific. This error happens after S3 server resolves the bucket + // successfully, and redirects the client to a backend to fetch data. So + // accessing a non-existent bucket will not hit this path. + return true + } + } + msg := err.Error() + if strings.Contains(msg, "resource unavailable") || + strings.Contains(msg, "Service Unavailable") || + // As of v1.42.0, the AWS SDK marks these errors as non-retriable [1]. We think we see these + // errors when an S3 host is throttling us so we actually do want to retry. + // Note: Empirically, the s3transport package's workaround reduces the occurrence of these + // errors in our workloads, but we still see them occasionally. + // + // [1] https://github.com/aws/aws-sdk-go/blob/e04cf0432b79324cae8af9e8e333404c18268137/aws/request/connection_reset_error.go#L9 + strings.Contains(msg, "read: connection reset") { + return true + } + return false +} diff --git a/file/s3file/s3file.go b/file/s3file/s3file.go index a0661aa8..34fa1eab 100644 --- a/file/s3file/s3file.go +++ b/file/s3file/s3file.go @@ -6,32 +6,25 @@ package s3file import ( - "bytes" "context" "fmt" - "io" - "path/filepath" - "sort" + "net/http" "strings" - "sync" "time" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" + awsrequest "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/s3" "github.com/aws/aws-sdk-go/service/s3/s3iface" - "github.com/grailbio/base/errorreporter" + "github.com/grailbio/base/errors" "github.com/grailbio/base/file" - "github.com/grailbio/base/log" - "github.com/pkg/errors" ) -// Path separator used by s3file. -const pathSeparator = "/" - -type maxRetrier interface { - MaxRetries() int -} +const ( + Scheme = "s3" + pathSeparator = "/" + pathPrefix = "s3://" +) // Options defines options that can be given when creating an s3Impl type Options struct { @@ -41,163 +34,15 @@ type Options struct { } type s3Impl struct { - provider ClientProvider - options Options -} - -// s3Info implements file.Info interface. -type s3Info struct { - name string - size int64 - modTime time.Time - etag string // = GetObjectOutput.ETag -} - -type s3Obj struct { - obj *s3.Object - cp *string -} - -type accessMode int - -const ( - readonly accessMode = iota // file is opened by Open. - writeonly // file is opened by Create. - - // TODO(saito) Stop using s3 upload manager. Implement cross-file throttling - // instead. - uploadPartSize = 16 << 20 - uploadParallelism = 16 -) - -// Operations on a file are internally implemented by a goroutine running -// handleRequests. Requests to handleRequests are sent through s3File.reqCh. The -// response to a request is sent through request.ch. -// -// The user-facing s3File methods, such as Read and Seek are implemented in the following way: -// -// - Create a chan response. -// -// - Send a request object through s3File.ch. The response channel is included -// in the request. handleRequests() receives the request, handles the request, -// and sends the response. -// -// - Wait for a message from either the response channel or context.Done(), -// whichever comes first. - -type requestType int - -const ( - seekRequest requestType = iota - readRequest - statRequest - writeRequest - closeRequest - abortRequest -) - -type request struct { - ctx context.Context // context passed to Read, Seek, Close, etc. - reqType requestType - - // For Read and Write - buf []byte - - // For Seek - off int64 - whence int - - // For sending the response - ch chan response -} - -type response struct { - n int // # of bytes read. Set only by Read. - off int64 // Seek location. Set only by Seek. - info *s3Info // Set only by Stat. - err error // Any error - uploader *s3Uploader -} - -// s3File implements file.File interface. -type s3File struct { - name string // "s3://bucket/key/.." - provider ClientProvider // Used to create s3 clients. - mode accessMode - - bucket string // bucket part of "name". - key string // key part "name". - - reqCh chan request - - // The following fields are accessed only by the handleRequests thread. - info *s3Info // File metadata. Filled on demand. - - // Active GetObject body reader. Created by a Read() request. Closed on Seek - // or Close call. - bodyReader io.ReadCloser - - // Seek offset. - // INVARIANT: position >= 0 && (position > 0 ⇒ info != nil) - position int64 - - // Used by files opened for writing. - uploader *s3Uploader -} - -type s3Lister struct { - ctx context.Context - clients []s3iface.S3API - dir, scheme, bucket, prefix string - - object s3Obj - objects []s3Obj - token *string - err error - done bool - recurse bool -} - -// s3Reader implements io.ReadSeeker for S3. -type s3Reader struct { - ctx context.Context - f *s3File -} - -// s3Reader implements a placeholder io.Writer for S3. -type s3Writer struct { - ctx context.Context - f *s3File -} - -func shouldRetry(path string, err error, clients *[]s3iface.S3API) bool { - if err == nil { - return false - } - if len(*clients) <= 1 { - // No more alternate clients to try - return false - } - if aerr, ok := err.(awserr.Error); ok { - switch aerr.Code() { - case s3.ErrCodeNoSuchBucket, s3.ErrCodeNoSuchKey: - // No point in trying again. - return false - case "NotFound": - // GetObject seems to return this error rather ErrCodeNoSuchKey - return false - default: - // Should retry with a different ticket. - } - } - *clients = (*clients)[1:] - return true + clientsForAction clientsForActionFunc + options Options } // NewImplementation creates a new file.Implementation for S3. The provider is // called to create s3 client objects. -func NewImplementation(provider ClientProvider, opts Options) file.Implementation { - return &s3Impl{provider, opts} +func NewImplementation(provider SessionProvider, opts Options) file.Implementation { + metricAutolog() + return &s3Impl{newClientCache(provider).forAction, opts} } // Run handler in a separate goroutine, then wait for either the handler to @@ -212,11 +57,38 @@ func runRequest(ctx context.Context, handler func() response) response { case res := <-ch: return res case <-ctx.Done(): - return response{err: fmt.Errorf("Request cancelled")} + return response{err: errors.E(errors.Canceled)} } } -func (impl *s3Impl) internalOpen(ctx context.Context, path string, mode accessMode) (file.File, error) { +// String implements a human-readable description. +func (impl *s3Impl) String() string { return "s3" } + +// Open opens a file for reading. The provided path should be of form +// "bucket/key..." +func (impl *s3Impl) Open(ctx context.Context, path string, opts ...file.Opts) (file.File, error) { + f, err := impl.internalOpen(ctx, path, readonly, opts...) + res := f.runRequest(ctx, request{reqType: statRequest}) + if res.err != nil { + return nil, res.err + } + return f, err +} + +// Create opens a file for writing. +func (impl *s3Impl) Create(ctx context.Context, path string, opts ...file.Opts) (file.File, error) { + return impl.internalOpen(ctx, path, writeonly, opts...) +} + +type accessMode int + +const ( + readonly accessMode = iota // file is opened by Open. + writeonly // file is opened by Create. +) + +func (impl *s3Impl) internalOpen(ctx context.Context, path string, mode accessMode, optsList ...file.Opts) (*s3File, error) { + opts := mergeFileOpts(optsList) _, bucket, key, err := ParseURL(path) if err != nil { return nil, err @@ -224,7 +96,7 @@ func (impl *s3Impl) internalOpen(ctx context.Context, path string, mode accessMo var uploader *s3Uploader if mode == writeonly { resp := runRequest(ctx, func() response { - u, err := newUploader(ctx, impl.provider, impl.options, path, bucket, key) + u, err := newUploader(ctx, impl.clientsForAction, impl.options, path, bucket, key, opts) return response{uploader: u, err: err} }) if resp.err != nil { @@ -233,727 +105,100 @@ func (impl *s3Impl) internalOpen(ctx context.Context, path string, mode accessMo uploader = resp.uploader } f := &s3File{ - name: path, - mode: mode, - provider: impl.provider, - bucket: bucket, - key: key, - uploader: uploader, - reqCh: make(chan request, 16), + name: path, + mode: mode, + opts: opts, + clientsForAction: impl.clientsForAction, + bucket: bucket, + key: key, + uploader: uploader, + reqCh: make(chan request, 16), } go f.handleRequests() - return f, nil -} - -// Open opens a file for reading. The provided path should be of form -// "bucket/key..." -func (impl *s3Impl) Open(ctx context.Context, path string) (file.File, error) { - return impl.internalOpen(ctx, path, readonly) + return f, err } -// Create opens a file for writing. -func (impl *s3Impl) Create(ctx context.Context, path string) (file.File, error) { - return impl.internalOpen(ctx, path, writeonly) -} - -// String implements a human-readable description. -func (impl *s3Impl) String() string { return "s3" } - -// List implements file.Implementation interface. -func (impl *s3Impl) List(ctx context.Context, dir string, recurse bool) file.Lister { - scheme, bucket, key, err := ParseURL(dir) - if err != nil { - return &s3Lister{ctx: ctx, dir: dir, err: err} - } - clients, err := impl.provider.Get(ctx, "ListBucket", dir) - if err != nil { - return &s3Lister{ctx: ctx, dir: dir, err: err} - } - return &s3Lister{ - ctx: ctx, - clients: clients, - dir: dir, - scheme: scheme, - bucket: bucket, - prefix: key, - recurse: recurse, - } -} - -// Stat implements file.Implementation interface. -func (impl *s3Impl) Stat(ctx context.Context, path string) (file.Info, error) { +// Remove implements file.Implementation interface. +func (impl *s3Impl) Remove(ctx context.Context, path string) error { resp := runRequest(ctx, func() response { _, bucket, key, err := ParseURL(path) if err != nil { return response{err: err} } - clients, err := impl.provider.Get(ctx, "GetObject", path) + clients, err := impl.clientsForAction(ctx, "DeleteObject", bucket, key) if err != nil { - return response{err: err} + return response{err: errors.E(err, "s3file.remove", path)} } + policy := newBackoffPolicy(clients, file.Opts{}) for { - resp, err := clients[0].HeadObjectWithContext(ctx, &s3.HeadObjectInput{ - Bucket: aws.String(bucket), - Key: aws.String(key), - }) - if shouldRetry(path, err, &clients) { + var ids s3RequestIDs + _, err = policy.client().DeleteObjectWithContext(ctx, &s3.DeleteObjectInput{Bucket: aws.String(bucket), Key: aws.String(key)}, + ids.captureOption()) + if policy.shouldRetry(ctx, err, path) { continue } if err != nil { - return response{err: err} + err = annotate(err, ids, &policy, "s3file.remove", path) } - if *resp.ETag == "" { - return response{err: fmt.Errorf("stat %v: file does not exist", path)} - } - return response{info: &s3Info{ - name: filepath.Base(path), - size: *resp.ContentLength, - modTime: *resp.LastModified, - etag: *resp.ETag, - }} + return response{err: err} } }) - return resp.info, resp.err + return resp.err } -// Remove implements file.Implementation interface. -func (impl *s3Impl) Remove(ctx context.Context, path string) error { +// Presign implements file.Implementation interface. +func (impl *s3Impl) Presign(ctx context.Context, path, method string, expiry time.Duration) (string, error) { resp := runRequest(ctx, func() response { _, bucket, key, err := ParseURL(path) if err != nil { return response{err: err} } - clients, err := impl.provider.Get(ctx, "DeleteObject", path) - if err != nil { - return response{err: err} - } - for { - _, err = clients[0].DeleteObjectWithContext(ctx, &s3.DeleteObjectInput{Bucket: aws.String(bucket), Key: aws.String(key)}) - if shouldRetry(path, err, &clients) { - continue + var action string + var getRequestFn func(client s3iface.S3API) *awsrequest.Request + switch method { + case http.MethodGet: + action = "GetObject" + getRequestFn = func(client s3iface.S3API) *awsrequest.Request { + req, _ := client.GetObjectRequest(&s3.GetObjectInput{Bucket: &bucket, Key: &key}) + return req + } + case http.MethodPut: + action = "PutObject" + getRequestFn = func(client s3iface.S3API) *awsrequest.Request { + req, _ := client.PutObjectRequest(&s3.PutObjectInput{Bucket: &bucket, Key: &key}) + return req + } + case http.MethodDelete: + action = "DeleteObject" + getRequestFn = func(client s3iface.S3API) *awsrequest.Request { + req, _ := client.DeleteObjectRequest(&s3.DeleteObjectInput{Bucket: &bucket, Key: &key}) + return req } - return response{err: err} - } - }) - return resp.err -} - -func maxRetries(clients []s3iface.S3API) int { - for _, client := range clients { - if s, ok := client.(maxRetrier); ok && s.MaxRetries() > 0 { - return s.MaxRetries() - } - } - return defaultMaxRetries -} - -func (f *s3File) handleRequests() { - for req := range f.reqCh { - switch req.reqType { - case seekRequest: - f.handleSeek(req) - case readRequest: - f.handleRead(req) - case statRequest: - f.handleStat(req) - case writeRequest: - f.handleWrite(req) - case closeRequest: - f.handleClose(req) - case abortRequest: - f.handleAbort(req) default: - panic(fmt.Sprintf("Illegal request: %+v", req)) - } - close(req.ch) - } -} - -// Name returns the name of the file. -func (f *s3File) Name() string { - return f.name -} - -func (f *s3File) Close(ctx context.Context) error { - err := f.runRequest(ctx, request{reqType: closeRequest}).err - close(f.reqCh) - f.provider = nil - return err -} - -func (f *s3File) Discard(ctx context.Context) error { - if f.mode != writeonly { - return fmt.Errorf("discard %v: file is not opened in write mode", f.name) - } - err := f.runRequest(ctx, request{reqType: abortRequest}).err - close(f.reqCh) - f.provider = nil - return err -} - -func (f *s3File) String() string { - return "s3://" + f.name -} - -// Send a request to the handleRequests goroutine and wait for a response. The -// caller must set all the necessary fields in req, except ctx and ch, which are -// filled by this method. On ctx timeout or cancellation, returns a response -// with non-nil err field. -func (f *s3File) runRequest(ctx context.Context, req request) response { - resCh := make(chan response, 1) - req.ctx = ctx - req.ch = resCh - f.reqCh <- req - select { - case res := <-resCh: - return res - case <-ctx.Done(): - return response{err: fmt.Errorf("Request cancelled")} - } -} - -func (f *s3File) Stat(ctx context.Context) (file.Info, error) { - res := f.runRequest(ctx, request{reqType: statRequest}) - if res.err != nil { - return nil, res.err - } - return res.info, nil -} - -func (f *s3File) handleStat(req request) { - if err := f.maybeFillInfo(req.ctx); err != nil { - req.ch <- response{err: err} - return - } - if f.info == nil { - panic(fmt.Sprintf("failed to fill stats in %+v", f)) - } - req.ch <- response{info: f.info} -} - -func newInfo(path string, output *s3.GetObjectOutput) *s3Info { - return &s3Info{ - name: filepath.Base(path), - size: *output.ContentLength, - modTime: *output.LastModified, - etag: *output.ETag, - } -} - -func (f *s3File) maybeFillInfo(ctx context.Context) error { - if f.info != nil { - return nil - } - clients, err := f.provider.Get(ctx, "GetObject", f.name) - if err != nil { - return err - } - for { - output, err := clients[0].GetObjectWithContext(ctx, &s3.GetObjectInput{ - Bucket: aws.String(f.bucket), - Key: aws.String(f.key)}) - if shouldRetry(f.name, err, &clients) { - continue + return response{err: errors.E(errors.NotSupported, "s3file.presign: unsupported http method", method)} } + clients, err := impl.clientsForAction(ctx, action, bucket, key) if err != nil { - return err - } - if output.Body == nil { - panic("GetObject with nil Body") - } - output.Body.Close() // nolint: errcheck - if *output.ETag == "" { - return fmt.Errorf("read %v: File does not exist", f.name) - } - f.info = newInfo(f.name, output) - return nil - } -} - -func (f *s3File) Reader(ctx context.Context) io.ReadSeeker { - if f.mode != readonly { - return file.NewErrorReader(fmt.Errorf("reader %v: file is not opened in read mode", f.name)) - } - return &s3Reader{ctx: ctx, f: f} -} - -func (f *s3File) Writer(ctx context.Context) io.Writer { - if f.mode != writeonly { - return file.NewErrorWriter(fmt.Errorf("writer %v: file is not opened in write mode", f.name)) - } - return &s3Writer{ctx: ctx, f: f} -} - -// Seek implements io.Seeker -func (r *s3Reader) Seek(offset int64, whence int) (int64, error) { - res := r.f.runRequest(r.ctx, request{ - reqType: seekRequest, - off: offset, - whence: whence, - }) - return res.off, res.err -} - -// Seek implements io.Seeker -func (f *s3File) handleSeek(req request) { - if err := f.maybeFillInfo(req.ctx); err != nil { - req.ch <- response{off: f.position, err: err} - return - } - var newPosition int64 - switch req.whence { - case io.SeekStart: - newPosition = req.off - case io.SeekCurrent: - newPosition = f.position + req.off - case io.SeekEnd: - newPosition = f.info.size + req.off - default: - req.ch <- response{off: f.position, err: fmt.Errorf("illegal whence: %d", req.whence)} - return - } - if newPosition < 0 { - req.ch <- response{off: f.position, err: fmt.Errorf("out-of-bounds seek")} - return - } - if newPosition == f.position { - req.ch <- response{off: f.position} - } - f.position = newPosition - if f.bodyReader != nil { - f.bodyReader.Close() // nolint: errcheck - f.bodyReader = nil - } - req.ch <- response{off: f.position} -} - -// Read implements io.Reader -func (r *s3Reader) Read(p []byte) (n int, err error) { - res := r.f.runRequest(r.ctx, request{ - reqType: readRequest, - buf: p, - }) - return res.n, res.err -} - -func (f *s3File) startGetObjectRequest(ctx context.Context, client s3iface.S3API) error { - if f.bodyReader != nil { - panic("get request still active") - } - input := &s3.GetObjectInput{ - Bucket: aws.String(f.bucket), - Key: aws.String(f.key), - } - if f.position > 0 { - // We either seeked or read before. So f.info must have been set. - if f.info == nil { - panic(fmt.Sprintf("read %v: nil info: %+v", f.name, f)) - } - if f.position >= f.info.size { - return io.EOF + return response{err: err} } - input.Range = aws.String(fmt.Sprintf("bytes=%d-", f.position)) - } - output, err := client.GetObjectWithContext(ctx, input) - if err != nil { - return err - } - if *output.ETag == "" { - output.Body.Close() // nolint: errcheck - return fmt.Errorf("read %v: File does not exist", f.name) - } - if f.info != nil && f.info.etag != *output.ETag { - output.Body.Close() // nolint: errcheck - return fmt.Errorf("read %v: File version changed from %v to %v", f.name, f.info.etag, *output.ETag) - } - f.bodyReader = output.Body // take ownership - if f.info == nil { - f.info = newInfo(f.name, output) - } - return nil -} - -// Read implements io.Reader -func (f *s3File) handleRead(req request) { - buf := req.buf - clients, err := f.provider.Get(req.ctx, "GetObject", f.name) - if err != nil { - req.ch <- response{err: err} - return - } - maxRetries := maxRetries(clients) - retries := 0 - for len(buf) > 0 { - if f.bodyReader == nil { - err = f.startGetObjectRequest(req.ctx, clients[0]) - if shouldRetry(f.name, err, &clients) { + policy := newBackoffPolicy(clients, file.Opts{}) + for { + var ids s3RequestIDs + req := getRequestFn(policy.client()) + req.ApplyOptions(ids.captureOption()) + url, err := req.Presign(expiry) + if policy.shouldRetry(ctx, err, path) { continue } if err != nil { - break - } - } - var n int - n, err = f.bodyReader.Read(buf) - if n > 0 { - buf = buf[n:] - f.position += int64(n) - } - if err != nil { - f.bodyReader.Close() // nolint: errcheck - f.bodyReader = nil - if err != io.EOF { - err = errors.WithStack(err) - retries++ - if retries <= maxRetries { - log.Error.Printf("s3read %v: retrying (%d) GetObject after error %v", - f.name, retries, err) - continue - } + return response{err: annotate(err, ids, &policy, fmt.Sprintf("s3file.presign %s", path))} } - break + return response{signedURL: url} } - } - totalBytesRead := len(req.buf) - len(buf) - req.ch <- response{n: totalBytesRead, err: err} -} - -func (f *s3File) handleWrite(req request) { - f.uploader.write(req.buf) - req.ch <- response{n: len(req.buf), err: nil} -} - -func (o s3Obj) name() string { - if o.obj == nil { - return *o.cp - } - return *o.obj.Key -} - -type uploadChunk struct { - client s3iface.S3API - uploadID string - partNum int64 - buf []byte -} - -// A helper class for driving s3manager.Uploader through an io.Writer-like -// interface. Its write() method will feed data incrementally to the uploader, -// and finish() will wait for all the uploads to finish. -type s3Uploader struct { - ctx context.Context - client s3iface.S3API - bucket, key string - uploadID string - - // curBuf is only accessed by the handleRequests thread. - curBuf []byte - nextPartNum int64 - - bufPool sync.Pool - reqCh chan uploadChunk - err errorreporter.T - sg sync.WaitGroup - mu sync.Mutex - parts []*s3.CompletedPart -} - -func newUploader(ctx context.Context, provider ClientProvider, opts Options, path, bucket, key string) (*s3Uploader, error) { - clients, err := provider.Get(ctx, "PutObject", path) - if err != nil { - return nil, err - } - params := &s3.CreateMultipartUploadInput{ - Bucket: aws.String(bucket), - Key: aws.String(key), - } - // Add any non-default options - if opts.ServerSideEncryption != "" { - params.SetServerSideEncryption(opts.ServerSideEncryption) - } - - u := &s3Uploader{ - ctx: ctx, - bucket: bucket, - key: key, - bufPool: sync.Pool{New: func() interface{} { return make([]byte, uploadPartSize) }}, - nextPartNum: 1, - } - for { - resp, err := clients[0].CreateMultipartUploadWithContext(ctx, params) - if shouldRetry(path, err, &clients) { - continue - } - if err != nil { - return nil, err - } - u.client = clients[0] - u.uploadID = *resp.UploadId - if u.uploadID == "" { - panic(fmt.Sprintf("empty uploadID: %+v", resp)) - } - break - } - - u.reqCh = make(chan uploadChunk, uploadParallelism) - for i := 0; i < uploadParallelism; i++ { - u.sg.Add(1) - go u.uploadThread() - } - return u, nil -} - -func (u *s3Uploader) uploadThread() { - defer u.sg.Done() - for chunk := range u.reqCh { - params := &s3.UploadPartInput{ - Bucket: aws.String(u.bucket), - Key: aws.String(u.key), - Body: bytes.NewReader(chunk.buf), - UploadId: aws.String(chunk.uploadID), - PartNumber: &chunk.partNum, - } - resp, err := chunk.client.UploadPartWithContext(u.ctx, params) - u.bufPool.Put(chunk.buf) - if err != nil { - u.err.Set(err) - continue - } - partNum := chunk.partNum - completed := &s3.CompletedPart{ETag: resp.ETag, PartNumber: &partNum} - u.mu.Lock() - u.parts = append(u.parts, completed) - u.mu.Unlock() - } -} - -// write appends data to file. It can be called only by the request thread. -func (u *s3Uploader) write(buf []byte) { - if len(buf) == 0 { - panic("empty buf in write") - } - for len(buf) > 0 { - if len(u.curBuf) == 0 { - u.curBuf = u.bufPool.Get().([]byte) - u.curBuf = u.curBuf[:0] - } - if cap(u.curBuf) != uploadPartSize { - panic("empty buf") - } - space := u.curBuf[len(u.curBuf):cap(u.curBuf)] - n := len(buf) - if n < len(space) { - copy(space, buf) - u.curBuf = u.curBuf[0 : len(u.curBuf)+n] - return - } - copy(space, buf) - buf = buf[len(space):] - u.curBuf = u.curBuf[0:cap(u.curBuf)] - u.reqCh <- uploadChunk{client: u.client, uploadID: u.uploadID, partNum: u.nextPartNum, buf: u.curBuf} - u.nextPartNum++ - u.curBuf = nil - } -} - -func (u *s3Uploader) abort() error { - _, err := u.client.AbortMultipartUploadWithContext(u.ctx, &s3.AbortMultipartUploadInput{ - Bucket: aws.String(u.bucket), - Key: aws.String(u.key), - UploadId: aws.String(u.uploadID), }) - return err + return resp.signedURL, resp.err } -// finish finishes writing. It can be called only by the request thread. -func (u *s3Uploader) finish() error { - if len(u.curBuf) > 0 { - u.reqCh <- uploadChunk{client: u.client, uploadID: u.uploadID, partNum: u.nextPartNum, buf: u.curBuf} - u.curBuf = nil - } - close(u.reqCh) - u.sg.Wait() - if u.err.Err() == nil { - if len(u.parts) == 0 { - // Special case: an empty file. CompleteMUltiPartUpload with empty parts causes an error, - // so work around the bug by issuing a separate PutObject request. - u.abort() // nolint: errcheck - _, err := u.client.PutObjectWithContext(u.ctx, &s3.PutObjectInput{ - Bucket: aws.String(u.bucket), - Key: aws.String(u.key), - Body: bytes.NewReader(nil), - }) - u.err.Set(err) - } else { - // Parts must be sorted in PartNumber order. - sort.Slice(u.parts, func(i, j int) bool { - return *u.parts[i].PartNumber < *u.parts[j].PartNumber - }) - params := &s3.CompleteMultipartUploadInput{ - Bucket: aws.String(u.bucket), - Key: aws.String(u.key), - UploadId: aws.String(u.uploadID), - MultipartUpload: &s3.CompletedMultipartUpload{Parts: u.parts}, - } - _, err := u.client.CompleteMultipartUploadWithContext(u.ctx, params) - u.err.Set(err) - } - } - if u.err.Err() != nil { - u.abort() // nolint: errcheck - } - return u.err.Err() -} - -func (w *s3Writer) Write(p []byte) (n int, err error) { - if len(p) == 0 { - return 0, nil - } - res := w.f.runRequest(w.ctx, request{ - reqType: writeRequest, - buf: p, - }) - return res.n, res.err -} - -func (f *s3File) handleClose(req request) { - var err error - if f.uploader != nil { - err = f.uploader.finish() - } else if f.bodyReader != nil { - if e := f.bodyReader.Close(); e != nil && err == nil { - err = e - } - } - req.ch <- response{err: err} -} - -func (f *s3File) handleAbort(req request) { - err := f.uploader.abort() - req.ch <- response{err: err} -} - -// Scan implements Lister.Scan -func (l *s3Lister) Scan() bool { - for { - if l.err != nil { - return false - } - l.err = l.ctx.Err() - if l.err != nil { - return false - } - if len(l.objects) > 0 { - l.object, l.objects = l.objects[0], l.objects[1:] - ll := len(l.prefix) - // Ignore keys whose path component isn't exactly equal to l.prefix. For - // example, if l.prefix="foo/bar", then we yield "foo/bar" and - // "foo/bar/baz", but not "foo/barbaz". - keyVal := l.object.name() - if ll > 0 && len(keyVal) > ll { - if l.prefix[ll-1] == '/' { - // Treat prefix "foo/bar/" as "foo/bar". - ll-- - } - if keyVal[ll] != '/' { - continue - } - } - return true - } - if l.done { - return false - } - - var prefix string - if l.showDirs() && !strings.HasSuffix(l.prefix, pathSeparator) && l.prefix != "" { - prefix = l.prefix + pathSeparator - } else { - prefix = l.prefix - } - - req := &s3.ListObjectsV2Input{ - Bucket: aws.String(l.bucket), - ContinuationToken: l.token, - Prefix: aws.String(prefix), - } - - if l.showDirs() { - req.Delimiter = aws.String(pathSeparator) - } - - res, err := l.clients[0].ListObjectsV2WithContext(l.ctx, req) - if shouldRetry(l.dir, err, &l.clients) { - continue - } - if err != nil { - l.err = err - return false - } - l.token = res.NextContinuationToken - l.objects = make([]s3Obj, 0, len(res.Contents)+len(res.CommonPrefixes)) - for _, objVal := range res.Contents { - l.objects = append(l.objects, s3Obj{obj: objVal}) - } - if l.showDirs() { // add the pseudo Dirs - for _, cpVal := range res.CommonPrefixes { - // Follow the Linux convention that directories do not come back with a trailing / - // when read by ListDir. To determine it is a directory, it is necessary to - // call implementation.Stat on the path and check IsDir() - pseudoDirName := *cpVal.Prefix - if strings.HasSuffix(pseudoDirName, pathSeparator) { - pseudoDirName = pseudoDirName[:len(pseudoDirName)-1] - } - l.objects = append(l.objects, s3Obj{cp: &pseudoDirName}) - } - } - - l.done = len(l.objects) == 0 || !aws.BoolValue(res.IsTruncated) - } -} - -// Path implements Lister.Path -func (l *s3Lister) Path() string { - return fmt.Sprintf("%s://%s/%s", l.scheme, l.bucket, l.object.name()) -} - -// Info implements Lister.Info -func (l *s3Lister) Info() file.Info { - if obj := l.object.obj; obj != nil { - - return &s3Info{ - size: *obj.Size, - modTime: *obj.LastModified, - etag: *obj.ETag, - } - } - return nil -} - -// IsDir implements Lister.IsDir -func (l *s3Lister) IsDir() bool { - return l.object.cp != nil -} - -// Err returns an error, if any. -func (l *s3Lister) Err() error { - return l.err -} - -// Object returns the last object that was scanned. -func (l *s3Lister) Object() s3Obj { - return l.object -} - -// showDirs controls whether CommonPrefixes are returned during a scan -func (l *s3Lister) showDirs() bool { - return !l.recurse -} - -func (i *s3Info) Name() string { return i.name } -func (i *s3Info) Size() int64 { return i.size } -func (i *s3Info) ModTime() time.Time { return i.modTime } - // ParseURL parses a path of form "s3://grail-bucket/dir/file" and returns // ("s3", "grail-bucket", "dir/file", nil). func ParseURL(url string) (scheme, bucket, key string, err error) { @@ -968,3 +213,14 @@ func ParseURL(url string) (scheme, bucket, key string, err error) { } return scheme, parts[0], parts[1], nil } + +func mergeFileOpts(opts []file.Opts) (o file.Opts) { + switch len(opts) { + case 0: + case 1: + o = opts[0] + default: + panic(fmt.Sprintf("More than one options specified: %+v", opts)) + } + return +} diff --git a/file/s3file/s3file_test.go b/file/s3file/s3file_test.go index 54a05eb9..2f090136 100644 --- a/file/s3file/s3file_test.go +++ b/file/s3file/s3file_test.go @@ -2,39 +2,55 @@ // Use of this source code is governed by the Apache-2.0 // license that can be found in the LICENSE file. -package s3file_test +//go:build !unit +// +build !unit + +package s3file import ( "context" + "crypto/md5" "crypto/sha256" - "errors" "flag" "fmt" "io" "io/ioutil" - "math" "math/rand" + "net/http" + "runtime/debug" "strings" + "sync" + "sync/atomic" "testing" "time" - "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + awsrequest "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/s3/s3iface" + "github.com/grailbio/base/errors" + "github.com/grailbio/base/file" + "github.com/grailbio/base/file/internal/s3bufpool" "github.com/grailbio/base/file/internal/testutil" - "github.com/grailbio/base/file/s3file" + "github.com/grailbio/base/file/s3file/s3transport" + "github.com/grailbio/base/log" + "github.com/grailbio/base/retry" "github.com/grailbio/testutil/assert" "github.com/grailbio/testutil/s3test" ) var ( s3BucketFlag = flag.String("s3-bucket", "", "If set, run a unittest against a real S3 bucket named in this flag") - profileFlag = flag.String("profile", "default", "If set, use the named profile in ~/.aws") + s3DirFlag = flag.String("s3-dir", "", "S3 directory under -s3-bucket used by some unittests") ) type failingContentAt struct { - prob float64 // probability of failing requests - rand *rand.Rand - content []byte + prob float64 // probability of failing requests + content []byte + failWithErr error + + randMu sync.Mutex + rand *rand.Rand } func doReadAt(src []byte, off64 int64, dest []byte) (int, error) { @@ -62,12 +78,17 @@ func doWriteAt(src []byte, off64 int64, dest *[]byte) (int, error) { } func (c *failingContentAt) ReadAt(p []byte, off64 int64) (int, error) { - if p := c.rand.Float64(); p < c.prob { - return 0, fmt.Errorf("failingContentAt synthetic error") + c.randMu.Lock() + pr := c.rand.Float64() + c.randMu.Unlock() + if pr < c.prob { + return 0, c.failWithErr } n := len(p) if n > 1 { + c.randMu.Lock() n = 1 + c.rand.Intn(n-1) + c.randMu.Unlock() } return doReadAt(c.content, off64, p[:n]) } @@ -80,6 +101,10 @@ func (c *failingContentAt) Size() int64 { return int64(len(c.content)) } +func (c *failingContentAt) Checksum() string { + return fmt.Sprintf("%x", md5.Sum(c.content)) +} + type pausingContentAt struct { ready chan bool content []byte @@ -101,41 +126,87 @@ func (c *pausingContentAt) Size() int64 { return int64(len(c.content)) } -type testProvider struct { - clients []s3iface.S3API -} - -func (p *testProvider) Get(ctx context.Context, op, path string) ([]s3iface.S3API, error) { - return p.clients, nil +func (c *pausingContentAt) Checksum() string { + return fmt.Sprintf("%x", md5.Sum(c.content)) } -func (p *testProvider) NotifyResult(ctx context.Context, op, path string, client s3iface.S3API, err error) { +func newImpl(clients ...s3iface.S3API) *s3Impl { + return &s3Impl{ + clientsForAction: func(_ context.Context, _, _, _ string) ([]s3iface.S3API, error) { + return clients, nil + }, + } } func newClient(t *testing.T) *s3test.Client { return s3test.NewClient(t, "b") } -func permErrorClient(t *testing.T) s3iface.S3API { +func errorClient(t *testing.T, err error) s3iface.S3API { c := s3test.NewClient(t, "b") - c.Err = errors.New("test permission error") + c.Err = func(api string, input interface{}) error { + return err + } return c } func TestS3(t *testing.T) { - provider := &testProvider{clients: []s3iface.S3API{permErrorClient(t), newClient(t)}} ctx := context.Background() - impl := s3file.NewImplementation(provider, s3file.Options{}) - testutil.TestAll(ctx, t, impl, "s3://b/dir") + impl := newImpl( + errorClient(t, awserr.New( + "", // TODO(swami): Use an AWS error code that represents a permission error. + "test permission error", + nil, + )), + newClient(t), + ) + testutil.TestStandard(ctx, t, impl, "s3://b/dir") + t.Run("readat", func(t *testing.T) { + testutil.TestConcurrentOffsetReads(ctx, t, impl, "s3://b/dir/readats.txt") + }) } -func TestListBucketRoot(t *testing.T) { - provider := &testProvider{clients: []s3iface.S3API{newClient(t)}} +func TestS3WithRetries(t *testing.T) { + tearDown := setZeroBackoffPolicy() + defer tearDown() + ctx := context.Background() - impl := s3file.NewImplementation(provider, s3file.Options{}) + for iter := 0; iter < 50; iter++ { + randIntsC := make(chan int) + go func() { + r := rand.New(rand.NewSource(int64(iter))) + for { + randIntsC <- r.Intn(20) + } + }() + client := newClient(t) + client.Err = func(api string, input interface{}) error { + switch <-randIntsC { + case 0: + return awserr.New(awsrequest.ErrCodeSerialization, "injected serialization failure", nil) + case 1: + return awserr.New("RequestError", "send request failed", readConnResetError{}) + } + return nil + } + impl := newImpl(client) + testutil.TestStandard(ctx, t, impl, "s3://b/dir") + t.Run("readat", func(t *testing.T) { + testutil.TestConcurrentOffsetReads(ctx, t, impl, "s3://b/dir/readats.txt") + }) + } +} - f, err := impl.Create(ctx, "s3://b/0.txt") +// WriteFile creates a file with the given contents. Path should be of form +// s3://bucket/key. +func writeFile(ctx context.Context, t *testing.T, impl file.Implementation, path, data string) { + f, err := impl.Create(ctx, path) assert.NoError(t, err) - _, err = f.Writer(ctx).Write([]byte("data")) + _, err = f.Writer(ctx).Write([]byte(data)) assert.NoError(t, err) assert.NoError(t, f.Close(ctx)) +} +func TestListBucketRoot(t *testing.T) { + ctx := context.Background() + impl := newImpl(newClient(t)) + writeFile(ctx, t, impl, "s3://b/0.txt", "data") l := impl.List(ctx, "s3://b", true) assert.True(t, l.Scan(), "err: %v", l.Err()) @@ -144,10 +215,21 @@ func TestListBucketRoot(t *testing.T) { assert.NoError(t, l.Err()) } +type readConnResetError struct{} + +func (c readConnResetError) Temporary() bool { return false } +func (c readConnResetError) Error() string { return "read: connection reset" } + func TestErrors(t *testing.T) { - provider := &testProvider{clients: []s3iface.S3API{permErrorClient(t)}} ctx := context.Background() - impl := s3file.NewImplementation(provider, s3file.Options{}) + impl := newImpl( + errorClient(t, + awserr.New("", // TODO(swami): Use an AWS error code that represents a permission error. + fmt.Sprintf("test permission error: %s", string(debug.Stack())), + nil, + ), + ), + ) _, err := impl.Create(ctx, "s3://b/junk0.txt") assert.Regexp(t, err, "test permission error") @@ -160,53 +242,131 @@ func TestErrors(t *testing.T) { assert.Regexp(t, l.Err(), "test permission error") } -func TestRetryAfterError(t *testing.T) { +func TestTransientErrors(t *testing.T) { + impl := newImpl(errorClient(t, awserr.New("RequestError", "send request failed", readConnResetError{}))) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := impl.Stat(ctx, "s3://b/junk0.txt") + assert.True(t, errors.Is(errors.Canceled, err), "expected cancellation") + + ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + _, err = impl.Stat(ctx, "s3://b/junk0.txt") + assert.Regexp(t, err, "ran out of time while waiting") +} + +func TestWriteRetryAfterError(t *testing.T) { + tearDown := setZeroBackoffPolicy() + defer tearDown() + client := newClient(t) - setContent := func(path string, prob float64, data string) { - c := &failingContentAt{ - prob: prob, - rand: rand.New(rand.NewSource(0)), - content: []byte(data), + impl := newImpl(client) + ctx := context.Background() + for i := 0; i < 10; i++ { + r := rand.New(rand.NewSource(0)) + client.Err = func(api string, input interface{}) error { + if r.Intn(3) == 0 { + fmt.Printf("write: api %s\n", api) + return awserr.New(awsrequest.ErrCodeSerialization, "test failure", nil) + } + return nil } - checksum := sha256.Sum256(c.content) - client.SetFileContentAt(path, c, fmt.Sprintf("%x", checksum[:])) + writeFile(ctx, t, impl, "s3://b/0.txt", "data") } +} - var contents string - { - l := []string{} - for i := 0; i < 1000; i++ { - l = append(l, fmt.Sprintf("D%d", i)) - } - contents = strings.Join(l, ",") +func TestReadRetryAfterError(t *testing.T) { + for errIdx, failWithErr := range []error{ + fmt.Errorf("failingContentAt synthetic error"), + readConnResetError{}, + } { + t.Run(fmt.Sprintf("error_%d", errIdx), func(t *testing.T) { + tearDown := setZeroBackoffPolicy() + defer tearDown() + + client := newClient(t) + setContent := func(path string, prob float64, data string) { + c := &failingContentAt{ + prob: prob, + rand: rand.New(rand.NewSource(0)), + content: []byte(data), + failWithErr: failWithErr, + } + checksum := sha256.Sum256(c.content) + client.SetFileContentAt(path, c, fmt.Sprintf("%x", checksum[:])) + } + + var contents string + { + l := []string{} + for i := 0; i < 1000; i++ { + l = append(l, fmt.Sprintf("D%d", i)) + } + contents = strings.Join(l, ",") + } + // Exercise parallel reading including partial last chunk. + tearDownRCB := setReadChunkBytes() + defer tearDownRCB() + + assert.GT(t, len(contents)%ReadChunkBytes(), 0) + + impl := newImpl(client) + ctx := context.Background() + + setContent("junk0.txt", 0.3, contents) + for i := 0; i < 10; i++ { + f, err := impl.Open(ctx, "b/junk0.txt") + assert.NoError(t, err) + r := f.Reader(ctx) + data, err := ioutil.ReadAll(r) + assert.NoError(t, err) + assert.EQ(t, contents, string(data)) + assert.NoError(t, f.Close(ctx)) + } + + // Simulate exhausting all allowed retries. Since the number of retries is unrestricted, + // the request is capped by MaxRetryDuration. To avoid a flaky time dependency, instead + // of using an actual deadline we just cancel the context. + tearDown = setFakeWithDeadline() + defer tearDown() + setContent("junk1.txt", 1.0 /*fail everything*/, contents) + { + f, err := impl.Open(ctx, "b/junk1.txt") + assert.NoError(t, err) + r := f.Reader(ctx) + _, err = ioutil.ReadAll(r) + assert.Regexp(t, err, failWithErr.Error()) + assert.NoError(t, f.Close(ctx)) + } + }) } +} - provider := &testProvider{clients: []s3iface.S3API{client}} - impl := s3file.NewImplementation(provider, s3file.Options{}) - ctx := context.Background() +func TestRetryWhenNotFound(t *testing.T) { + client := s3test.NewClient(t, "b") - setContent("junk0.txt", 0.3, contents) - for i := 0; i < 10; i++ { - client.NumMaxRetries = math.MaxInt32 - f, err := impl.Open(ctx, "b/junk0.txt") - assert.NoError(t, err) - r := f.Reader(ctx) - data, err := ioutil.ReadAll(r) - assert.NoError(t, err) - assert.EQ(t, contents, string(data)) - assert.NoError(t, f.Close(ctx)) - } + impl := newImpl(client) - setContent("junk1.txt", 1.0 /*fail everything*/, contents) - { - client.NumMaxRetries = 10 - f, err := impl.Open(ctx, "b/junk1.txt") + ctx := context.Background() + // By default, there is no retry. + _, err := impl.Open(ctx, "s3://b/file.txt") + assert.Regexp(t, err, "NoSuchKey") + + doneCh := make(chan bool) + go func() { + _, err := impl.Open(ctx, "s3://b/file.txt", file.Opts{RetryWhenNotFound: true}) assert.NoError(t, err) - r := f.Reader(ctx) - _, err = ioutil.ReadAll(r) - assert.Regexp(t, err, "failingContentAt synthetic error") - assert.NoError(t, f.Close(ctx)) + doneCh <- true + }() + time.Sleep(1 * time.Second) + select { + case <-doneCh: + t.Fatal("should not reach here") + default: } + writeFile(ctx, t, impl, "s3://b/file.txt", "data") + fmt.Println("wrote file") + <-doneCh } func TestCancellation(t *testing.T) { @@ -221,8 +381,7 @@ func TestCancellation(t *testing.T) { c0 := setContent("test0.txt", "hello") _ = setContent("test1.txt", "goodbye") - provider := &testProvider{clients: []s3iface.S3API{client}} - impl := s3file.NewImplementation(provider, s3file.Options{}) + impl := newImpl(client) { c0.ready <- true // Reading c0 completes immediately. @@ -243,30 +402,247 @@ func TestCancellation(t *testing.T) { defer cancel() r := f.Reader(ctx) _, err = ioutil.ReadAll(r) - assert.Regexp(t, err, "Request cancelled") - assert.Regexp(t, f.Close(ctx), "Request cancelled") + assert.True(t, errors.Is(errors.Canceled, err), "expected cancellation") + assert.True(t, errors.Is(errors.Canceled, f.Close(ctx)), "expected cancellation") } } -func TestAWS(t *testing.T) { +func testOverwriteWhileReading(t *testing.T, impl file.Implementation, pathPrefix string) { + ctx := context.Background() + path := pathPrefix + "/test.txt" + writeFile(ctx, t, impl, path, "test0") + f, err := impl.Open(ctx, path) + assert.NoError(t, err) + + r := f.Reader(ctx) + data, err := ioutil.ReadAll(r) + assert.NoError(t, err) + assert.EQ(t, "test0", string(data)) + + _, err = r.Seek(0, io.SeekStart) + assert.NoError(t, err) + + writeFile(ctx, t, impl, path, "test0") + + data, err = ioutil.ReadAll(r) + assert.NoError(t, err) + assert.EQ(t, "test0", string(data)) + + _, err = r.Seek(0, io.SeekStart) + assert.NoError(t, err) + writeFile(ctx, t, impl, path, "test1") + _, err = ioutil.ReadAll(r) + assert.True(t, errors.Is(errors.Precondition, err), "err=%v", err) +} + +func TestWriteLargeFile(t *testing.T) { + // Reduce the upload chunk size to issue concurrent upload requests to S3. + oldUploadPartSize := UploadPartSize + UploadPartSize = 128 + defer func() { + UploadPartSize = oldUploadPartSize + }() + + ctx := context.Background() + impl := newImpl(s3test.NewClient(t, "b")) + path := "s3://b/test.txt" + f, err := impl.Create(ctx, path) + assert.NoError(t, err) + r := rand.New(rand.NewSource(0)) + var want []byte + const iters = 400 + for i := 0; i < iters; i++ { + n := r.Intn(1024) + 100 + data := make([]byte, n) + n, err := r.Read(data) + assert.EQ(t, n, len(data)) + assert.NoError(t, err) + n, err = f.Writer(ctx).Write(data) + assert.EQ(t, n, len(data)) + assert.NoError(t, err) + want = append(want, data...) + } + assert.NoError(t, f.Close(ctx)) + + // Read the file back and verify contents. + f, err = impl.Open(ctx, path) + assert.NoError(t, err) + got := make([]byte, len(want)) + n, _ := f.Reader(ctx).Read(got) + assert.EQ(t, n, len(want)) + assert.EQ(t, got, want) + assert.NoError(t, f.Close(ctx)) +} + +func TestOverwriteWhileReading(t *testing.T) { + impl := newImpl(s3test.NewClient(t, "b")) + testOverwriteWhileReading(t, impl, "s3://b/test") +} + +func TestNotExist(t *testing.T) { + impl := newImpl(s3test.NewClient(t, "b")) + ctx := context.Background() + // The s3test client fails tests for requests that attempt to + // access buckets other than the one specified, so we can + // test only missing keys here. + _, err := impl.Open(ctx, "b/notexist") + assert.True(t, errors.Is(errors.NotExist, err)) +} + +func realBucketProviderOrSkip(t *testing.T) SessionProvider { if *s3BucketFlag == "" { t.Skip("Skipping. Set -s3-bucket to run the test.") } - provider := s3file.NewDefaultProvider(session.Options{Profile: *profileFlag}) + return NewDefaultProvider( + aws.NewConfig().WithHTTPClient(s3transport.DefaultClient()), + ) +} + +func TestOverwriteWhileReadingAWS(t *testing.T) { + provider := realBucketProviderOrSkip(t) + impl := NewImplementation(provider, Options{}) + testOverwriteWhileReading(t, impl, fmt.Sprintf("s3://%s/tmp/testoverwrite", *s3BucketFlag)) +} + +func TestPresignRequestsAWS(t *testing.T) { + provider := realBucketProviderOrSkip(t) + impl := NewImplementation(provider, Options{}) ctx := context.Background() - impl := s3file.NewImplementation(provider, s3file.Options{}) - testutil.TestAll(ctx, t, impl, "s3://"+*s3BucketFlag+"/tmp") + const content = "file for testing presigned URLs\n" + path := fmt.Sprintf("s3://%s/tmp/testpresigned", *s3BucketFlag) + + // Write the dummy file. + url, err := impl.Presign(ctx, path, "PUT", time.Minute) + if err != nil { + t.Fatal(err) + } + req, err := http.NewRequest(http.MethodPut, url, strings.NewReader(content)) + if err != nil { + t.Fatal(err) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + + // Read the dummy file. + url, err = impl.Presign(ctx, path, "GET", time.Minute) + if err != nil { + t.Fatal(err) + } + resp, err = http.Get(url) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + respBytes, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + if content != string(respBytes) { + t.Errorf("got: %q, want: %q", string(respBytes), content) + } + + // Delete the dummy file. + url, err = impl.Presign(ctx, path, "DELETE", time.Minute) + if err != nil { + t.Fatal(err) + } + req, err = http.NewRequest(http.MethodDelete, url, strings.NewReader("")) + if err != nil { + t.Fatal(err) + } + resp, err = http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + if _, err := impl.Stat(ctx, path); !errors.Is(errors.NotExist, err) { + t.Errorf("got: %v\nwant an error of kind NotExist", err) + } +} + +func TestAWS(t *testing.T) { + provider := realBucketProviderOrSkip(t) + ctx := context.Background() + impl := NewImplementation(provider, Options{}) + testutil.TestStandard(ctx, t, impl, "s3://"+*s3BucketFlag+"/tmp") + t.Run("readat", func(t *testing.T) { + testutil.TestConcurrentOffsetReads(ctx, t, impl, "s3://"+*s3BucketFlag+"/tmp") + }) +} + +func TestConcurrentUploadsAWS(t *testing.T) { + provider := realBucketProviderOrSkip(t) + impl := NewImplementation(provider, Options{}) + + if *s3DirFlag == "" { + t.Skip("Skipping. Set -s3-bucket and -s3-dir to run the test.") + } + path := fmt.Sprintf("s3://%s/%s/test.txt", *s3BucketFlag, *s3DirFlag) + ctx := context.Background() + + upload := func() { + f, err := impl.Create(ctx, path, file.Opts{IgnoreNoSuchUpload: true}) + if err != nil { + log.Panic(err) + } + _, err = f.Writer(ctx).Write([]byte("hello")) + if err != nil { + log.Panic(err) + } + if err := f.Close(ctx); err != nil { + log.Panic(err) + } + } + + wg := sync.WaitGroup{} + n := uint64(0) + for i := 0; i < 4000; i++ { + wg.Add(1) + go func() { + upload() + if x := atomic.AddUint64(&n, 1); x%100 == 0 { + log.Printf("%d done", x) + } + wg.Done() + }() + } + wg.Wait() } func ExampleParseURL() { - scheme, bucket, key, err := s3file.ParseURL("s3://grail-bucket/dir/file") + scheme, bucket, key, err := ParseURL("s3://grail-bucket/dir/file") fmt.Printf("scheme: %s, bucket: %s, key: %s, err: %v\n", scheme, bucket, key, err) - scheme, bucket, key, err = s3file.ParseURL("s3://grail-bucket/dir/") + scheme, bucket, key, err = ParseURL("s3://grail-bucket/dir/") fmt.Printf("scheme: %s, bucket: %s, key: %s, err: %v\n", scheme, bucket, key, err) - scheme, bucket, key, err = s3file.ParseURL("s3://grail-bucket") + scheme, bucket, key, err = ParseURL("s3://grail-bucket") fmt.Printf("scheme: %s, bucket: %s, key: %s, err: %v\n", scheme, bucket, key, err) // Output: // scheme: s3, bucket: grail-bucket, key: dir/file, err: // scheme: s3, bucket: grail-bucket, key: dir/, err: // scheme: s3, bucket: grail-bucket, key: , err: } + +func setZeroBackoffPolicy() (tearDown func()) { + oldPolicy := BackoffPolicy + BackoffPolicy = retry.Backoff(0, 0, 1.0) + return func() { BackoffPolicy = oldPolicy } +} + +func setReadChunkBytes() (tearDown func()) { + old := s3bufpool.BufBytes + s3bufpool.SetBufSize(100) + return func() { s3bufpool.SetBufSize(old) } +} + +func setFakeWithDeadline() (tearDown func()) { + old := WithDeadline + WithDeadline = func(ctx context.Context, deadline time.Time) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithDeadline(ctx, deadline) + cancel() + return ctx, cancel + } + return func() { WithDeadline = old } +} diff --git a/file/s3file/s3transport/dns.go b/file/s3file/s3transport/dns.go new file mode 100644 index 00000000..46b58f26 --- /dev/null +++ b/file/s3file/s3transport/dns.go @@ -0,0 +1,49 @@ +package s3transport + +import ( + "net" + "sync" + "time" +) + +const dnsCacheTime = 5 * time.Second + +type resolverCacheEntry struct { + result []net.IP + resolvedAt time.Time +} + +type resolver struct { + lookupIP func(host string) ([]net.IP, error) + now func() time.Time + cacheMu sync.Mutex + cache map[string]resolverCacheEntry +} + +func newResolver(lookupIP func(host string) ([]net.IP, error), now func() time.Time) *resolver { + return &resolver{ + lookupIP: lookupIP, + now: now, + cache: map[string]resolverCacheEntry{}, + } +} + +var defaultResolver = newResolver(net.LookupIP, time.Now) + +func (r *resolver) LookupIP(host string) ([]net.IP, error) { + r.cacheMu.Lock() + entry, ok := r.cache[host] + r.cacheMu.Unlock() + now := r.now() + if ok && now.Sub(entry.resolvedAt) < dnsCacheTime { + return entry.result, nil + } + ips, err := r.lookupIP(host) + if err != nil { + return nil, err + } + r.cacheMu.Lock() + r.cache[host] = resolverCacheEntry{ips, now} + r.cacheMu.Unlock() + return ips, nil +} diff --git a/file/s3file/s3transport/dns_test.go b/file/s3file/s3transport/dns_test.go new file mode 100644 index 00000000..8be27291 --- /dev/null +++ b/file/s3file/s3transport/dns_test.go @@ -0,0 +1,53 @@ +package s3transport + +import ( + "fmt" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestResolver(t *testing.T) { + var ( + gotHost string + stubIP []net.IP + stubError error + stubNow time.Time + ) + stubLookupIP := func(host string) ([]net.IP, error) { + gotHost = host + return stubIP, stubError + } + r := newResolver(stubLookupIP, func() time.Time { return stubNow }) + + stubIP, stubError = []net.IP{{1, 2, 3, 4}, {10, 20, 30, 40}}, nil + stubNow = time.Unix(1600000000, 0) + gotIP, gotError := r.LookupIP("s3.example.com") + assert.Equal(t, "s3.example.com", gotHost) + assert.NoError(t, gotError) + assert.Equal(t, []net.IP{{1, 2, 3, 4}, {10, 20, 30, 40}}, gotIP) + + stubIP, stubError = nil, fmt.Errorf("stub err") + stubNow = stubNow.Add(dnsCacheTime - 1) + gotHost = "should not be called" + gotIP, gotError = r.LookupIP("s3.example.com") + assert.Equal(t, "should not be called", gotHost) + assert.NoError(t, gotError) + assert.Equal(t, []net.IP{{1, 2, 3, 4}, {10, 20, 30, 40}}, gotIP) + + stubIP, stubError = []net.IP{{5, 6, 7, 8}}, nil + gotIP, gotError = r.LookupIP("s3-us-west-2.example.com") + assert.Equal(t, "s3-us-west-2.example.com", gotHost) + assert.NoError(t, gotError) + assert.Equal(t, []net.IP{{5, 6, 7, 8}}, gotIP) + + stubIP, stubError = []net.IP{{21, 22, 23, 24}}, nil + gotHost = "" + stubNow = stubNow.Add(2) + gotIP, gotError = r.LookupIP("s3.example.com") + assert.Equal(t, "s3.example.com", gotHost) + assert.NoError(t, gotError) + assert.Equal(t, []net.IP{{21, 22, 23, 24}}, gotIP) +} diff --git a/file/s3file/s3transport/expiring_map.go b/file/s3file/s3transport/expiring_map.go new file mode 100644 index 00000000..85d49e97 --- /dev/null +++ b/file/s3file/s3transport/expiring_map.go @@ -0,0 +1,109 @@ +package s3transport + +import ( + "net" + "sync" + "time" + + "github.com/grailbio/base/file/s3file/internal/autolog" + "github.com/grailbio/base/log" +) + +const ( + // expireAfter balances saving seen IPs to distribute ongoing load vs. tying up resources + // for a long time. Given that DNS provides new S3 IP addresses every few seconds, retaining + // for an hour means I/O intensive batch jobs can maintain hundreds of S3 peers. But, an API server + // with weeks of uptime won't accrete huge numbers of old records. + expireAfter = time.Hour + // expireLoopEvery controls how frequently the expireAfter threshold is tested, so it controls + // "slack" in expireAfter. The loop takes locks that block requests, so it should not be too + // frequent (relative to request rate). + expireLoopEvery = time.Minute +) + +type expiringMap struct { + now func() time.Time + + mu sync.Mutex + // elems is URL host -> string(net.IP) -> last seen. + elems map[string]map[string]time.Time +} + +func newExpiringMap(runPeriodic runPeriodic, now func() time.Time) *expiringMap { + s := expiringMap{now: now, elems: map[string]map[string]time.Time{}} + go runPeriodic(expireLoopEvery, s.expireOnce) + autolog.Register(s.logOnce) + return &s +} + +func (s *expiringMap) AddAndGet(host string, newIPs []net.IP) (allIPs []net.IP) { + now := s.now() + s.mu.Lock() + defer s.mu.Unlock() + ips, ok := s.elems[host] + if !ok { + ips = map[string]time.Time{} + s.elems[host] = ips + } + for _, ip := range newIPs { + ips[string(ip)] = now + } + for ip := range ips { + allIPs = append(allIPs, net.IP(ip)) + } + return +} + +func (s *expiringMap) expireOnce(now time.Time) { + earliestUnexpiredTime := now.Add(-expireAfter) + s.mu.Lock() + for host, ips := range s.elems { + deleteBefore(ips, earliestUnexpiredTime) + if len(ips) == 0 { + delete(s.elems, host) + } + } + s.mu.Unlock() +} + +func deleteBefore(times map[string]time.Time, threshold time.Time) { + for key, time := range times { + if time.Before(threshold) { + delete(times, key) + } + } +} + +func (s *expiringMap) logOnce() { + s.mu.Lock() + var ( + hosts = len(s.elems) + ips, hostIPMax int + ) + for _, e := range s.elems { + ips += len(e) + if len(e) > hostIPMax { + hostIPMax = len(e) + } + } + s.mu.Unlock() + log.Printf("s3file transport: hosts:%d ips:%d hostipmax:%d", hosts, ips, hostIPMax) +} + +// runPeriodic runs the given func with the given period. +type runPeriodic func(time.Duration, func(time.Time)) + +func runPeriodicForever() runPeriodic { + return func(period time.Duration, tick func(time.Time)) { + ticker := time.NewTicker(period) + defer ticker.Stop() + for { + select { + case now := <-ticker.C: + tick(now) + } + } + } +} + +func noOpRunPeriodic(time.Duration, func(time.Time)) {} diff --git a/file/s3file/s3transport/expiring_map_test.go b/file/s3file/s3transport/expiring_map_test.go new file mode 100644 index 00000000..975ba0b9 --- /dev/null +++ b/file/s3file/s3transport/expiring_map_test.go @@ -0,0 +1,40 @@ +package s3transport + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestExpiringMap(t *testing.T) { + ips := func(is ...byte) (ret []net.IP) { + for _, i := range is { + ret = append(ret, net.IP{i, i, i, i}) + } + return + } + var stubNow time.Time + + m := newExpiringMap(noOpRunPeriodic, func() time.Time { return stubNow }) + + stubNow = time.Unix(1600000000, 0) + assert.ElementsMatch(t, ips(0, 1), m.AddAndGet("s3.example.com", ips(0, 1))) + + stubNow = stubNow.Add(expireAfter / 2) + assert.ElementsMatch(t, ips(0, 1), m.AddAndGet("s3.example.com", ips(0))) + assert.ElementsMatch(t, ips(0, 1, 3), m.AddAndGet("s3.example.com", ips(3))) + + stubNow = stubNow.Add(expireAfter/2 + 2) + m.expireOnce(stubNow) // Drop ips(1). + assert.ElementsMatch(t, ips(0, 3, 4), m.AddAndGet("s3.example.com", ips(4))) + assert.ElementsMatch(t, ips(100), m.AddAndGet("s3-2.example.com", ips(100))) + + stubNow = stubNow.Add(expireAfter/2 + 2) + m.expireOnce(stubNow) // Drop ips(0, 3). + assert.ElementsMatch(t, ips(4), m.AddAndGet("s3.example.com", nil)) + assert.ElementsMatch(t, ips(100), m.AddAndGet("s3-2.example.com", nil)) + + m.logOnce() // No assertions other than it shouldn't panic. +} diff --git a/file/s3file/s3transport/transport.go b/file/s3file/s3transport/transport.go new file mode 100644 index 00000000..ad8819b8 --- /dev/null +++ b/file/s3file/s3transport/transport.go @@ -0,0 +1,153 @@ +package s3transport + +import ( + "crypto/tls" + "fmt" + "io" + "math/rand" + "net" + "net/http" + "sort" + "sync" + "time" + + "github.com/grailbio/base/file/s3file/internal/autolog" + "github.com/grailbio/base/log" +) + +// T is an http.RoundTripper specialized for S3. See https://github.com/aws/aws-sdk-go/issues/3739. +type T struct { + factory func() *http.Transport + + hostRTsMu sync.Mutex + hostRTs map[string]http.RoundTripper + + nOpenConnsPerIPMu sync.Mutex + nOpenConnsPerIP map[string]int + + hostIPs *expiringMap +} + +var ( + stdDefaultTransport = http.DefaultTransport.(*http.Transport) + httpTransport = &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, // Copied from http.DefaultTransport. + KeepAlive: 30 * time.Second, // Copied from same. + }).DialContext, + ForceAttemptHTTP2: false, // S3 doesn't support HTTP2. + MaxIdleConns: 200, // Keep many peers for future bursts. + MaxIdleConnsPerHost: 4, // But limit connections to each. + IdleConnTimeout: expireAfter + 2*expireLoopEvery, // Keep until we forget the peer. + TLSClientConfig: &tls.Config{}, + TLSHandshakeTimeout: stdDefaultTransport.TLSHandshakeTimeout, + ExpectContinueTimeout: stdDefaultTransport.ExpectContinueTimeout, + } + + defaultOnce sync.Once + defaultT *T + defaultClient *http.Client +) + +func defaults() (*T, *http.Client) { + defaultOnce.Do(func() { + defaultT = New(httpTransport.Clone) + defaultClient = &http.Client{Transport: defaultT} + }) + return defaultT, defaultClient +} + +// Default returns an http.RoundTripper with recommended settings. +func Default() *T { t, _ := defaults(); return t } + +// DefaultClient returns an *http.Client that uses the http.RoundTripper +// returned by Default (suitable for general use, analogous to +// "net/http".DefaultClient). +func DefaultClient() *http.Client { _, c := defaults(); return c } + +// New constructs *T using factory to create internal transports. Each call to factory() +// must return a separate http.Transport and they must not share TLSClientConfig. +func New(factory func() *http.Transport) *T { + t := T{ + factory: factory, + hostRTs: map[string]http.RoundTripper{}, + hostIPs: newExpiringMap(runPeriodicForever(), time.Now), + nOpenConnsPerIP: map[string]int{}, + } + autolog.Register(func() { + var nOpen []int + t.nOpenConnsPerIPMu.Lock() + for _, n := range t.nOpenConnsPerIP { + nOpen = append(nOpen, n) + } + t.nOpenConnsPerIPMu.Unlock() + sort.Sort(sort.Reverse(sort.IntSlice(nOpen))) + log.Printf("s3file transport: open RTs per IP: %v", nOpen) + }) + return &t +} + +func (t *T) RoundTrip(req *http.Request) (*http.Response, error) { + host := req.URL.Hostname() + + ips, err := defaultResolver.LookupIP(host) + if err != nil { + if req.Body != nil { + _ = req.Body.Close() + } + return nil, fmt.Errorf("s3transport: lookup ip: %w", err) + } + ips = t.hostIPs.AddAndGet(host, ips) + + hostReq := req.Clone(req.Context()) + hostReq.Host = host + // TODO: Consider other load balancing strategies. + ip := ips[rand.Intn(len(ips))].String() + hostReq.URL.Host = ip + + hostRT := t.hostRoundTripper(host) + resp, err := hostRT.RoundTrip(hostReq) + if resp != nil { + t.addOpenConnsPerIP(ip, 1) + resp.Body = &rcOnClose{resp.Body, func() { t.addOpenConnsPerIP(ip, -1) }} + } + return resp, err +} + +func (t *T) hostRoundTripper(host string) http.RoundTripper { + t.hostRTsMu.Lock() + defer t.hostRTsMu.Unlock() + if rt, ok := t.hostRTs[host]; ok { + return rt + } + transport := t.factory() + // We modify request URL to contain an IP, but server certificates list hostnames, so we + // configure our client to check against original hostname. + if transport.TLSClientConfig == nil { + transport.TLSClientConfig = &tls.Config{} + } + transport.TLSClientConfig.ServerName = host + t.hostRTs[host] = transport + return transport +} + +func (t *T) addOpenConnsPerIP(ip string, add int) { + t.nOpenConnsPerIPMu.Lock() + t.nOpenConnsPerIP[ip] += add + t.nOpenConnsPerIPMu.Unlock() +} + +type rcOnClose struct { + io.ReadCloser + onClose func() +} + +func (r *rcOnClose) Close() error { + // In rare cases, this Close() is called a second time, with a call stack from the AWS SDK's + // cleanup code. + if r.onClose != nil { + defer r.onClose() + } + r.onClose = nil + return r.ReadCloser.Close() +} diff --git a/file/s3file/s3transport/transport_test.go b/file/s3file/s3transport/transport_test.go new file mode 100644 index 00000000..571172a3 --- /dev/null +++ b/file/s3file/s3transport/transport_test.go @@ -0,0 +1,2 @@ +// s3transport is exercised in s3file's *AWS integration tests. +package s3transport diff --git a/file/s3file/session_provider.go b/file/s3file/session_provider.go new file mode 100644 index 00000000..bc794927 --- /dev/null +++ b/file/s3file/session_provider.go @@ -0,0 +1,172 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +package s3file + +import ( + "context" + "fmt" + "runtime" + "sync" + "sync/atomic" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3iface" + "github.com/grailbio/base/errors" +) + +const ( + defaultRegion = "us-west-2" + clientCacheGarbageCollectionInterval = 10 * time.Minute +) + +type ( + // SessionProvider provides Sessions for making AWS API calls. Get() is called whenever s3file + // needs to access a file. The provider should cache and reuse the sessions, if needed. + // The implementation must be thread safe. + SessionProvider interface { + // Get returns AWS sessions that can be used to perform in.S3IAMAction on + // s3://{in.bucket}/{in.key}. + // + // s3file maintains an internal cache keyed by *session.Session that is only pruned + // occasionally. Get() is called for every S3 operation so it should be very fast. Caching + // (that is, reusing *session.Session whenever possible) is strongly encouraged. + // + // Get() must return >= 1 session, or error. If > 1, the S3 operation will be tried + // on each session in unspecified order until it succeeds. + // + // Note: Some implementations will not need SessionProviderInput and can just ignore it. + // + // TODO: Consider passing chan<- *session.Session (implementer sends and then closes) + // so s3file can try credentials as soon as they're available. + Get(_ context.Context, in SessionProviderInput) ([]*session.Session, error) + } + SessionProviderInput struct { + // S3IAMAction is an action name from this list: + // https://docs.aws.amazon.com/service-authorization/latest/reference/list_amazons3.html + // + // Note: There is no `s3:` prefix. + // + // Note: This is different from the notion of "action" in the S3 API documentation: + // https://docs.aws.amazon.com/AmazonS3/latest/API/API_Operations.html + // Some names, like GetObject, appear in both; others, like HeadObject, do not. + S3IAMAction string + // Bucket and Key describe the API operation to be performed, if applicable. + Bucket, Key string + } + + constSessionProvider struct { + session *session.Session + err error + } +) + +// NewDefaultProvider returns a SessionProvider that calls session.NewSession(configs...) once. +func NewDefaultProvider(configs ...*aws.Config) SessionProvider { + session, err := session.NewSession(configs...) + return constSessionProvider{session, err} +} + +func (p constSessionProvider) Get(context.Context, SessionProviderInput) ([]*session.Session, error) { + if p.err != nil { + return nil, p.err + } + return []*session.Session{p.session}, nil +} + +type ( + clientsForActionFunc func(ctx context.Context, s3IAMAction, bucket, key string) ([]s3iface.S3API, error) + // clientCache caches clients for all regions, based on the user's SessionProvider. + clientCache struct { + provider SessionProvider + // clients maps clientCacheKey -> *clientCacheValue. + // TODO: Implement some kind of garbage collection and relax the documented constraint + // that sessions are never released. + clients *sync.Map + } + clientCacheKey struct { + region string + // userSession is the session that the user's SessionProvider returned. + // It may be configured for a different region, so we don't use it directly. + userSession *session.Session + } + clientCacheValue struct { + client *s3.S3 + // usedSinceLastGC is 0 or 1. It's set when this client is used, and acted on by the + // GC goroutine. + // TODO: Use atomic.Bool in go1.19. + usedSinceLastGC int32 + } +) + +func newClientCache(provider SessionProvider) *clientCache { + // According to time.Tick documentation, ticker.Stop must be called to avoid leaking ticker + // memory. However, *clientCache is never explicitly "shut down", so we don't have a good way + // to stop the GC loop. Instead, we use a finalizer on *clientCache, and ensure the GC loop + // itself doesn't keep *clientCache alive. + var ( + clients sync.Map + gcCtx, gcCancel = context.WithCancel(context.Background()) + ) + go func() { + ticker := time.NewTicker(clientCacheGarbageCollectionInterval) + defer ticker.Stop() + for { + select { + case <-gcCtx.Done(): + return + case <-ticker.C: + } + clients.Range(func(keyAny, valueAny any) bool { + key := keyAny.(clientCacheKey) + value := valueAny.(*clientCacheValue) + if atomic.SwapInt32(&value.usedSinceLastGC, 0) == 0 { + // Note: Concurrent goroutines could mark this client as used between our query + // and delete. That's fine; we'll just construct a new client next time. + clients.Delete(key) + } + return true + }) + } + }() + // Note: Declare *clientCache after the GC loop to help ensure the latter doesn't keep a + // reference to the former. + cc := clientCache{provider, &clients} + runtime.SetFinalizer(&cc, func(any) { gcCancel() }) + return &cc +} + +func (c *clientCache) forAction(ctx context.Context, s3IAMAction, bucket, key string) ([]s3iface.S3API, error) { + // TODO: Consider using some better default, like current region if we're in EC2. + region := defaultRegion + if bucket != "" { // bucket is empty when listing buckets, for example. + var err error + region, err = FindBucketRegion(ctx, bucket) + if err != nil { + return nil, errors.E(err, fmt.Sprintf("locating region for bucket %s", bucket)) + } + } + sessions, err := c.provider.Get(ctx, SessionProviderInput{S3IAMAction: s3IAMAction, Bucket: bucket, Key: key}) + if err != nil { + return nil, errors.E(err, fmt.Sprintf("getting sessions from provider %T", c.provider)) + } + clients := make([]s3iface.S3API, len(sessions)) + for i, session := range sessions { + key := clientCacheKey{region, session} + obj, ok := c.clients.Load(key) + if !ok { + obj, _ = c.clients.LoadOrStore(key, &clientCacheValue{ + client: s3.New(session, &aws.Config{Region: ®ion}), + usedSinceLastGC: 1, + }) + } + value := obj.(*clientCacheValue) + clients[i] = value.client + atomic.StoreInt32(&value.usedSinceLastGC, 1) + } + return clients, nil +} diff --git a/file/s3file/stat.go b/file/s3file/stat.go new file mode 100644 index 00000000..deac2517 --- /dev/null +++ b/file/s3file/stat.go @@ -0,0 +1,79 @@ +package s3file + +import ( + "context" + "path/filepath" + "strings" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3iface" + "github.com/grailbio/base/errors" + "github.com/grailbio/base/file" +) + +// Stat implements file.Implementation interface. +func (impl *s3Impl) Stat(ctx context.Context, path string, opts ...file.Opts) (file.Info, error) { + _, bucket, key, err := ParseURL(path) + if err != nil { + return nil, errors.E(errors.Invalid, "could not parse", path, err) + } + resp := runRequest(ctx, func() response { + clients, err := impl.clientsForAction(ctx, "GetObject", bucket, key) + if err != nil { + return response{err: err} + } + policy := newBackoffPolicy(clients, mergeFileOpts(opts)) + info, err := stat(ctx, clients, policy, path, bucket, key) + if err != nil { + return response{err: err} + } + return response{info: info} + }) + return resp.info, resp.err +} + +func stat(ctx context.Context, clients []s3iface.S3API, policy retryPolicy, path, bucket, key string) (*s3Info, error) { + if key == "" { + return nil, errors.E(errors.Invalid, "cannot stat with empty S3 key", path) + } + metric := metrics.Op("stat").Start() + defer metric.Done() + for { + var ids s3RequestIDs + output, err := policy.client().HeadObjectWithContext(ctx, + &s3.HeadObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + }, + ids.captureOption(), + ) + if policy.shouldRetry(ctx, err, path) { + metric.Retry() + continue + } + if err != nil { + return nil, annotate(err, ids, &policy, "s3file.stat", path) + } + if output.ETag == nil || *output.ETag == "" { + return nil, errors.E("s3file.stat: empty ETag", path, errors.NotExist, "awsrequestID:", ids.String()) + } + if output.ContentLength == nil { + return nil, errors.E("s3file.stat: nil ContentLength", path, errors.NotExist, "awsrequestID:", ids.String()) + } + if *output.ContentLength == 0 && strings.HasSuffix(path, "/") { + // Assume this is a directory marker: + // https://web.archive.org/web/20190424231712/https://docs.aws.amazon.com/AmazonS3/latest/user-guide/using-folders.html + return nil, errors.E("s3file.stat: directory marker at path", path, errors.NotExist, "awsrequestID:", ids.String()) + } + if output.LastModified == nil { + return nil, errors.E("s3file.stat: nil LastModified", path, errors.NotExist, "awsrequestID:", ids.String()) + } + return &s3Info{ + name: filepath.Base(path), + size: *output.ContentLength, + modTime: *output.LastModified, + etag: *output.ETag, + }, nil + } +} diff --git a/file/s3file/versions.go b/file/s3file/versions.go new file mode 100644 index 00000000..9f6c4e35 --- /dev/null +++ b/file/s3file/versions.go @@ -0,0 +1,58 @@ +package s3file + +import ( + "context" + "fmt" + + "github.com/grailbio/base/errors" + "github.com/grailbio/base/file" + "github.com/grailbio/base/file/addfs" + "github.com/grailbio/base/file/fsnode" + "github.com/grailbio/base/grail/biofs/biofseventlog" +) + +type versionsFunc struct{} + +var ( + VersionsFunc versionsFunc + _ addfs.PerNodeFunc = VersionsFunc +) + +func (versionsFunc) Apply(ctx context.Context, node fsnode.T) ([]fsnode.T, error) { + biofseventlog.UsedFeature("s3.versions.func") + // For now, we rely on gfilefs's fsnode.T implementations saving the underlying path as Sys(). + // This is temporary (BXDS-1030). When we fix that, we'll need to make this detect the + // concrete type for S3-backed fsnode.T's instead of looking for Sys(). That'll likely require + // refactoring such as merging gfilefs into this package. + path, ok := node.Info().Sys().(string) + if !ok { + return nil, nil + } + scheme, bucket, key, err := ParseURL(path) + if err != nil || scheme != Scheme { + return nil, nil + } + implIface := file.FindImplementation(Scheme) + impl, ok := implIface.(*s3Impl) + if !ok { + return nil, errors.E(errors.Precondition, fmt.Sprintf("unrecognized s3 impl: %T", implIface)) + } + var ( + q = s3Query{impl, bucket, key} + gen fsnode.ChildrenGenerator + ) + switch node.(type) { + case fsnode.Parent: + gen = versionsDirViewGen{q} + case fsnode.Leaf: + gen = versionsObjViewGen{q} + default: + return nil, errors.E(errors.Precondition, fmt.Sprintf("unrecognized node: %T", node)) + } + return []fsnode.T{ + fsnode.NewParent( + fsnode.NewDirInfo("versions").WithCacheableFor(fsnode.CacheableFor(node)), + gen, + ), + }, nil +} diff --git a/file/s3file/versions_leaf.go b/file/s3file/versions_leaf.go new file mode 100644 index 00000000..1be7be29 --- /dev/null +++ b/file/s3file/versions_leaf.go @@ -0,0 +1,126 @@ +package s3file + +import ( + "context" + "os" + "sync/atomic" + "unsafe" + + "github.com/aws/aws-sdk-go/service/s3/s3iface" + "github.com/grailbio/base/errors" + "github.com/grailbio/base/file" + "github.com/grailbio/base/file/fsnode" + "github.com/grailbio/base/grail/biofs/biofseventlog" + "github.com/grailbio/base/ioctx" + "github.com/grailbio/base/ioctx/fsctx" +) + +type ( + versionsLeaf struct { + fsnode.FileInfo + s3Query + versionID string + } + versionsFile struct { + versionsLeaf + + // readOffset is the cursor for Read(). + readOffset int64 + + reader chunkReaderCache + } +) + +var ( + _ fsnode.Leaf = versionsLeaf{} + _ fsctx.File = (*versionsFile)(nil) + _ ioctx.ReaderAt = (*versionsFile)(nil) +) + +func (n versionsLeaf) FSNodeT() {} + +func (n versionsLeaf) OpenFile(ctx context.Context, flag int) (fsctx.File, error) { + biofseventlog.UsedFeature("s3.versions.open") + return &versionsFile{versionsLeaf: n}, nil +} + +func (f *versionsFile) Stat(ctx context.Context) (os.FileInfo, error) { + return f.FileInfo, nil +} + +func (f *versionsFile) Read(ctx context.Context, dst []byte) (int, error) { + n, err := f.ReadAt(ctx, dst, f.readOffset) + f.readOffset += int64(n) + return n, err +} + +func (f *versionsFile) ReadAt(ctx context.Context, dst []byte, offset int64) (int, error) { + reader, cleanUp, err := f.reader.getOrCreate(ctx, func() (*chunkReaderAt, error) { + clients, err := f.impl.clientsForAction(ctx, "GetObjectVersion", f.bucket, f.key) + if err != nil { + return nil, errors.E(err, "getting clients") + } + return &chunkReaderAt{ + name: f.path(), bucket: f.bucket, key: f.key, versionID: f.versionID, + newRetryPolicy: func() retryPolicy { + return newBackoffPolicy(append([]s3iface.S3API{}, clients...), file.Opts{}) + }, + }, nil + }) + if err != nil { + return 0, err + } + defer cleanUp() + // TODO: Consider checking s3Info for ETag changes. + n, _, err := reader.ReadAt(ctx, dst, offset) + return n, err +} + +func (f *versionsFile) Close(ctx context.Context) error { + f.reader.close() + return nil +} + +type chunkReaderCache struct { + // available is idle (for some goroutine to use). Goroutines set available = nil before + // using it to "acquire" it, then return it after their operation (if available == nil then). + // If the caller only uses one thread, we'll end up creating and reusing just one + // *chunkReaderAt for all operations. + available unsafe.Pointer // *chunkReaderAt +} + +// get constructs a reader. cleanUp must be called iff error is nil. +func (c *chunkReaderCache) getOrCreate( + ctx context.Context, create func() (*chunkReaderAt, error), +) ( + reader *chunkReaderAt, cleanUp func(), err error, +) { + trySaveReader := func() { + if atomic.CompareAndSwapPointer(&c.available, nil, unsafe.Pointer(reader)) { + return + } + reader.Close() + } + + reader = (*chunkReaderAt)(atomic.SwapPointer(&c.available, nil)) + if reader != nil { + return reader, trySaveReader, nil + } + + reader, err = create() + if err != nil { + if reader != nil { + reader.Close() + } + return nil, nil, err + } + + return reader, trySaveReader, nil +} + +func (c *chunkReaderCache) close() { + reader := (*chunkReaderAt)(atomic.SwapPointer(&c.available, nil)) + if reader != nil { + reader.Close() + } +} diff --git a/file/s3file/versions_list.go b/file/s3file/versions_list.go new file mode 100644 index 00000000..8e682135 --- /dev/null +++ b/file/s3file/versions_list.go @@ -0,0 +1,297 @@ +package s3file + +import ( + "context" + "strings" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/grailbio/base/errors" + "github.com/grailbio/base/file" + "github.com/grailbio/base/file/fsnode" + "github.com/grailbio/base/grail/biofs/biofseventlog" + "github.com/grailbio/base/must" +) + +// s3Query is a generic description of an S3 object or prefix. +type s3Query struct { + impl *s3Impl + // bucket must be non-empty. + bucket string + // key is either an S3 object's key or a key prefix (optionally ending with pathSeparator). + // "" is allowed and refers to the root of the bucket. + key string +} + +func (q s3Query) path() string { return pathPrefix + q.bucket + pathSeparator + q.key } + +// TODO: Dedupe with gfilefs. +const fileInfoCacheFor = 1 * time.Hour + +type ( + // versionsDirViewGen lists all the versions of all the (direct child) objects in a single S3 + // "directory" (that is, single bucket with a single key prefix). + // + // Note: We implement fsnode.ChildrenGenerator rather than fsnode.Iterator because it reduces + // implementation complexity. We need to parse three separate fields from listing responses + // so the implementation is a bit verbose, and Child()/Children() differences introduce edge + // cases we should test. But, we'll probably want to do this eventually. + versionsDirViewGen struct{ s3Query } + + // versionsObjectGen lists the versions of an S3 object. Each version of the object is accessible + // via a child node. Additionally, if there are other S3 object versions that have this path as + // a prefix (or, in directory terms, if there used to be a directory with the same name as this + // file), a dir/ child provides access to those. + // + // Scheme: + // vVERSION_ID/ for each version + // vVERSION_ID (empty file) to mark deletion time + // dir/ for children, if there used to be a "directory" with this name + // TODO: + // @DATE/ -> VERSION_ID/ for each version + // latest/ -> VERSION_ID/ + // 0/, 1/, etc. -> VERSION_ID/ + // + // Note: We implement fsnode.ChildrenGenerator rather than fsnode.Iterator because it reduces + // implementation complexity and we expect number of versions per object to be relatively + // modest in practice. If we see performance problems, we can make it more sophisticated. + versionsObjViewGen struct{ s3Query } +) + +var ( + _ fsnode.ChildrenGenerator = versionsDirViewGen{} + _ fsnode.ChildrenGenerator = versionsObjViewGen{} + + objViewDirInfo = fsnode.NewDirInfo("dir").WithCacheableFor(fileInfoCacheFor) +) + +func (g versionsDirViewGen) GenerateChildren(ctx context.Context) ([]fsnode.T, error) { + biofseventlog.UsedFeature("s3.versions.dirview") + dirPrefix := g.key + if dirPrefix != "" { + dirPrefix = g.key + pathSeparator + } + iterator, err := newVersionsIterator(ctx, g.impl, g.s3Query, s3.ListObjectVersionsInput{ + Bucket: aws.String(g.bucket), + Delimiter: aws.String(pathSeparator), + Prefix: aws.String(dirPrefix), + }) + if err != nil { + return nil, err + } + var ( + dirChildren = map[string]fsnode.T{} + objChildren = map[string][]fsnode.T{} + ) + for iterator.HasNextPage() { + out, err := iterator.NextPage(ctx) + if err != nil { + return nil, err + } + for _, common := range out.CommonPrefixes { + name := (*common.Prefix)[len(dirPrefix):] + name = name[:len(name)-len(pathSeparator)] + if name == "" { + // Note: S3 keys may have multiple trailing `/`s leading to name == "". + // For now, we skip these, making them inaccessible to users. + // TODO: Better mapping of S3 key semantics onto fsnode.T, for example recursively + // listing "key//" so we can merge those children into "key/"'s. + // See also: BXDS-2039 for the non-version listing case. + continue + } + q := g.s3Query + q.key = dirPrefix + name + dirChildren[name] = fsnode.NewParent( + fsnode.NewDirInfo(name).WithCacheableFor(fileInfoCacheFor), + versionsDirViewGen{q}) + } + for _, del := range out.DeleteMarkers { + if *del.Key == dirPrefix { + continue // Skip directory markers. + } + name := (*del.Key)[len(dirPrefix):] + objChildren[name] = append(objChildren[name], newDeleteChild(del)) + } + for _, version := range out.Versions { + if *version.Key == dirPrefix { + continue // Skip directory markers. + } + q := g.s3Query + q.key = *version.Key + name := q.key[len(dirPrefix):] + objChildren[name] = append(objChildren[name], newVersionChild(q, version)) + } + } + merged := make([]fsnode.T, 0, len(dirChildren)+len(objChildren)) + for name, child := range dirChildren { + if _, ok := objChildren[name]; ok { + // If a name was used both for files and directories, prefer files here, because + // the user can find the directory view under {name}/dir/. + continue + } + merged = append(merged, child) + } + for name, children := range objChildren { + merged = append(merged, fsnode.NewParent( + fsnode.NewDirInfo(name).WithCacheableFor(fileInfoCacheFor), + fsnode.ConstChildren(children...), + )) + } + return merged, nil +} + +func (g versionsObjViewGen) GenerateChildren(ctx context.Context) ([]fsnode.T, error) { + biofseventlog.UsedFeature("s3.versions.objview") + iterator, err := newVersionsIterator(ctx, g.impl, g.s3Query, s3.ListObjectVersionsInput{ + Bucket: aws.String(g.bucket), + Delimiter: aws.String(pathSeparator), + Prefix: aws.String(g.key), + }) + if err != nil { + return nil, err + } + var ( + versions []fsnode.T + hasOtherChildren bool + ) + for iterator.HasNextPage() { + out, err := iterator.NextPage(ctx) + if err != nil { + return nil, err + } + if len(out.CommonPrefixes) > 0 { + hasOtherChildren = true + } + for _, del := range out.DeleteMarkers { + if *del.Key != g.key { + hasOtherChildren = true + // del is in a "subdirectory" of a previous directory version of our object. + // We don't render those here; instead we just add the dir/ child below. + continue + // Note: It seems like S3 returns delete markers in sorted order, but the API + // docs don't explicitly state this for ListObjectVersions [1] as they do for + // ListObjectsV2 [2], so we `continue` instead of `break`. We're still assuming + // API response pages are so ordered, though, because the alternative is unworkable. + // TODO: Ask AWS for explicit documentation on versions ordering. + // + // [1] https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListObjectVersions.html + // [2] https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListObjectsV2.html + } + versions = append(versions, newDeleteChild(del)) + } + for _, version := range out.Versions { + if *version.Key != g.key { + hasOtherChildren = true + continue // See delete marker note. + } + versions = append(versions, newVersionChild(g.s3Query, version)) + } + } + if hasOtherChildren { + versions = append(versions, fsnode.NewParent(objViewDirInfo, versionsDirViewGen{g.s3Query})) + } + return versions, nil +} + +func newVersionChild(q s3Query, v *s3.ObjectVersion) fsnode.Parent { + must.Truef(len(q.key) > 0, "creating child for %#v, %v", q, v) + name := q.key + if idx := strings.LastIndex(name, pathSeparator); idx >= 0 { + name = name[idx+len(pathSeparator):] + } + dirName := "v" + sanitizePathElem(*v.VersionId) + // Some S3 storage classes don't allow immediate, direct access (for example, requiring restore + // first). They also have very different cost profiles and users may not know about these and + // accidentally, expensively download many objects (especially with biofs, where it's easy + // to run `grep`, etc.). We have a best-effort allowlist and block others for now. + // TODO: Refine this UX. Maybe add a README.txt describing these properties and suggesting + // using the AWS console for unsupported objects. + // TODO: Consider supporting Glacier restoration. + // + // Note: This field's `enum` tag [1] names an enum with only one value (standard class [2]). + // However, as of this writing, we're seeing the API return more values, like DEEP_ARCHIVE. + // We assume it can take any value in s3.ObjectStorageClass* instead. + // TODO: Verify this, report to AWS, etc. + // [1] https://pkg.go.dev/github.com/aws/aws-sdk-go@v1.42.0/service/s3#ObjectVersion.StorageClass + // [2] https://pkg.go.dev/github.com/aws/aws-sdk-go@v1.42.0/service/s3#ObjectVersionStorageClassStandard + switch *v.StorageClass { + default: + dirName += "." + *v.StorageClass + return fsnode.NewParent( + fsnode.NewDirInfo(dirName).WithModTime(*v.LastModified), + fsnode.ConstChildren()) + case + s3.ObjectStorageClassStandard, + s3.ObjectStorageClassReducedRedundancy, + s3.ObjectStorageClassStandardIa, + s3.ObjectStorageClassOnezoneIa, + s3.ObjectStorageClassIntelligentTiering: + return fsnode.NewParent( + fsnode.NewDirInfo(dirName).WithModTime(*v.LastModified), + fsnode.ConstChildren( + versionsLeaf{ + FileInfo: fsnode.NewRegInfo(name).WithSize(*v.Size).WithModTime(*v.LastModified), + s3Query: q, + versionID: *v.VersionId, + }, + ), + ) + } +} + +func newDeleteChild(del *s3.DeleteMarkerEntry) fsnode.T { + return fsnode.ConstLeaf( + fsnode.NewRegInfo("v"+sanitizePathElem(*del.VersionId)).WithModTime(*del.LastModified), + nil) +} + +type versionsIterator struct { + in s3.ListObjectVersionsInput + eof bool + policy retryPolicy + path string +} + +func newVersionsIterator( + ctx context.Context, + impl *s3Impl, + q s3Query, + in s3.ListObjectVersionsInput, +) (*versionsIterator, error) { + clients, err := impl.clientsForAction(ctx, "ListVersions", q.bucket, q.key) + if err != nil { + return nil, errors.E(err, "getting clients") + } + policy := newBackoffPolicy(clients, file.Opts{}) + return &versionsIterator{in: in, policy: policy, path: q.path()}, nil +} + +func (it *versionsIterator) HasNextPage() bool { return !it.eof } + +func (it *versionsIterator) NextPage(ctx context.Context) (*s3.ListObjectVersionsOutput, error) { + for { + var ids s3RequestIDs + out, err := it.policy.client().ListObjectVersionsWithContext(ctx, &it.in, ids.captureOption()) + if err == nil { + it.in.KeyMarker = out.NextKeyMarker + it.in.VersionIdMarker = out.NextVersionIdMarker + if !*out.IsTruncated { + it.eof = true + } + return out, nil + } + if !it.policy.shouldRetry(ctx, err, it.path) { + it.eof = true + return nil, annotate(err, ids, &it.policy, "s3file.versionsRootNode.Child", it.path) + } + } +} + +func sanitizePathElem(s string) string { + // TODO: Consider being stricter. S3 guarantees very little about version IDs: + // https://docs.aws.amazon.com/AmazonS3/latest/userguide/versioning-workflows.html#version-ids + // TODO: Implement more robust replacement (with some escape char, etc.) so that we cannot + // introduce collisions. + return strings.ReplaceAll(s, "/", "_") +} diff --git a/file/util.go b/file/util.go index 17125c6b..646f50ab 100644 --- a/file/util.go +++ b/file/util.go @@ -13,9 +13,9 @@ import ( ) // ReadFile reads the given file and returns the contents. A successful call -// returns err == nil, not err == EOF. -func ReadFile(ctx context.Context, path string) ([]byte, error) { - in, err := Open(ctx, path) +// returns err == nil, not err == EOF. Arg opts is passed to file.Open. +func ReadFile(ctx context.Context, path string, opts ...Opts) ([]byte, error) { + in, err := Open(ctx, path, opts...) if err != nil { return nil, err } diff --git a/fileio/close.go b/fileio/close.go new file mode 100644 index 00000000..6d17d3b2 --- /dev/null +++ b/fileio/close.go @@ -0,0 +1,46 @@ +package fileio + +import ( + "fmt" + "io" + + "github.com/grailbio/base/errors" +) + +type named interface { + // Name returns the path name. + Name() string +} + +// CloseAndReport returns a defer-able helper that calls f.Close and reports errors, if any, +// to *err. Pass your function's named return error. Example usage: +// +// func processFile(filename string) (_ int, err error) { +// f, err := os.Open(filename) +// if err != nil { ... } +// defer fileio.CloseAndReport(f, &err) +// ... +// } +// +// If your function returns with an error, any f.Close error will be chained appropriately. +// +// Deprecated: Use errors.CleanUp directly. +func CloseAndReport(f io.Closer, err *error) { + errors.CleanUp(f.Close, err) +} + +// MustClose is a defer-able function that calls f.Close and panics on error. +// +// Example: +// f, err := os.Open(filename) +// if err != nil { panic(err) } +// defer fileio.MustClose(f) +// ... +func MustClose(f io.Closer) { + if err := f.Close(); err != nil { + if n, ok := f.(named); ok { + panic(fmt.Sprintf("close %s: %v", n.Name(), err)) + } + panic(err) + } +} diff --git a/fileio/names.go b/fileio/names.go index 8f0916a3..d33ca209 100644 --- a/fileio/names.go +++ b/fileio/names.go @@ -31,6 +31,10 @@ const ( GrailRIOPackedCompressedAndEncrypted // JSON text file JSON + // Zstd format. + // https://facebook.github.io/zstd/ + // https://tools.ietf.org/html/rfc8478 + Zstd ) var lookup = map[string]FileType{ @@ -42,6 +46,7 @@ var lookup = map[string]FileType{ ".grail-rpk-gz": GrailRIOPackedCompressed, ".grail-rpk-gz-kd": GrailRIOPackedCompressedAndEncrypted, ".json": JSON, + ".zst": Zstd, } // StorageAPI represents the Storage API required to access a file. @@ -53,8 +58,6 @@ const ( LocalAPI StorageAPI = iota // S3API represents an Amazon S3 API. S3API - - s3Separator = '/' ) // DetermineAPI determines the Storage API that stores the file diff --git a/flock/flock_test.go b/flock/flock_test.go new file mode 100644 index 00000000..7e490a51 --- /dev/null +++ b/flock/flock_test.go @@ -0,0 +1,77 @@ +package flock_test + +import ( + "context" + "io/ioutil" + "sync/atomic" + "testing" + "time" + + "github.com/grailbio/base/flock" + "github.com/grailbio/testutil/assert" +) + +func TestLock(t *testing.T) { + tempDir, err := ioutil.TempDir("", "") + if err != nil { + t.Fatal(err) + } + + lockPath := tempDir + "/lock" + lock := flock.New(lockPath) + + // Test uncontended locks + ctx := context.Background() + for i := 0; i < 3; i++ { + assert.NoError(t, lock.Lock(ctx)) + assert.NoError(t, lock.Unlock()) + } + + assert.NoError(t, lock.Lock(ctx)) + locked := int64(0) + doneCh := make(chan struct{}) + go func() { + assert.NoError(t, lock.Lock(ctx)) + atomic.StoreInt64(&locked, 1) + assert.NoError(t, lock.Unlock()) + atomic.StoreInt64(&locked, 2) + doneCh <- struct{}{} + }() + + time.Sleep(500 * time.Millisecond) + if atomic.LoadInt64(&locked) != 0 { + t.Errorf("locked=%d", locked) + } + + assert.NoError(t, lock.Unlock()) + <-doneCh + if atomic.LoadInt64(&locked) != 2 { + t.Errorf("locked=%d", locked) + } +} + +func TestLockContext(t *testing.T) { + tempDir, err := ioutil.TempDir("", "") + if err != nil { + t.Fatal(err) + } + lockPath := tempDir + "/lock" + + lock := flock.New(lockPath) + ctx := context.Background() + ctx2, cancel2 := context.WithCancel(ctx) + assert.NoError(t, lock.Lock(ctx2)) + assert.NoError(t, lock.Unlock()) + + assert.NoError(t, lock.Lock(ctx)) + go func() { + time.Sleep(500 * time.Millisecond) + cancel2() + }() + assert.Regexp(t, lock.Lock(ctx2), "context canceled") + + assert.NoError(t, lock.Unlock()) + // Make sure the lock is in a sane state by cycling lock-unlock again. + assert.NoError(t, lock.Lock(ctx)) + assert.NoError(t, lock.Unlock()) +} diff --git a/flock/flock_unix.go b/flock/flock_unix.go new file mode 100644 index 00000000..f54185b2 --- /dev/null +++ b/flock/flock_unix.go @@ -0,0 +1,84 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +// Package flock implements a simple POSIX file-based advisory lock. +package flock + +import ( + "context" + "sync" + "syscall" + + "github.com/grailbio/base/log" +) + +type T struct { + name string + fd int + mu sync.Mutex +} + +// New creates an object that locks the given path. +func New(path string) *T { + return &T{name: path} +} + +// Lock locks the file. Iff Lock() returns nil, the caller must call Unlock() +// later. +func (f *T) Lock(ctx context.Context) (err error) { + reqCh := make(chan func() error, 2) + doneCh := make(chan error) + go func() { + var err error + for req := range reqCh { + if err == nil { + err = req() + } + doneCh <- err + } + }() + reqCh <- f.doLock + select { + case <-ctx.Done(): + reqCh <- f.doUnlock + err = ctx.Err() + case err = <-doneCh: + } + close(reqCh) + return err +} + +// Unlock unlocks the file. +func (f *T) Unlock() error { + return f.doUnlock() +} + +func (f *T) doLock() error { + f.mu.Lock() // Serialize the lock within one process. + + var err error + f.fd, err = syscall.Open(f.name, syscall.O_CREAT|syscall.O_RDWR, 0777) + if err != nil { + f.mu.Unlock() + return err + } + err = syscall.Flock(f.fd, syscall.LOCK_EX|syscall.LOCK_NB) + for err == syscall.EWOULDBLOCK || err == syscall.EAGAIN { + log.Printf("waiting for lock %s", f.name) + err = syscall.Flock(f.fd, syscall.LOCK_EX) + } + if err != nil { + f.mu.Unlock() + } + return err +} + +func (f *T) doUnlock() error { + err := syscall.Flock(f.fd, syscall.LOCK_UN) + if err := syscall.Close(f.fd); err != nil { + log.Error.Printf("close %s: %v", f.name, err) + } + f.mu.Unlock() + return err +} diff --git a/go.mod b/go.mod new file mode 100644 index 00000000..b1565b59 --- /dev/null +++ b/go.mod @@ -0,0 +1,54 @@ +module github.com/grailbio/base + +go 1.13 + +require ( + cloud.google.com/go v0.46.3 // indirect + github.com/DataDog/zstd v1.4.1 + github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 + github.com/aws/aws-sdk-go v1.34.31 + github.com/biogo/store v0.0.0-20190426020002-884f370e325d + github.com/cespare/xxhash/v2 v2.1.0 + github.com/coreos/go-oidc v2.1.0+incompatible + github.com/fullsailor/pkcs7 v0.0.0-20190404230743-d7302db945fa + github.com/go-test/deep v1.0.4 + github.com/gobwas/glob v0.2.3 + github.com/golang/protobuf v1.4.3 + github.com/google/gofuzz v1.1.0 + github.com/google/gops v0.3.6 + github.com/grailbio/testutil v0.0.3 + github.com/grailbio/v23/factories/grail v0.0.0-20190904050408-8a555d238e9a + github.com/hanwen/go-fuse/v2 v2.0.2 + github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 // indirect + github.com/keybase/go-keychain v0.0.0-20190828153431-2390ae572545 + github.com/klauspost/compress v1.8.6 + github.com/pkg/errors v0.9.1 + github.com/stretchr/testify v1.6.1 + github.com/willf/bitset v1.1.10 + github.com/yasushi-saito/zlibng v0.0.0-20190922135643-2a860060b80c + go.opencensus.io v0.22.1 // indirect + golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a + golang.org/x/net v0.0.0-20201110031124-69a78807bb2b + golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45 + golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 + golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f + google.golang.org/appengine v1.6.5 // indirect + v.io v0.1.15 + v.io/x/lib v0.1.7 +) + +require ( + github.com/cespare/xxhash v1.1.0 + github.com/gin-gonic/gin v1.6.3 + github.com/gogo/protobuf v1.3.2 // indirect + github.com/google/uuid v1.1.2 + github.com/shirou/gopsutil v2.19.9+incompatible + go.uber.org/zap v1.16.0 + google.golang.org/api v0.10.0 + gopkg.in/inf.v0 v0.9.1 // indirect + k8s.io/api v0.0.0-20181213150558-05914d821849 + k8s.io/apimachinery v0.0.0-20181127025237-2b1284ed4c93 + k8s.io/client-go v10.0.0+incompatible + k8s.io/klog v1.0.0 // indirect + sigs.k8s.io/yaml v1.2.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 00000000..3b13d663 --- /dev/null +++ b/go.sum @@ -0,0 +1,549 @@ +cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +cloud.google.com/go v0.38.0 h1:ROfEUZz+Gh5pa62DJWXSaonyu3StP6EA6lPEXPI6mCo= +cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU= +cloud.google.com/go v0.44.1/go.mod h1:iSa0KzasP4Uvy3f1mN/7PiObzGgflwredwwASm/v6AU= +cloud.google.com/go v0.44.2/go.mod h1:60680Gw3Yr4ikxnPRS/oxxkBccT6SA1yMk63TGekxKY= +cloud.google.com/go v0.45.1/go.mod h1:RpBamKRgapWJb87xiFSdk4g1CME7QZg3uwTez+TSTjc= +cloud.google.com/go v0.46.3 h1:AVXDdKsrtX33oR9fbCMu/+c1o8Ofjq6Ku/MInaLVg5Y= +cloud.google.com/go v0.46.3/go.mod h1:a6bKKbmY7er1mI7TEI4lsAkts/mkhTSZK8w33B4RAg0= +cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o= +cloud.google.com/go/datastore v1.0.0/go.mod h1:LXYbyblFSglQ5pkeyhO+Qmw7ukd3C+pD7TKLgZqpHYE= +cloud.google.com/go/pubsub v1.0.1/go.mod h1:R0Gpsv3s54REJCy4fxDixWD93lHJMoZTyQ2kNxGRt3I= +github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/DATA-DOG/go-sqlmock v1.3.3/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= +github.com/DataDog/zstd v1.4.1 h1:3oxKN3wbHibqx897utPC2LTQU4J+IHWWJO+glkAkpFM= +github.com/DataDog/zstd v1.4.1/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo= +github.com/NYTimes/gziphandler v0.0.0-20170623195520-56545f4a5d46/go.mod h1:3wb06e3pkSAbeQ52E9H9iFoQsEEwGN64994WTCIhntQ= +github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 h1:TngWCqHvy9oXAN6lEVMRuU21PR1EtLVZJmdB18Gu3Rw= +github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8D7ML55dXQrVaamCz2vxCfdQBasLZfHKk= +github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE= +github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= +github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= +github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= +github.com/StackExchange/wmi v0.0.0-20170410192909-ea383cf3ba6e/go.mod h1:3eOhrUMpNV+6aFIbp5/iudMxNCF27Vw2OZgy4xEx0Fg= +github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d h1:G0m3OIz70MZUWq3EgK3CesDbo8upS2Vm9/P3FtgI+Jk= +github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d/go.mod h1:3eOhrUMpNV+6aFIbp5/iudMxNCF27Vw2OZgy4xEx0Fg= +github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= +github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= +github.com/aws/aws-sdk-go v1.23.14/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= +github.com/aws/aws-sdk-go v1.23.22/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= +github.com/aws/aws-sdk-go v1.34.31 h1:408wh5EHKzxyby8JpYfnn1w3fsF26AIU0o1kbJoRy7E= +github.com/aws/aws-sdk-go v1.34.31/go.mod h1:H7NKnBqNVzoTJpGfLrQkkD+ytBA93eiDYi/+8rV9s48= +github.com/biogo/store v0.0.0-20190426020002-884f370e325d h1:vu2gsANkGtqYaQNXhmAAiJ7b1eKTbX3/aPAvDCasBgE= +github.com/biogo/store v0.0.0-20190426020002-884f370e325d/go.mod h1:Iev9Q3MErcn+w3UOJD/DkEzllvugfdx7bGcMOFhvr/4= +github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= +github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= +github.com/cespare/xxhash/v2 v2.1.0 h1:yTUvW7Vhb89inJ+8irsUqiWjh8iT6sQPZiQzI6ReGkA= +github.com/cespare/xxhash/v2 v2.1.0/go.mod h1:dgIUBU3pDso/gPgZ1osOZ0iQf77oPR28Tjxl5dIMyVM= +github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/coreos/go-oidc v2.1.0+incompatible h1:sdJrfw8akMnCuUlaZU3tE/uYXFgfqom8DBE9so9EBsM= +github.com/coreos/go-oidc v2.1.0+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc= +github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/docker/spdystream v0.0.0-20160310174837-449fdfce4d96/go.mod h1:Qh8CwZgvJUkLughtfhJv5dyTYa91l1fOUCrgjqmcifM= +github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE= +github.com/elazarl/goproxy v0.0.0-20180725130230-947c36da3153/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc= +github.com/emicklei/go-restful v0.0.0-20170410110728-ff4f55a20633/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs= +github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= +github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/evanphx/json-patch v4.9.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= +github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/fullsailor/pkcs7 v0.0.0-20190404230743-d7302db945fa h1:RDBNVkRviHZtvDvId8XSGPu3rmpmSe+wKRcEWNgsfWU= +github.com/fullsailor/pkcs7 v0.0.0-20190404230743-d7302db945fa/go.mod h1:KnogPXtdwXqoenmZCw6S+25EAm2MkxbG0deNDu4cbSA= +github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.6.3 h1:ahKqKTFpO5KTPHxWZjEdPScmYaGtLo8Y4DMHoEsnp14= +github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= +github.com/go-logr/logr v0.1.0/go.mod h1:ixOQHD9gLJUVQQ2ZOR7zLEifBX6tGkNJF4QyIY7sIas= +github.com/go-logr/logr v0.2.0 h1:QvGt2nLcHH0WK9orKa+ppBPAxREcH364nPUedEpK0TY= +github.com/go-logr/logr v0.2.0/go.mod h1:z6/tIYblkpsD+a4lm/fGIIU9mZ+XfAiaFtq7xTgseGU= +github.com/go-ole/go-ole v1.2.1/go.mod h1:7FAglXiTm7HKlQRDeOQ6ZNUHidzCWXuZWq/1dTyBNF8= +github.com/go-ole/go-ole v1.2.4 h1:nNBDSCOigTSiarFpYE9J/KtEA1IOW4CNeqT9TQDqCxI= +github.com/go-ole/go-ole v1.2.4/go.mod h1:XCwSNxSkXRo4vlyPy93sltvi/qJq0jqQhjqQNIwKuxM= +github.com/go-openapi/jsonpointer v0.19.2/go.mod h1:3akKfEdA7DF1sugOqz1dVQHBcuDBPKZGEoHC/NkiQRg= +github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= +github.com/go-openapi/jsonreference v0.19.2/go.mod h1:jMjeRr2HHw6nAVajTXJ4eiUwohSTlpa0o73RUL1owJc= +github.com/go-openapi/jsonreference v0.19.3/go.mod h1:rjx6GuL8TTa9VaixXglHmQmIL98+wF9xc8zWvFonSJ8= +github.com/go-openapi/spec v0.19.3/go.mod h1:FpwSN1ksY1eteniUU7X0N/BgJ7a4WvBFVA8Lj9mJglo= +github.com/go-openapi/swag v0.19.2/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= +github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= +github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= +github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8ceBS/t7Q= +github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= +github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD876Lmtgy7VtROAbHHXk8no= +github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= +github.com/go-playground/validator/v10 v10.2.0 h1:KgJ0snyC2R9VXYN2rneOtQcw5aHQB1Vv0sFl1UcHBOY= +github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= +github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= +github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/go-test/deep v1.0.3/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA= +github.com/go-test/deep v1.0.4 h1:u2CU3YKy9I2pmu9pX0eq50wCgjfGIt539SqR7FbHiho= +github.com/go-test/deep v1.0.4/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA= +github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= +github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= +github.com/gogo/protobuf v1.3.1 h1:DqDEcV5aeaTmdFBePNpYsp3FlcVH/2ISVVM9Qf8PSls= +github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= +github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= +github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= +github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b h1:VKtxabqXZkF25pY9ekfRL6a582T4P37/31XEstQ5p58= +github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= +github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6 h1:ZgQEtGgCBiWRM39fZuwSd1LwSqqSW0hOdXCYYDX0R3I= +github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20191002201903-404acd9df4cc h1:55rEp52jU6bkyslZ1+C/7NGfpQsEc6pxGLAGDOctqbw= +github.com/golang/groupcache v0.0.0-20191002201903-404acd9df4cc/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e h1:1r7pUrabqp18hOBcwBwiTsbnFeTZHV9eER/QT5JVZxY= +github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= +github.com/golang/protobuf v1.4.3 h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/gofuzz v1.0.0 h1:A8PeW59pxE9IoFRqBp37U+mSNaQoZ46F1f0f863XSXw= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/gofuzz v1.1.0 h1:Hsa8mG0dQ46ij8Sl2AYJDUv1oA9/d6Vk+3LG99Oe02g= +github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/gops v0.3.6 h1:6akvbMlpZrEYOuoebn2kR+ZJekbZqJ28fJXTs84+8to= +github.com/google/gops v0.3.6/go.mod h1:RZ1rH95wsAGX4vMWKmqBOIWynmWisBf4QFdgT/k/xOI= +github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= +github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= +github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= +github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= +github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= +github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= +github.com/googleapis/gax-go/v2 v2.0.5 h1:sjZBwGj9Jlw33ImPtvFviGYvseOtDM7hkSKB7+Tv3SM= +github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= +github.com/googleapis/gnostic v0.4.1/go.mod h1:LRhVm6pbyptWbWbuZ38d1eyptfvIytN3ir6b65WBswg= +github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= +github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/grailbio/base v0.0.1/go.mod h1:wVM2Cq2/HT0rt6WYGQhXJ3CCLkNnGjeAAOPHCZ2IsN0= +github.com/grailbio/v23/factories/grail v0.0.0-20190904050408-8a555d238e9a h1:kAl1x1ErQgs55bcm/WdoKCPny/kIF7COmC+UGQ9GKcM= +github.com/grailbio/v23/factories/grail v0.0.0-20190904050408-8a555d238e9a/go.mod h1:2g5HI42KHw+BDBdjLP3zs+WvTHlDK3RoE8crjCl26y4= +github.com/hanwen/go-fuse v1.0.0 h1:GxS9Zrn6c35/BnfiVsZVWmsG803xwE7eVRDvcf/BEVc= +github.com/hanwen/go-fuse v1.0.0/go.mod h1:unqXarDXqzAk0rt98O2tVndEPIpUgLD9+rwFisZH3Ok= +github.com/hanwen/go-fuse/v2 v2.0.2 h1:BtsqKI5RXOqDMnTgpCb0IWgvRgGLJdqYVZ/Hm6KgKto= +github.com/hanwen/go-fuse/v2 v2.0.2/go.mod h1:HH3ygZOoyRbP9y2q7y3+JM6hPL+Epe29IbWaS0UA81o= +github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/golang-lru v0.5.1 h1:0hERBMJE1eitiLkihrMvRVBYAkpHzc/J3QdDN+dAcgU= +github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/golang-lru v0.5.3/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af h1:pmfjZENx5imkbgOkpRUYLnmbU7UEFbjtDA2hxJ1ichM= +github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= +github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= +github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= +github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= +github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns= +github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/json-iterator/go v1.1.10 h1:Kz6Cvnvv2wGdaG/V8yMvfkmNiXq9Ya2KUv4rouJJr68= +github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= +github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= +github.com/kardianos/osext v0.0.0-20170510131534-ae77be60afb1 h1:PJPDf8OUfOK1bb/NeTKd4f1QXZItOX389VN3B6qC8ro= +github.com/kardianos/osext v0.0.0-20170510131534-ae77be60afb1/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= +github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 h1:iQTw/8FWTuc7uiaSepXwyf3o52HaUYcV+Tu66S3F5GA= +github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8= +github.com/keybase/go-keychain v0.0.0-20190828153431-2390ae572545 h1:9w9Fw+XzuDIZFMW5zxNehR97pMWu+1CpcEYK71PyAGg= +github.com/keybase/go-keychain v0.0.0-20190828153431-2390ae572545/go.mod h1:JJNrCn9otv/2QP4D7SMJBgaleKpOf66PnW6F5WGNRIc= +github.com/keybase/go-ps v0.0.0-20161005175911-668c8856d999/go.mod h1:hY+WOq6m2FpbvyrI93sMaypsttvaIL5nhVR92dTMUcQ= +github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= +github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/klauspost/compress v1.8.1/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= +github.com/klauspost/compress v1.8.6 h1:970MQcQdxX7hfgc/aqmB4a3grW0ivUVV6i1TLkP8CiE= +github.com/klauspost/compress v1.8.6/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= +github.com/klauspost/cpuid v1.2.1 h1:vJi+O/nMdFt0vqm8NZBI6wzALWdA2X+egi0ogNyrC/w= +github.com/klauspost/cpuid v1.2.1/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.5/go.mod h1:9r2w37qlBe7rQ6e1fg1S/9xpWHSnaqNdHD3WcMdbPDA= +github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348 h1:MtvEpTB6LX3vkb4ax0b5D2DHbNAUsen0Gx5wZoq3lV4= +github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348/go.mod h1:B69LEHPfb2qLo0BaaOLcbitczOKLWTsrBG9LczfCD4k= +github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= +github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= +github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-sqlite3 v1.11.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= +github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/modern-go/reflect2 v1.0.1 h1:9f412s+6RmYXLWZSEzVVgPGK7C2PphHj5RJrvfx9AWI= +github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/munnerz/goautoneg v0.0.0-20120707110453-a547fc61f48d/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= +github.com/onsi/ginkgo v0.0.0-20170829012221-11459a886d9c/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.11.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/gomega v0.0.0-20170829124025-dcabb60a477c/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5uiA= +github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= +github.com/pborman/uuid v1.2.0 h1:J7Q5mO4ysT1dv8hyrUGHb9+ooztCXu1D8MY8DZYsu3g= +github.com/pborman/uuid v1.2.0/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35 h1:J9b7z+QKAmPf4YLrFg6oQUotqHQeUNWwkvo7jZp1GLU= +github.com/pquerna/cachecontrol v0.0.0-20180517163645-1555304b9b35/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA= +github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.3.2/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= +github.com/shirou/gopsutil v0.0.0-20180427012116-c95755e4bcd7/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= +github.com/shirou/gopsutil v2.18.12+incompatible h1:1eaJvGomDnH74/5cF4CTmTbLHAriGFsTZppLXDX93OM= +github.com/shirou/gopsutil v2.18.12+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= +github.com/shirou/gopsutil v2.19.9+incompatible h1:IrPVlK4nfwW10DF7pW+7YJKws9NkgNzWozwwWv9FsgY= +github.com/shirou/gopsutil v2.19.9+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA= +github.com/shirou/w32 v0.0.0-20160930032740-bb4de0191aa4/go.mod h1:qsXQc7+bwAM3Q1u/4XEfrquwF8Lw7D7y5cD8CuHnfIc= +github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ= +github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= +github.com/spf13/afero v1.2.2/go.mod h1:9ZxEEn6pIJ8Rxe320qSDBk6AsU0r9pR7Q4OcevTdifk= +github.com/spf13/pflag v0.0.0-20170130214245-9ff6c6923cff/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= +github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1 h1:2vfRuCMp5sSVIDSqO8oNnWJq7mPa6KVP3iPIwFBuy8A= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.2.0 h1:Hbg2NidpLE8veEBkEZTL3CvlkUIVzuU9jDplZO54c48= +github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/ugorji/go v1.1.7 h1:/68gy2h+1mWMrwZFeD1kQialdSzAb432dtpeJ42ovdo= +github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= +github.com/ugorji/go/codec v1.1.7 h1:2SvQaVZ1ouYrrKKwoSk2pzd4A9evlKJb9oTL+OaLUSs= +github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= +github.com/vanadium/go-mdns-sd v0.0.0-20181006014439-f1a1ccd1252e h1:pHSeCN6iUoIWXqaMgi9TeKuESVQY1zThuhVjAHq3GpI= +github.com/vanadium/go-mdns-sd v0.0.0-20181006014439-f1a1ccd1252e/go.mod h1:35fXDjvKtzyf89fHHhyTTNLHaG2CkI7u/GvO59PIjP4= +github.com/vitessio/vitess v2.1.1+incompatible/go.mod h1:A11WWLimUfZAYYm8P1I63RryRPP2GdpHRgQcfa++OnQ= +github.com/willf/bitset v1.1.10 h1:NotGKqX0KwQ72NUzqrjZq5ipPNDQex9lo3WpaS8L2sc= +github.com/willf/bitset v1.1.10/go.mod h1:RjeCKbqT1RxIR/KWY6phxZiaY1IyutSBfGjNPySAYV4= +github.com/xlab/treeprint v0.0.0-20180616005107-d6fb6747feb6/go.mod h1:ce1O1j6UtZfjr22oyGxGLbauSBp2YVXpARAosm7dHBg= +github.com/yasushi-saito/zlibng v0.0.0-20190905015749-ec536402779e/go.mod h1:qD8maXXiM82RPOfKUGWetL74si8WnsRS7LNPDWK7byI= +github.com/yasushi-saito/zlibng v0.0.0-20190922135643-2a860060b80c h1:PacAOojZgacR3Rs+QL8Vb8UiySXgKMU/HPqvaOotVqs= +github.com/yasushi-saito/zlibng v0.0.0-20190922135643-2a860060b80c/go.mod h1:fmRgeAuoXV70NcmjNe3PyhylzfGSgyLv9nZaW/I/C7Q= +github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= +go.opencensus.io v0.21.0 h1:mU6zScU4U1YAFPHEHYk+3JC4SY7JxgkqS10ZOSyksNg= +go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= +go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= +go.opencensus.io v0.22.1 h1:8dP3SGL7MPB94crU3bEPplMPe83FI4EouesJUeFHv50= +go.opencensus.io v0.22.1/go.mod h1:Ap50jQcDJrx6rB6VgeeFPtuPIf3wMRvRfrfYDO6+BmA= +go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk= +go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/multierr v1.5.0 h1:KCa4XfM8CWFCpxXRGok+Q0SS/0XBhMDbHHGABQLvD2A= +go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= +go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee h1:0mgffUl7nfd+FpvXMVz4IDEaUSmT1ysygQC7qYo7sG4= +go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= +go.uber.org/zap v1.16.0 h1:uFRZXykJGK9lLY4HtgSw44DnIcAM+kRBP7x5m+NpAOM= +go.uber.org/zap v1.16.0/go.mod h1:MA8QOfq0BHJwdXa996Y4dYkAqRKB8/1K1QMMZVaNZjQ= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190611184440-5c40567a22f8/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a h1:vclmkQCjlDX5OydZ9wv8rBCcS0QyQY66Mpf/7BZbInM= +golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= +golang.org/x/exp v0.0.0-20190829153037-c13cbed26979/go.mod h1:86+5VVa7VpoJ4kLfm080zCjGlMRFzhUhsZKEZO7MGek= +golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= +golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= +golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= +golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= +golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= +golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20190409202823-959b441ac422/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20190909230951-414d861bb4ac/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE= +golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= +golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= +golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= +golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.3.0 h1:RM4zey1++hCTbCVQfnWeKs9/IEsaBLA8vTkd0WVtmH4= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190501004415-9ce7a6920f09/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= +golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200202094626-16171245cfb2 h1:CCH4IOTTfewWjGOlSp+zGcjutRKlBEZQ6wTn8ozI/nI= +golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b h1:0mm1VjtFUOIlE1SbDlwjYaDxZVDP2S5ou6y0gSgXHu8= +golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/net v0.0.0-20201021035429-f5854403a974 h1:IX6qOQeG5uLjB/hjjwjedwfjND0hgjPMMyO1RoIXQNI= +golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20201110031124-69a78807bb2b h1:uwuIcX0g4Yl1NC5XAz37xsr2lTtcqevgzYNVt49waME= +golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= +golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45 h1:SVwTIAaPC2U/AvvLNZ2a7OVsmBpC8L5BlwK1whH3hm0= +golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY= +golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 h1:SQFwaSi55rU7vdNs9Yr0Z324VNlrF+0wMqRXT4St8ck= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20171017063910-8dbc5d05d6ed/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190616124812-15dcb6c0061f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200331124033-c3d80250170d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200923182605-d9f96fdee20d h1:L/IKR6COd7ubZrs2oTnTi73IhgqJ71c9s80WsQnh0Es= +golang.org/x/sys v0.0.0-20200923182605-d9f96fdee20d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f h1:+Nyd8tzPX9R7BWHguqsrbFdRx3WQ/1ib8I44HXV5yTA= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201112073958-5cba982894dd h1:5CtCZbICpIOFdgO940moixOPjc0178IU44m4EjOO5IY= +golang.org/x/sys v0.0.0-20201112073958-5cba982894dd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.4 h1:0YWbFKbhXG/wIiuHDSKpS0Iy7FSA+u45VtBMfQcFTTc= +golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 h1:SvFZT6jyqRaOeXpc5h/JSfZenJ2O330aBsf7JfSUXmQ= +golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190614205625-5aca471b1d59/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190628153133-6cdbf07be9d0/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190816200558-6889da9d5479/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20190911174233-4f2ddba30aff/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e h1:aZzprAO9/8oim3qStq3wc1Xuxx4QmAGriC4VU4ojemQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= +golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= +gonum.org/v1/gonum v0.0.0-20190902003836-43865b531bee/go.mod h1:9mxDZsDKxgMAuccQkewq682L+0eCu4dCN2yonUJTCLU= +gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= +gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= +google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= +google.golang.org/api v0.7.0/go.mod h1:WtwebWUNSVBH/HAw79HIFXZNqEvBhG+Ra+ax0hx3E3M= +google.golang.org/api v0.8.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= +google.golang.org/api v0.9.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= +google.golang.org/api v0.10.0 h1:7tmAxx3oKE98VMZ+SBZzvYYWRQ9HODBxmC8mXUsraSQ= +google.golang.org/api v0.10.0/go.mod h1:o4eAsZoiT+ibD93RtjEohWalFOjRDx6CVaqeizhEnKg= +google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= +google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.5.0 h1:KxkO13IPW4Lslp2bz+KHP2E3gtFlrIGNThxkZQ3g+4c= +google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= +google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= +google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM= +google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= +google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= +google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190425155659-357c62f0e4bb/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190502173448-54afdca5d873 h1:nfPFGzJkUDX6uBmpN/pSw7MbOAWegH5QDQuoXFHedLg= +google.golang.org/genproto v0.0.0-20190502173448-54afdca5d873/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= +google.golang.org/genproto v0.0.0-20190801165951-fa694d86fc64/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= +google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51 h1:Ex1mq5jaJof+kRnYi3SlYJ8KKa9Ao3NHyIT5XJ1gF6U= +google.golang.org/genproto v0.0.0-20190911173649-1774047e7e51/go.mod h1:IbNlFCBrqXvoKpeg0TB2l7cyZUmoaFKYIwrEpbDKLA8= +google.golang.org/genproto v0.0.0-20191007204434-a023cd5227bd h1:84VQPzup3IpKLxuIAZjHMhVjJ8fZ4/i3yUnj3k6fUdw= +google.golang.org/genproto v0.0.0-20191007204434-a023cd5227bd/go.mod h1:n3cpQtvxv34hfy77yVDNjmbRyujviMdxYliBSkLhpCc= +google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 h1:+kGHl1aib/qcwaRi1CbqBZ1rk19r85MNUf8HaBghugY= +google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= +google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= +google.golang.org/grpc v1.20.1 h1:Hz2g2wirWK7H0qIIhGIqRGTuMwTE8HEKFnDZZ7lm9NU= +google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= +google.golang.org/grpc v1.21.1 h1:j6XxA85m/6txkUCHvzlV5f+HBNl/1r5cZ2A/3IEFOO8= +google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= +google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= +google.golang.org/grpc v1.24.0 h1:vb/1TCsVn3DcJlQ0Gs1yB1pKI6Do2/QNwxdKqmc/b0s= +google.golang.org/grpc v1.24.0/go.mod h1:XDChyiUovWa60DnaeDeZmSW86xtLtjtZbwvSiRnRtcA= +google.golang.org/grpc v1.27.0 h1:rRYRFMVgRv6E0D70Skyfsr28tDXIuuPZyWGMPdMcnXg= +google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.0 h1:4MY060fB1DLGMB/7MBTLnwQUY6+F09GEiz6SsrNqyzM= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c= +google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= +gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= +gopkg.in/square/go-jose.v2 v2.3.1 h1:SK5KegNXmKmqE342YYN2qPHEnUYeoMiXXl1poUlI+o4= +gopkg.in/square/go-jose.v2 v2.3.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4 h1:/eiJrUcujPVeJ3xlSWaiNi3uSVmDGBK1pDHUHAnao1I= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= +honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= +k8s.io/api v0.0.0-20181213150558-05914d821849 h1:WZFcFPXmLR7g5CxQNmjWv0mg8qulJLxDghbzS4pQtzY= +k8s.io/api v0.0.0-20181213150558-05914d821849/go.mod h1:iuAfoD4hCxJ8Onx9kaTIt30j7jUFS00AXQi6QMi99vA= +k8s.io/api v0.20.2 h1:y/HR22XDZY3pniu9hIFDLpUCPq2w5eQ6aV/VFQ7uJMw= +k8s.io/api v0.20.2/go.mod h1:d7n6Ehyzx+S+cE3VhTGfVNNqtGc/oL9DCdYYahlurV8= +k8s.io/apimachinery v0.0.0-20181127025237-2b1284ed4c93 h1:tT6oQBi0qwLbbZSfDkdIsb23EwaLY85hoAV4SpXfdao= +k8s.io/apimachinery v0.0.0-20181127025237-2b1284ed4c93/go.mod h1:ccL7Eh7zubPUSh9A3USN90/OzHNSVN6zxzde07TDCL0= +k8s.io/apimachinery v0.20.2 h1:hFx6Sbt1oG0n6DZ+g4bFt5f6BoMkOjKWsQFu077M3Vg= +k8s.io/apimachinery v0.20.2/go.mod h1:WlLqWAHZGg07AeltaI0MV5uk1Omp8xaN0JGLY6gkRpU= +k8s.io/client-go v1.5.1 h1:XaX/lo2/u3/pmFau8HN+sB5C/b4dc4Dmm2eXjBH4p1E= +k8s.io/client-go v10.0.0+incompatible h1:F1IqCqw7oMBzDkqlcBymRq1450wD0eNqLE9jzUrIi34= +k8s.io/client-go v10.0.0+incompatible/go.mod h1:7vJpHMYJwNQCWgzmNV+VYUl1zCObLyodBc8nIyt8L5s= +k8s.io/client-go v11.0.0+incompatible h1:LBbX2+lOwY9flffWlJM7f1Ct8V2SRNiMRDFeiwnJo9o= +k8s.io/client-go v11.0.0+incompatible/go.mod h1:7vJpHMYJwNQCWgzmNV+VYUl1zCObLyodBc8nIyt8L5s= +k8s.io/gengo v0.0.0-20200413195148-3a45101e95ac/go.mod h1:ezvh/TsK7cY6rbqRK0oQQ8IAqLxYwwyPxAX1Pzy0ii0= +k8s.io/klog v1.0.0 h1:Pt+yjF5aB1xDSVbau4VsWe+dQNzA0qv1LlXdC2dF6Q8= +k8s.io/klog v1.0.0/go.mod h1:4Bi6QPql/J/LkTDqv7R/cd3hPo4k2DG6Ptcz060Ez5I= +k8s.io/klog/v2 v2.0.0/go.mod h1:PBfzABfn139FHAV07az/IF9Wp1bkk3vpT2XSJ76fSDE= +k8s.io/klog/v2 v2.4.0 h1:7+X0fUguPyrKEC4WjH8iGDg3laWgMo5tMnRTIGTTxGQ= +k8s.io/klog/v2 v2.4.0/go.mod h1:Od+F08eJP+W3HUb4pSrPpgp9DGU4GzlpG/TmITuYh/Y= +k8s.io/kube-openapi v0.0.0-20201113171705-d219536bb9fd/go.mod h1:WOJ3KddDSol4tAGcJo0Tvi+dK12EcqSLqcWsryKMpfM= +k8s.io/utils v0.0.0-20210111153108-fddb29f9d009 h1:0T5IaWHO3sJTEmCP6mUlBvMukxPKUQWqiI/YuiBNMiQ= +k8s.io/utils v0.0.0-20210111153108-fddb29f9d009/go.mod h1:jPW/WVKK9YHAvNhRxK0md/EJ228hCsBRufyofKtW8HA= +rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= +rsc.io/goversion v1.0.0/go.mod h1:Eih9y/uIBS3ulggl7KNJ09xGSLcuNaLgmvvqa07sgfo= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= +sigs.k8s.io/structured-merge-diff/v4 v4.0.2 h1:YHQV7Dajm86OuqnIR6zAelnDWBRjo+YhYV9PmGrh1s8= +sigs.k8s.io/structured-merge-diff/v4 v4.0.2/go.mod h1:bJZC9H9iH24zzfZ/41RGcq60oK1F7G282QMXDPYydCw= +sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o= +sigs.k8s.io/yaml v1.2.0 h1:kr/MCeFWJWTwyaHoR9c8EjH9OumOmoF9YGiZd7lFm/Q= +sigs.k8s.io/yaml v1.2.0/go.mod h1:yfXDCHCao9+ENCvLSE62v9VSji2MKu5jeNfTrofGhJc= +v.io v0.1.5 h1:FQY1JUhIs1wts/NVSqAr2I9O/qPhJMmteXyLvv4lLOE= +v.io v0.1.5/go.mod h1:Apu/AQfn7lq+o3m+ReLtlrKxkZTTo2p6mLXlioAUWA0= +v.io v0.1.8 h1:9pMw27Epqq4CinvudXX5fOlERasYzZYqoxDYxwXda2U= +v.io v0.1.8/go.mod h1:63LjtWsxMaRKYc9sMM0rXCYxkhZ1/1aNJOS6He4qkPU= +v.io v0.1.15 h1:bMmuOBU8ErCSyY3SZXLGla1SMMO1wBr18GwZGCBoSHQ= +v.io v0.1.15/go.mod h1:Ort0a9YYK5eDJZ1Bc5m5hroWkygXzkifJzmlRpaGbWk= +v.io/x/lib v0.1.4 h1:PCDfluqBeRbA7OgDIs9tIpT+z6ZNZ5VMeR+t7h/K2ig= +v.io/x/lib v0.1.4/go.mod h1:maU79RWqiiC9ARbvS+2Q8tqZUnQiHxeJDriXcW7cYg8= +v.io/x/lib v0.1.5 h1:Nv82WPqT0W9vHdc2CnLEOQepmSCsvEKUQHFwf/kXc6s= +v.io/x/lib v0.1.5/go.mod h1:aLm+mPXyXf4Vd/n+1f4LcSQFFgqNhNzwQvHYfXoOLlE= +v.io/x/lib v0.1.6/go.mod h1:aLm+mPXyXf4Vd/n+1f4LcSQFFgqNhNzwQvHYfXoOLlE= +v.io/x/lib v0.1.7 h1:FXaiEHSrk6Jduc9JCNCzCVFX6ps/NbUaS45ppufC8Do= +v.io/x/lib v0.1.7/go.mod h1:aLm+mPXyXf4Vd/n+1f4LcSQFFgqNhNzwQvHYfXoOLlE= +v.io/x/ref/internal/logger v0.1.1 h1:FZwqC6myQ4xMTgz6jIgZwQlIpuy/PtwTHY95qQFEzso= +v.io/x/ref/internal/logger v0.1.1/go.mod h1:00nuJdZEVCzMOn9y474jZ+e6B9R/ydLW7d6IQFl/NHU= +v.io/x/ref/lib/flags/sitedefaults v0.1.1 h1:jAeEpnfOK5ddeqjqJPi2X5WgbeYkCA+ogawkJyDgZMI= +v.io/x/ref/lib/flags/sitedefaults v0.1.1/go.mod h1:ew4Igo60KMBDYhnxH6l7P+qBCJiqR8PVp7fJJYGqILA= diff --git a/grail/biofs/biofseventlog/biofseventlog.go b/grail/biofs/biofseventlog/biofseventlog.go new file mode 100644 index 00000000..74369d49 --- /dev/null +++ b/grail/biofs/biofseventlog/biofseventlog.go @@ -0,0 +1,66 @@ +// biofseventlog creates usage events for biofs, a GRAIL-internal program. biofs has to be internal +// because it runs fsnodefuse with some fsnode.T's derived from other internal code, but it also +// uses github.com/grailbio packages like s3file. +package biofseventlog + +import ( + "strconv" + "sync" + "time" + + "github.com/grailbio/base/config" + "github.com/grailbio/base/eventlog" + "github.com/grailbio/base/log" + "github.com/grailbio/base/must" +) + +const configName = "biofs/eventer" + +func init() { + config.Default(configName, "eventer/nop") +} + +// UsedFeature creates an event for usage of the named feature. +func UsedFeature(featureName string) { + var eventer eventlog.Eventer + must.Nil(config.Instance(configName, &eventer)) + eventer.Event("usedFeature", + "name", featureName, + "buildTime", getCoarseBuildTimestamp()) +} + +// CoarseNow returns times with precision truncated by CoarseTime. +func CoarseNow() time.Time { return CoarseTime(time.Now()) } + +// CoarseTime truncates t's precision to a nearby week. It's used to improve event log anonymity. +func CoarseTime(t time.Time) time.Time { + weekMillis := 7 * 24 * time.Hour.Milliseconds() + now := t.UnixMilli() + now /= weekMillis + now *= weekMillis + return time.UnixMilli(now) +} + +var ( + buildTimestamp string + coarseBuildTimestamp = "unknown" + buildTimestampOnce sync.Once +) + +// getCoarseBuildTimestamp returns the (probably bazel-injected) build timestamp +// with precision truncated to CoarseTime, or a message if data is unavailable. +func getCoarseBuildTimestamp() string { + buildTimestampOnce.Do(func() { + if buildTimestamp == "" { + return + } + buildSecs, err := strconv.ParseInt(buildTimestamp, 10, 64) + if err != nil { + log.Error.Printf("biofseventlog: error parsing build timestamp: %v", err) + return + } + coarseBuildTime := CoarseTime(time.Unix(buildSecs, 0)) + coarseBuildTimestamp = coarseBuildTime.Format("20060102") + }) + return coarseBuildTimestamp +} diff --git a/grail/biofs/biofseventlog/biofseventlog_test.go b/grail/biofs/biofseventlog/biofseventlog_test.go new file mode 100644 index 00000000..d1c98902 --- /dev/null +++ b/grail/biofs/biofseventlog/biofseventlog_test.go @@ -0,0 +1,31 @@ +package biofseventlog + +import ( + "math/rand" + "testing" + "time" +) + +func TestCoarseTime(t *testing.T) { + const ( + weeks = 10 + N = 10000 + ) + minMillis := coarseMillis(1600000000000) // Arbitrary time in 2020. + maxMillis := minMillis + 7*24*weeks*time.Hour.Milliseconds() + + gotCoarseMillis := map[int64]struct{}{} + rnd := rand.New(rand.NewSource(1)) + for i := 0; i < N; i++ { + fineMillis := minMillis + rnd.Int63n(maxMillis-minMillis) + gotCoarseMillis[coarseMillis(fineMillis)] = struct{}{} + } + + if got := len(gotCoarseMillis); got != weeks { + t.Errorf("got %d, want %d", got, weeks) + } +} + +func coarseMillis(millis int64) int64 { + return CoarseTime(time.UnixMilli(millis)).UnixMilli() +} diff --git a/grail/data/v23data/blessings.go b/grail/data/v23data/blessings.go deleted file mode 100644 index a7c0dd33..00000000 --- a/grail/data/v23data/blessings.go +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright 2018 GRAIL, Inc. All rights reserved. -// Use of this source code is governed by the Apache-2.0 -// license that can be found in the LICENSE file. - -package v23data - -import ( - "bytes" - "encoding/base64" - "fmt" - - "v.io/v23" - "v.io/v23/context" - "v.io/v23/security" - "v.io/v23/vom" - "v.io/x/lib/vlog" -) - -const ( - // Under the hood these roots are public certs so having them in here poses no - // security risks. - // - // Pipeline is using a different root. - // - // Public key Pattern - // 59:20:19:97:dd:49:5e:3f:35:39:fe:d3:c7:d1:42:95 [pipeline] - // - // How the data was obtained: - // - // principal -v23.credentials ~/.v23-pipeline-razvan get default | principal dumproots - - pipelineRoot = "gV0cAgAUdi5pby92MjMvdW5pcXVlaWQuSWQBAgIQ4VsyBgAYdi5pby92MjMvc2VjdXJpdHkuQ2F2ZWF0AQIAAklkAS_hAAhQYXJhbVZvbQEn4eFZBAMBLuFhHAAAFnYuaW8vdjIzL3NlY3VyaXR5Lkhhc2gBA-FfQgYAG3YuaW8vdjIzL3NlY3VyaXR5LlNpZ25hdHVyZQEEAAdQdXJwb3NlASfhAARIYXNoATHhAAFSASfhAAFTASfh4VdZBgAddi5pby92MjMvc2VjdXJpdHkuQ2VydGlmaWNhdGUBBAAJRXh0ZW5zaW9uAQPhAAlQdWJsaWNLZXkBJ-EAB0NhdmVhdHMBLeEACVNpZ25hdHVyZQEw4eFVBAMBLOFTBAMBK-FROwYAH3YuaW8vdjIzL3NlY3VyaXR5LldpcmVCbGVzc2luZ3MBAQARQ2VydGlmaWNhdGVDaGFpbnMBKuHhUv--AAEBAAhwaXBlbGluZQFbMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEOP0MkPPMlEH4OiRm98-QotTsWWYgN229sjCJHq_TB9aolgRBFmpL0QI6qSa5E4SvAtMXuAuZqd9ackMBtWK1SAMAAkIxAQZTSEEyNTYCIPJJxI8quSubxV7xHQhoj_vfejvwUmDEu81pJzOSGbyoAyBboUYxo543p_sJOMf5NJoPngWHeDvGIAjBlWrt6RruQeHh4Q==" - - // Pipeline staging is using yet another different root. - // - // Public key Pattern - // b9:e1:c3:ef:26:b7:8a:88:86:8b:b7:2f:e8:d8:1b:7c [staging-pipeline] - pipelineStagingRoot = "gV0cAgAUdi5pby92MjMvdW5pcXVlaWQuSWQBAgIQ4VsyBgAYdi5pby92MjMvc2VjdXJpdHkuQ2F2ZWF0AQIAAklkAS_hAAhQYXJhbVZvbQEn4eFZBAMBLuFhHAAAFnYuaW8vdjIzL3NlY3VyaXR5Lkhhc2gBA-FfQgYAG3YuaW8vdjIzL3NlY3VyaXR5LlNpZ25hdHVyZQEEAAdQdXJwb3NlASfhAARIYXNoATHhAAFSASfhAAFTASfh4VdZBgAddi5pby92MjMvc2VjdXJpdHkuQ2VydGlmaWNhdGUBBAAJRXh0ZW5zaW9uAQPhAAlQdWJsaWNLZXkBJ-EAB0NhdmVhdHMBLeEACVNpZ25hdHVyZQEw4eFVBAMBLOFTBAMBK-FROwYAH3YuaW8vdjIzL3NlY3VyaXR5LldpcmVCbGVzc2luZ3MBAQARQ2VydGlmaWNhdGVDaGFpbnMBKuHhUv_GAAEBABBzdGFnaW5nLXBpcGVsaW5lAVswWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAARyGfli1xVYDZsy2puv2_cwERx_1JnRQxJ8HXmz2juBG3N61-U1gX1OazINr_MRTO5jBxs6ZNIQ7PxrZOjJ1RCHAwACQjEBBlNIQTI1NgIgpHN4wzd-h17Vps9k91N2rwrcQaQTs2pd2LvDhDnzzX8DIA6LRkRbhp7pGr2JBgGwqsbNgh9cfxdjmoETLpfBmR-h4eHh" -) - -// InjectPipelineBlessings injects the non-v23.grail.com roots used by the -// pipeline and pipeline-staging. The ticket-server hand outs blessings to users -// and server using a v23.grail.com prefix. In order to allow the v23.grail.com -// blessings to talk to the pipeline and pipeline-staging we need to add these -// roots explicitly. Currently this is done by grail-access and grail-role, the -// two client tools that retrieve blessings from the ticket-server. -func InjectPipelineBlessings(ctx *context.T) error { - principal := v23.GetPrincipal(ctx) - - pipeline, err := decodeBlessings(pipelineRoot) - if err != nil { - vlog.Error(err) - return fmt.Errorf("failed to decode the pipeline root blessings: %v", err) - } - if err := security.AddToRoots(principal, pipeline); err != nil { - vlog.Error(err) - return fmt.Errorf("failed to add the pipeline root") - } - - pipelineStaging, err := decodeBlessings(pipelineStagingRoot) - if err != nil { - vlog.Error(err) - return fmt.Errorf("failed to decode the pipeline staging root blessings: %v", err) - } - if err := security.AddToRoots(principal, pipelineStaging); err != nil { - vlog.Error(err) - return fmt.Errorf("failed to add the pipeline staging root") - } - - return nil -} - -func decodeBlessings(s string) (security.Blessings, error) { - b, err := base64.URLEncoding.DecodeString(s) - if err != nil { - return security.Blessings{}, err - } - - dec := vom.NewDecoder(bytes.NewBuffer(b)) - var blessings security.Blessings - return blessings, dec.Decode(&blessings) -} diff --git a/grail/init.go b/grail/init.go index 2f950b0f..015a13cd 100644 --- a/grail/init.go +++ b/grail/init.go @@ -9,11 +9,17 @@ package grail import ( "flag" "os" + "strings" "sync" "github.com/google/gops/agent" + "github.com/grailbio/base/config" "github.com/grailbio/base/log" "github.com/grailbio/base/pprof" + "github.com/grailbio/base/shutdown" + + // GRAIL applications require the AWS ticket provider. + _ "github.com/grailbio/base/config/awsticket" "v.io/x/lib/vlog" ) @@ -22,16 +28,19 @@ import ( type Shutdown func() var ( - initialized = false - mu = sync.Mutex{} - shutdownHandlers = []Shutdown{} - gopsFlag = flag.Bool("gops", false, "enable the gops listener") + initialized = false + mu = sync.Mutex{} + gopsFlag = flag.Bool("gops", false, "enable the gops listener") ) // Init should be called once at the beginning at each executable that doesn't // use the github.com/grailbio/base/cmdutil. The Shutdown function should be called to // perform the final cleanup (closing logs for example). // +// Init also applies a default configuration profile (see package +// github.com/grailbio/base/config), and adds profile flags to the +// default flag set. The default profile path used is $HOME/grail/profile. +// // Note that this function will call flag.Parse(). // // Suggested use: @@ -45,15 +54,27 @@ func Init() Shutdown { } initialized = true mu.Unlock() - flag.CommandLine.Init(os.Args[0], flag.ContinueOnError) - err := flag.CommandLine.Parse(os.Args[1:]) - if err == flag.ErrHelp { - os.Exit(0) - } else if err != nil { - os.Exit(2) + + profile := config.New() + config.NewDefault = func() *config.Profile { + if err := profile.Parse(strings.NewReader(defaultProfile)); err != nil { + panic("grail: error in default profile: " + err.Error()) + } + if err := profile.ProcessFlags(); err != nil { + log.Fatal(err) + } + return profile + } + profile.RegisterFlags(flag.CommandLine, "", os.ExpandEnv("$HOME/grail/profile")) + flag.Parse() + if err := vlog.ConfigureLibraryLoggerFromFlags(); err != nil { + vlog.Error(err) } - vlog.ConfigureLibraryLoggerFromFlags() - log.SetOutputter(vlogOutputter{}) + log.SetOutputter(VlogOutputter{}) + if profile.NeedProcessFlags() { + _ = config.Application() + } + pprof.Start() _, ok := os.LookupEnv("GOPS") if ok || *gopsFlag { @@ -62,51 +83,8 @@ func Init() Shutdown { } } return func() { - RunShutdownCallbacks() + shutdown.Run() pprof.Write(1) vlog.FlushLog() } } - -// RegisterShutdownCallback registers a function to be run in the Init shutdown -// callback. The callbacks will run in the reverse order of registration. -func RegisterShutdownCallback(cb Shutdown) { - mu.Lock() - shutdownHandlers = append(shutdownHandlers, cb) - mu.Unlock() -} - -// RunShutdownCallbacks run callbacks added in RegisterShutdownCallbacks. This -// function is not for general use. -func RunShutdownCallbacks() { - mu.Lock() - cbs := shutdownHandlers - shutdownHandlers = nil - mu.Unlock() - for i := len(cbs) - 1; i >= 0; i-- { - cbs[i]() - } -} - -type vlogOutputter struct{} - -func (vlogOutputter) Level() log.Level { - if vlog.V(1) { - return log.Debug - } else { - return log.Info - } -} - -func (vlogOutputter) Output(calldepth int, level log.Level, s string) error { - switch level { - case log.Off: - case log.Error: - vlog.ErrorDepth(calldepth+1, s) - case log.Info: - vlog.InfoDepth(calldepth+1, s) - default: - vlog.VI(vlog.Level(level)).InfoDepth(calldepth+1, s) - } - return nil -} diff --git a/grail/log.go b/grail/log.go new file mode 100644 index 00000000..d61c2730 --- /dev/null +++ b/grail/log.go @@ -0,0 +1,34 @@ +package grail + +import ( + "github.com/grailbio/base/log" + "v.io/x/lib/vlog" +) + +// VlogOutputter implements base/log.Outputter backed by vlog. +type VlogOutputter struct{} + +func (VlogOutputter) Level() log.Level { + if vlog.V(1) { + return log.Debug + } else { + return log.Info + } +} + +func (VlogOutputter) Output(calldepth int, level log.Level, s string) error { + // Notice that we do not add 1 to the call depth. In vlog, 0 depth means + // that the caller's file/line will be used. This is different from the log + // and github.com/grailbio/base/log packages, where that's the behavior you + // get with depth 1. + switch level { + case log.Off: + case log.Error: + vlog.ErrorDepth(calldepth, s) + case log.Info: + vlog.InfoDepth(calldepth, s) + default: + vlog.VI(vlog.Level(level)).InfoDepth(calldepth, s) + } + return nil +} diff --git a/grail/profile.go b/grail/profile.go new file mode 100644 index 00000000..f445c33c --- /dev/null +++ b/grail/profile.go @@ -0,0 +1,29 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package grail + +// defaultProfile contains default configuration for use within GRAIL. +// TODO(marius): replace this with an account-based profile +const defaultProfile = ` +// Use the ticket provider for AWS credentials by default. +// Our default region is us-west-2. +param aws/ticket ( + region = "us-west-2" + path = "tickets/eng/dev/aws" +) + +instance aws aws/ticket + +// Bigmachine defaults for GRAIL (eng/dev). +// This should eventually be replaced by profile auto loading. +param bigmachine/ec2system ( + aws = aws + instance-profile = "arn:aws:iam::619867110810:instance-profile/bigmachine" + security-group = "sg-7390e50c" +) + +param bigmachine/ec2tensorflow base = bigmachine/ec2system + +` diff --git a/gtl/README.md b/gtl/README.md index feb33016..5171e117 100644 --- a/gtl/README.md +++ b/gtl/README.md @@ -4,6 +4,9 @@ This directory contains algorithms written using a pidgin templates. # Directory contents +- rcu_map: concurrent hash map. Readers can access the map without memory + barriers. + - unsafe: unsafe, but efficient slice operations, including casting between string and []byte and uninitialized slice resizing. diff --git a/gtl/generate.py b/gtl/generate.py index cd7460ce..8977b227 100755 --- a/gtl/generate.py +++ b/gtl/generate.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3 """ Simple template engine. @@ -9,45 +9,70 @@ Example: - generate.py --prefix= -DELEM=int32 --package=tests --output=unsafe.go ../unsafe.go.tpl + generate.py --prefix=int --PREFIX=Int -DELEM=int32 --package=tests --output=unsafe.go ../unsafe.go.tpl ---prefix=ARG replaces all occurrences of "ZZ" with "ARG". +--prefix=arg replaces all occurrences of "zz" with "arg". + +--PREFIX=Arg replaces all occurrences of "ZZ" with "Arg". If --Prefix is omitted, + it defaults to --prefix, with its first letter uppercased. --Dfrom=to replaces all occurrences of "from" with "to". This flag can be set multiple times. --output=path specifies the output file name. +--goimports=path locates goimports binary. If empty, search on PATH instead. + +--header=str Comment that'll be the first line of the generated file. If empty, use a default. + """ import re import argparse +import subprocess import sys + def main() -> None: "Main application entry point" parser = argparse.ArgumentParser() parser.add_argument( - '--package', default='funkymonkeypackage', - help="Occurrences of 'PACKAGE' in the template are replaced with this string.") + "--package", + default="funkymonkeypackage", + help="Occurrences of 'PACKAGE' in the template are replaced with this string.", + ) + parser.add_argument( + "--prefix", + default="funkymonkey", + help="Occurrences of 'zz' in the template are replaced with this string", + ) + parser.add_argument( + "--PREFIX", + default="", + help="Occurrences of 'ZZ' in the template are replaced with this string", + ) parser.add_argument( - '--prefix', default='funkymonkey', - help="Occurrences of 'ZZ' in the template are replaced with this string") + "-o", + "--output", + default="", + help="Output destination. Defaults to standard output", + ) parser.add_argument( - '-o', '--output', default='', - help="Output destination. Defaults to standard output") + "-D", "--define", default=[], type=str, action="append", help="str=replacement" + ) parser.add_argument( - '-D', '--define', default=[], - type=str, action='append', - help="str=replacement") + "--goimports", default="", help="Path to goimports. Defaults to search PATH", + ) parser.add_argument( - 'template', help="*.go.tpl file to process") + "--header", + default="", + help="First-line comment of the generated file. If empty, generate one.", + ) + parser.add_argument("template", help="*.go.tpl file to process") opts = parser.parse_args() - - if opts.output == '': - out = sys.stdout - else: - out = open(opts.output, 'w') + if not opts.PREFIX: + if opts.prefix: + opts.PREFIX = opts.prefix[0].upper() + opts.prefix[1:] defines = [] for d in opts.define: @@ -56,12 +81,35 @@ def main() -> None: raise Exception("Invalid -D option: " + d) defines.append((m[1], m[2])) - print('// Code generated from \"', ' '.join(sys.argv), '\". DO NOT EDIT.', file=out) - for line in open(opts.template, 'r').readlines(): - line = line.replace('ZZ', opts.prefix) - line = line.replace('PACKAGE', opts.package) + out = sys.stdout + if opts.output != "": + out = open(opts.output, "w") + proc = subprocess.Popen( + [opts.goimports or "goimports"], + stdin=subprocess.PIPE, + stdout=out, + universal_newlines=True, + ) + + if opts.header: + header = opts.header + else: + header = '// Code generated by "' + " ".join(sys.argv) + '". DO NOT EDIT.\n' + print(header, file=proc.stdin) + + for line in open(opts.template, "r").readlines(): + line = line.replace("ZZ", opts.PREFIX) + line = line.replace("zz", opts.prefix) + line = line.replace("PACKAGE", opts.package) for def_from, def_to in defines: line = line.replace(def_from, def_to) - print(line, end='', file=out) + print(line, end="", file=proc.stdin) + proc.stdin.close() + status = proc.wait() + if status != 0: + raise Exception("goimports failed: {}".format(status)) + if out != sys.stdout: + out.close() + main() diff --git a/gtl/generate_randomized_freepool.py b/gtl/generate_randomized_freepool.py new file mode 100755 index 00000000..485b0f3a --- /dev/null +++ b/gtl/generate_randomized_freepool.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 + +import argparse +import logging +import os +import re +import subprocess +import sys +from typing import List + + +def main() -> None: + logging.basicConfig(level=logging.DEBUG) + output = None + package = None + generate_argv: List[str] = [] + for arg in sys.argv[1:]: + m = re.match("^--package=(.*)", arg) + if m: + package = m[1] + m = re.match("^--output=(.*)", arg) + if m: + output = m[1] + else: + generate_argv.append(arg) + + if not output or not package: + raise Exception("--output and --package not set") + + gtl_dir = os.path.dirname(sys.argv[0]) + output_dir = os.path.dirname(output) + if not output_dir: + output_dir = "." + + cmdline = ( + [sys.executable, os.path.join(gtl_dir, "generate.py"), f"--output={output}.go"] + + generate_argv + + [os.path.join(gtl_dir, "randomized_freepool.go.tpl")] + ) + logging.debug("CMDLINE %s", cmdline) + subprocess.check_call( + [sys.executable, os.path.join(gtl_dir, "generate.py"), f"--output={output}.go"] + + generate_argv + + [os.path.join(gtl_dir, "randomized_freepool.go.tpl")] + ) + subprocess.check_call( + [ + sys.executable, + os.path.join(gtl_dir, "generate.py"), + f"--output={output}_race.go", + ] + + generate_argv + + [os.path.join(gtl_dir, "randomized_freepool_race.go.tpl")] + ) + + with open(f"{output_dir}/randomized_freepool_internal.s", "w") as fd: + fd.write( + """// Code generated by generate_randomized_freepool.py. DO NOT EDIT. + +// Dummy file to force the go compiler to honor go:linkname directives. See +// +// https://github.com/golang/go/issues/15006 +// http://www.alangpierce.com/blog/2016/03/17/adventures-in-go-accessing-unexported-functions/ +""" + ) + + with open(f"{output_dir}/randomized_freepool_internal.go", "w") as fd: + fd.write( + f"""// Code generated by generate_randomized_freepool.py. DO NOT EDIT. +package {package} + +// This import is needed to use go:linkname. +import _ "unsafe" + +// The following functions are defined in go runtime. To use them, we need to +// import "unsafe", and elsewhere in this package, import "C" to force compiler +// to recognize the "go:linktime" directive. Some of the details are explained +// in the below blog post. +// +// procPin() pins the caller to the current processor, and returns the processor +// id in range [0,GOMAXPROCS). procUnpin() undos the effect of procPin(). +// +// http://www.alangpierce.com/blog/2016/03/17/adventures-in-go-accessing-unexported-functions/ + +//go:linkname runtime_procPin sync.runtime_procPin +//go:nosplit +func runtime_procPin() int + +//go:linkname runtime_procUnpin sync.runtime_procUnpin +//go:nosplit +func runtime_procUnpin() + +//go:linkname fastrandn sync.fastrandn +func fastrandn(n uint32) uint32 +""" + ) + + +main() diff --git a/gtl/randomized_freepool.go.tpl b/gtl/randomized_freepool.go.tpl index 293c21b6..52149662 100644 --- a/gtl/randomized_freepool.go.tpl +++ b/gtl/randomized_freepool.go.tpl @@ -1,3 +1,5 @@ +// +build !race + // ZZFreePool is thread-safe pool that uses power-of-two loadbalancing across // CPUs. @@ -28,8 +30,8 @@ // //go:nosplit // func runtime_procUnpin() // -// //go:linkname fastrand sync.fastrand -// func fastrand() uint32 +// //go:linkname fastrandn sync.fastrandn +// func fastrandn(n uint32) uint32 // // 2. An empty .s file. @@ -54,17 +56,17 @@ import ( // if needed. type ZZFreePool struct { new func() ELEM - local []ZZpoolLocal + local []zzPoolLocal maxLocalSize int64 } const ( - ZZmaxPrivateElems = 4 - ZZcacheLineSize = 64 + zzMaxPrivateElems = 4 + zzCacheLineSize = 64 ) -type ZZpoolLocalInner struct { - private [ZZmaxPrivateElems]ELEM // Can be used only by the respective P. +type zzPoolLocalInner struct { + private [zzMaxPrivateElems]ELEM // Can be used only by the respective P. privateSize int shared []ELEM // Can be used by any P. @@ -72,10 +74,10 @@ type ZZpoolLocalInner struct { mu sync.Mutex // Protects shared. } -type ZZpoolLocal struct { - ZZpoolLocalInner +type zzPoolLocal struct { + zzPoolLocalInner // Pad prevents false sharing. - pad [ZZcacheLineSize - unsafe.Sizeof(ZZpoolLocalInner{})%ZZcacheLineSize]byte + pad [zzCacheLineSize - unsafe.Sizeof(zzPoolLocalInner{})%zzCacheLineSize]byte } // NewZZFreePool creates a new free object pool. new should create a new @@ -93,13 +95,13 @@ func NewZZFreePool(new func() ELEM, maxSize int) *ZZFreePool { } p := &ZZFreePool{ new: new, - local: make([]ZZpoolLocal, maxProcs), + local: make([]zzPoolLocal, maxProcs), maxLocalSize: int64(maxLocalSize), } return p } -func (p *ZZFreePool) pin() *ZZpoolLocal { +func (p *ZZFreePool) pin() *zzPoolLocal { pid := runtime_procPin() if int(pid) >= len(p.local) { panic(pid) @@ -112,7 +114,7 @@ func (p *ZZFreePool) pin() *ZZpoolLocal { func (p *ZZFreePool) Put(x ELEM) { done := false l := p.pin() - if l.privateSize < ZZmaxPrivateElems { + if l.privateSize < zzMaxPrivateElems { l.private[l.privateSize] = x l.privateSize++ done = true @@ -124,7 +126,7 @@ func (p *ZZFreePool) Put(x ELEM) { // queues to log(log(#queues)) . // // https://www.eecs.harvard.edu/~michaelm/postscripts/mythesis.pdf - l2 := &p.local[int(fastrand())%len(p.local)] + l2 := &p.local[int(fastrandn(uint32(len(p.local))))] lSize := atomic.LoadInt64(&l.sharedSize) l2Size := atomic.LoadInt64(&l2.sharedSize) if l2Size < lSize { diff --git a/gtl/randomized_freepool_race.go.tpl b/gtl/randomized_freepool_race.go.tpl new file mode 100644 index 00000000..dc0c5fa9 --- /dev/null +++ b/gtl/randomized_freepool_race.go.tpl @@ -0,0 +1,25 @@ +// +build race + +package PACKAGE + +import "sync/atomic" + +type ZZFreePool struct { + new func() ELEM + len int64 +} + +func NewZZFreePool(new func() ELEM, maxSize int) *ZZFreePool { + return &ZZFreePool{new: new} +} + +func (p *ZZFreePool) Put(x ELEM) { + atomic.AddInt64(&p.len, -1) +} + +func (p *ZZFreePool) Get() ELEM { + atomic.AddInt64(&p.len, 1) + return p.new() +} + +func (p *ZZFreePool) ApproxLen() int { return int(atomic.LoadInt64(&p.len)) } diff --git a/gtl/rcu_map.go.tpl b/gtl/rcu_map.go.tpl new file mode 100644 index 00000000..9312fdf5 --- /dev/null +++ b/gtl/rcu_map.go.tpl @@ -0,0 +1,158 @@ +package PACKAGE + +import ( + "sync/atomic" + "unsafe" +) + +// ZZMap is a concurrent map. A reader can access the map without lock, +// regardless of background updates. The writer side must coordinate using an +// external mutex if there are multiple writers. This map is linearizable. +// +// Example: +// +// m := NewZZMap(10) +// go func() { // writer +// m.Store("foo", "bar") +// }() +// go func() { // reader +// val, ok := m.Load("foo") +// } +type ZZMap struct { + p unsafe.Pointer // *zzMapState +} + +// ZZMapState represents a fixed-size chained hash table. It can store up to +// maxCapacity key/value pairs. Beyond that, the caller must create a new +// ZZMapState with a larger capacity. +type zzMapState struct { + log2Len uint // ==log2(len(table)) + mask uint64 // == ^(log2Len-1) + table []unsafe.Pointer // *zzMapNode + n int // # of objects currently stored in the table + maxCapacity int // max # of object that can be stored +} + +// ZZMapNode represents a hash bucket. +type zzMapNode struct { + key KEY + value VALUE + + // next points to the next element in the same hash bucket + next unsafe.Pointer // *zzMapNode +} + +func newZZMapState(log2Len uint) *zzMapState { + len := int(1 << log2Len) + table := &zzMapState{ + log2Len: log2Len, + mask: uint64(log2Len - 1), + table: make([]unsafe.Pointer, 1< 31 { + // TODO(saito) We could make the table to grow larger than 32 bits, but + // doing so will break 32bit builds. + panic(initialLenHint) + } + log2Len++ + } + m := ZZMap{p: unsafe.Pointer(newZZMapState(log2Len))} + return &m +} + +// Load finds a value with the given key. Returns false if not found. +func (m *ZZMap) Load(key KEY) (VALUE, bool) { + hash := HASH(key) + table := (*zzMapState)(atomic.LoadPointer(&m.p)) + b := int(hash & table.mask) + node := (*zzMapNode)(atomic.LoadPointer(&table.table[b])) + for node != nil { + if node.key == key { + return node.value, true + } + node = (*zzMapNode)(atomic.LoadPointer(&node.next)) + } + var dummy VALUE + return dummy, false +} + +// store returns false iff the table needs resizing. +func (t *zzMapState) store(key KEY, value VALUE) bool { + var ( + hash = HASH(key) + b = int(hash & t.mask) + node = (*zzMapNode)(t.table[b]) + probeLen = 0 + prevNode *zzMapNode + ) + for node != nil { + if node.key == key { + newNode := *node + newNode.value = value + if prevNode == nil { + atomic.StorePointer(&t.table[b], unsafe.Pointer(&newNode)) + } else { + atomic.StorePointer(&prevNode.next, unsafe.Pointer(&newNode)) + } + return true + } + prevNode = node + node = (*zzMapNode)(node.next) + probeLen++ + if probeLen >= 4 && t.n >= t.maxCapacity { + return false + } + } + newNode := zzMapNode{key: key, value: value} + if prevNode == nil { + atomic.StorePointer(&t.table[b], unsafe.Pointer(&newNode)) + } else { + atomic.StorePointer(&prevNode.next, unsafe.Pointer(&newNode)) + } + t.n++ + return true +} + +// Store stores the value for the given key. If the key is already in the map, +// it updates the mapping to the given value. +// +// Caution: if Store() is going to be called concurrently, it must be serialized +// externally. +func (m *ZZMap) Store(key KEY, value VALUE) { + table := (*zzMapState)(atomic.LoadPointer(&m.p)) + if table.store(key, value) { + return + } + log2Len := table.log2Len + 1 + if log2Len > 31 { + panic(log2Len) + } + newTable := newZZMapState(log2Len) + // Copy the contents of the old table over to the new table. + for _, p := range table.table { + node := (*zzMapNode)(p) + for node != nil { + if !newTable.store(node.key, node.value) { + panic(node) + } + node = (*zzMapNode)(node.next) + } + } + if !newTable.store(key, value) { + panic(key) + } + atomic.StorePointer(&m.p, unsafe.Pointer(newTable)) +} diff --git a/gtl/tests/dummy.s b/gtl/tests/dummy.s deleted file mode 100644 index 7daaa644..00000000 --- a/gtl/tests/dummy.s +++ /dev/null @@ -1,4 +0,0 @@ - // Dummy file to force the go compiler to honor go:linkname directives. See - // - // https://github.com/golang/go/issues/15006 - // http://www.alangpierce.com/blog/2016/03/17/adventures-in-go-accessing-unexported-functions/ diff --git a/gtl/tests/freepool.go b/gtl/tests/freepool.go index aa4b709e..8f020b7f 100644 --- a/gtl/tests/freepool.go +++ b/gtl/tests/freepool.go @@ -1,4 +1,5 @@ -// Code generated from " ../generate.py --prefix=byte -DMAXSIZE=128 -DELEM=[]byte --package=tests --output=freepool.go ../freepool.go.tpl ". DO NOT EDIT. +// Code generated by "../generate.py --prefix=byte --PREFIX=byte -DMAXSIZE=128 -DELEM=[]byte --package=tests --output=freepool.go ../freepool.go.tpl". DO NOT EDIT. + package tests // A freepool for a single thread. The interface is the same as sync.Pool, but diff --git a/gtl/tests/freepool_test.go b/gtl/tests/freepool_test.go index 16c5de7d..ad880250 100644 --- a/gtl/tests/freepool_test.go +++ b/gtl/tests/freepool_test.go @@ -1,6 +1,6 @@ package tests -//go:generate ../generate.py --prefix=byte -DMAXSIZE=128 -DELEM=[]byte --package=tests --output=freepool.go ../freepool.go.tpl +//go:generate ../generate.py --prefix=byte --PREFIX=byte -DMAXSIZE=128 -DELEM=[]byte --package=tests --output=freepool.go ../freepool.go.tpl import ( "testing" diff --git a/gtl/tests/int_freepool.go b/gtl/tests/int_freepool.go index 27f61d1d..f64395ee 100644 --- a/gtl/tests/int_freepool.go +++ b/gtl/tests/int_freepool.go @@ -1,4 +1,8 @@ -// Code generated from " ../generate.py --prefix=Ints -DELEM=[]int --package=tests --output=int_freepool.go ../randomized_freepool.go.tpl ". DO NOT EDIT. +// Code generated by "../generate.py --output=int_freepool.go --prefix=ints -DELEM=[]int --package=tests ../randomized_freepool.go.tpl". DO NOT EDIT. + +//go:build !race +// +build !race + // IntsFreePool is thread-safe pool that uses power-of-two loadbalancing across // CPUs. @@ -29,8 +33,8 @@ // //go:nosplit // func runtime_procUnpin() // -// //go:linkname fastrand sync.fastrand -// func fastrand() uint32 +// //go:linkname fastrandn sync.fastrandn +// func fastrandn(n uint32) uint32 // // 2. An empty .s file. @@ -45,27 +49,27 @@ import ( // IntsFreePool is a variation of sync.Pool, specialized for a concrete type. // -// - Put() performs power-of-two loadbalancing, and Get() looks only at the -// local queue. This improves the performance of Get() on many-core machines, -// at the cost of slightly more allocations. +// - Put() performs power-of-two loadbalancing, and Get() looks only at the +// local queue. This improves the performance of Get() on many-core machines, +// at the cost of slightly more allocations. // // - It assumes that GOMAXPROCS is fixed at boot. // -// - It never frees objects accumulated in the pool. We could add this feature -// if needed. +// - It never frees objects accumulated in the pool. We could add this feature +// if needed. type IntsFreePool struct { new func() []int - local []IntspoolLocal + local []intsPoolLocal maxLocalSize int64 } const ( - IntsmaxPrivateElems = 4 - IntscacheLineSize = 64 + intsMaxPrivateElems = 4 + intsCacheLineSize = 64 ) -type IntspoolLocalInner struct { - private [IntsmaxPrivateElems][]int // Can be used only by the respective P. +type intsPoolLocalInner struct { + private [intsMaxPrivateElems][]int // Can be used only by the respective P. privateSize int shared [][]int // Can be used by any P. @@ -73,10 +77,10 @@ type IntspoolLocalInner struct { mu sync.Mutex // Protects shared. } -type IntspoolLocal struct { - IntspoolLocalInner - // Prevents false sharing. - pad [IntscacheLineSize - unsafe.Sizeof(IntspoolLocalInner{})%IntscacheLineSize]byte +type intsPoolLocal struct { + intsPoolLocalInner + // Pad prevents false sharing. + pad [intsCacheLineSize - unsafe.Sizeof(intsPoolLocalInner{})%intsCacheLineSize]byte } // NewIntsFreePool creates a new free object pool. new should create a new @@ -94,13 +98,13 @@ func NewIntsFreePool(new func() []int, maxSize int) *IntsFreePool { } p := &IntsFreePool{ new: new, - local: make([]IntspoolLocal, maxProcs), + local: make([]intsPoolLocal, maxProcs), maxLocalSize: int64(maxLocalSize), } return p } -func (p *IntsFreePool) pin() *IntspoolLocal { +func (p *IntsFreePool) pin() *intsPoolLocal { pid := runtime_procPin() if int(pid) >= len(p.local) { panic(pid) @@ -113,7 +117,7 @@ func (p *IntsFreePool) pin() *IntspoolLocal { func (p *IntsFreePool) Put(x []int) { done := false l := p.pin() - if l.privateSize < IntsmaxPrivateElems { + if l.privateSize < intsMaxPrivateElems { l.private[l.privateSize] = x l.privateSize++ done = true @@ -125,7 +129,7 @@ func (p *IntsFreePool) Put(x []int) { // queues to log(log(#queues)) . // // https://www.eecs.harvard.edu/~michaelm/postscripts/mythesis.pdf - l2 := &p.local[int(fastrand())%len(p.local)] + l2 := &p.local[int(fastrandn(uint32(len(p.local))))] lSize := atomic.LoadInt64(&l.sharedSize) l2Size := atomic.LoadInt64(&l2.sharedSize) if l2Size < lSize { diff --git a/gtl/tests/int_freepool_race.go b/gtl/tests/int_freepool_race.go new file mode 100644 index 00000000..ff6f61a6 --- /dev/null +++ b/gtl/tests/int_freepool_race.go @@ -0,0 +1,28 @@ +// Code generated by "../generate.py --output=int_freepool_race.go --prefix=ints -DELEM=[]int --package=tests ../randomized_freepool_race.go.tpl". DO NOT EDIT. + +//go:build race +// +build race + +package tests + +import "sync/atomic" + +type IntsFreePool struct { + new func() []int + len int64 +} + +func NewIntsFreePool(new func() []int, maxSize int) *IntsFreePool { + return &IntsFreePool{new: new} +} + +func (p *IntsFreePool) Put(x []int) { + atomic.AddInt64(&p.len, -1) +} + +func (p *IntsFreePool) Get() []int { + atomic.AddInt64(&p.len, 1) + return p.new() +} + +func (p *IntsFreePool) ApproxLen() int { return int(atomic.LoadInt64(&p.len)) } diff --git a/gtl/tests/freepool_internal.go b/gtl/tests/randomized_freepool_internal.go similarity index 85% rename from gtl/tests/freepool_internal.go rename to gtl/tests/randomized_freepool_internal.go index e1997796..bea61f6b 100644 --- a/gtl/tests/freepool_internal.go +++ b/gtl/tests/randomized_freepool_internal.go @@ -1,3 +1,4 @@ +// Code generated by generate_randomized_freepool.py. DO NOT EDIT. package tests // This import is needed to use go:linkname. @@ -21,5 +22,5 @@ func runtime_procPin() int //go:nosplit func runtime_procUnpin() -//go:linkname fastrand sync.fastrand -func fastrand() uint32 +//go:linkname fastrandn sync.fastrandn +func fastrandn(n uint32) uint32 diff --git a/gtl/tests/randomized_freepool_internal.s b/gtl/tests/randomized_freepool_internal.s new file mode 100644 index 00000000..0c518de0 --- /dev/null +++ b/gtl/tests/randomized_freepool_internal.s @@ -0,0 +1,6 @@ +// Code generated by generate_randomized_freepool.py. DO NOT EDIT. + +// Dummy file to force the go compiler to honor go:linkname directives. See +// +// https://github.com/golang/go/issues/15006 +// http://www.alangpierce.com/blog/2016/03/17/adventures-in-go-accessing-unexported-functions/ diff --git a/gtl/tests/randomized_freepool_test.go b/gtl/tests/randomized_freepool_test.go index 17f73616..9cb2c13d 100644 --- a/gtl/tests/randomized_freepool_test.go +++ b/gtl/tests/randomized_freepool_test.go @@ -1,5 +1,5 @@ -//go:generate ../generate.py --prefix=Ints -DELEM=[]int --package=tests --output=int_freepool.go ../randomized_freepool.go.tpl -//go:generate ../generate.py --prefix=Strings -DELEM=[]string --package=tests --output=string_freepool.go ../randomized_freepool.go.tpl +//go:generate ../generate_randomized_freepool.py --prefix=ints -DELEM=[]int --package=tests --output=int_freepool +//go:generate ../generate_randomized_freepool.py --prefix=strings -DELEM=[]string --package=tests --output=string_freepool package tests diff --git a/gtl/tests/rcp_map_test.go b/gtl/tests/rcp_map_test.go new file mode 100644 index 00000000..dddf8ba3 --- /dev/null +++ b/gtl/tests/rcp_map_test.go @@ -0,0 +1,102 @@ +package tests + +//go:generate ../generate.py --prefix=rcuTest --PREFIX=RCUTest -DKEY=string -DVALUE=uint64 -DHASH=testhash --package=tests --output=rcu_map.go ../rcu_map.go.tpl + +import ( + "fmt" + "math/rand" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/grailbio/testutil/assert" +) + +type modelMap map[string]uint64 + +func testRCUMap(t *testing.T, r *rand.Rand) { + m := NewRCUTestMap(r.Intn(100)) + model := modelMap{} + + for i := 0; i < 1000; i++ { + op := r.Intn(3) + if op == 2 { + key := fmt.Sprintf("key%d", r.Intn(1000)) + val := r.Uint64() + m.Store(key, val) + model[key] = val + continue + } + key := fmt.Sprintf("key%d", r.Intn(1000)) + if val, ok := model[key]; ok { + val2, ok2 := m.Load(key) + assert.EQ(t, ok2, ok, key) + assert.EQ(t, val2, val, key) + } + } +} + +func TestRCUMap(t *testing.T) { + for seed := 0; seed < 1000; seed++ { + t.Run(fmt.Sprintf("seed=%d", seed), func(t *testing.T) { + testRCUMap(t, rand.New(rand.NewSource(int64(seed)))) + }) + } +} + +func TestConcurrentRCUMap(t *testing.T) { + var ( + seq uint64 + done uint64 + wg sync.WaitGroup + m = NewRCUTestMap(10) + ) + + // Producer + wg.Add(1) + go func() { + defer wg.Done() + start := time.Now() + for time.Since(start) < 5*time.Second { + val := seq + key := fmt.Sprintf("key%d", val) + m.Store(key, val) + atomic.StoreUint64(&seq, val+1) + time.Sleep(time.Millisecond) + } + atomic.StoreUint64(&done, 1) + }() + + // Consumer + for seed := 0; seed < 10; seed++ { + wg.Add(1) + go func(seed int) { + defer wg.Done() + r := rand.New(rand.NewSource(int64(seed))) + ops := 0 + for atomic.LoadUint64(&done) == 0 { + floor := atomic.LoadUint64(&seq) // A key < floor is guaranteed to be in the map + if floor == 0 { + time.Sleep(time.Millisecond) + continue + } + want := uint64(r.Intn(int(floor))) + key := fmt.Sprintf("key%d", want) + got, ok := m.Load(key) + ceil := atomic.LoadUint64(&seq) // A key > ceil is guaranteed not to be in the map + if ok { + assert.EQ(t, got, want) + assert.LE(t, want, ceil) + } else { + assert.GE(t, want, floor) + } + ops++ + } + t.Logf("Ops: %d", ops) + assert.GT(t, ops, 0) + }(seed) + } + + wg.Wait() +} diff --git a/gtl/tests/rcu_map.go b/gtl/tests/rcu_map.go new file mode 100644 index 00000000..e230b867 --- /dev/null +++ b/gtl/tests/rcu_map.go @@ -0,0 +1,160 @@ +// Code generated by "../generate.py --prefix=rcuTest --PREFIX=RCUTest -DKEY=string -DVALUE=uint64 -DHASH=testhash --package=tests --output=rcu_map.go ../rcu_map.go.tpl". DO NOT EDIT. + +package tests + +import ( + "sync/atomic" + "unsafe" +) + +// RCUTestMap is a concurrent map. A reader can access the map without lock, +// regardless of background updates. The writer side must coordinate using an +// external mutex if there are multiple writers. This map is linearizable. +// +// Example: +// +// m := NewRCUTestMap(10) +// go func() { // writer +// m.Store("foo", "bar") +// }() +// go func() { // reader +// val, ok := m.Load("foo") +// } +type RCUTestMap struct { + p unsafe.Pointer // *rcuTestMapState +} + +// RCUTestMapState represents a fixed-size chained hash table. It can store up to +// maxCapacity key/value pairs. Beyond that, the caller must create a new +// RCUTestMapState with a larger capacity. +type rcuTestMapState struct { + log2Len uint // ==log2(len(table)) + mask uint64 // == ^(log2Len-1) + table []unsafe.Pointer // *rcuTestMapNode + n int // # of objects currently stored in the table + maxCapacity int // max # of object that can be stored +} + +// RCUTestMapNode represents a hash bucket. +type rcuTestMapNode struct { + key string + value uint64 + + // next points to the next element in the same hash bucket + next unsafe.Pointer // *rcuTestMapNode +} + +func newRCUTestMapState(log2Len uint) *rcuTestMapState { + len := int(1 << log2Len) + table := &rcuTestMapState{ + log2Len: log2Len, + mask: uint64(log2Len - 1), + table: make([]unsafe.Pointer, 1< 31 { + // TODO(saito) We could make the table to grow larger than 32 bits, but + // doing so will break 32bit builds. + panic(initialLenHint) + } + log2Len++ + } + m := RCUTestMap{p: unsafe.Pointer(newRCUTestMapState(log2Len))} + return &m +} + +// Load finds a value with the given key. Returns false if not found. +func (m *RCUTestMap) Load(key string) (uint64, bool) { + hash := testhash(key) + table := (*rcuTestMapState)(atomic.LoadPointer(&m.p)) + b := int(hash & table.mask) + node := (*rcuTestMapNode)(atomic.LoadPointer(&table.table[b])) + for node != nil { + if node.key == key { + return node.value, true + } + node = (*rcuTestMapNode)(atomic.LoadPointer(&node.next)) + } + var dummy uint64 + return dummy, false +} + +// store returns false iff the table needs resizing. +func (t *rcuTestMapState) store(key string, value uint64) bool { + var ( + hash = testhash(key) + b = int(hash & t.mask) + node = (*rcuTestMapNode)(t.table[b]) + probeLen = 0 + prevNode *rcuTestMapNode + ) + for node != nil { + if node.key == key { + newNode := *node + newNode.value = value + if prevNode == nil { + atomic.StorePointer(&t.table[b], unsafe.Pointer(&newNode)) + } else { + atomic.StorePointer(&prevNode.next, unsafe.Pointer(&newNode)) + } + return true + } + prevNode = node + node = (*rcuTestMapNode)(node.next) + probeLen++ + if probeLen >= 4 && t.n >= t.maxCapacity { + return false + } + } + newNode := rcuTestMapNode{key: key, value: value} + if prevNode == nil { + atomic.StorePointer(&t.table[b], unsafe.Pointer(&newNode)) + } else { + atomic.StorePointer(&prevNode.next, unsafe.Pointer(&newNode)) + } + t.n++ + return true +} + +// Store stores the value for the given key. If the key is already in the map, +// it updates the mapping to the given value. +// +// Caution: if Store() is going to be called concurrently, it must be serialized +// externally. +func (m *RCUTestMap) Store(key string, value uint64) { + table := (*rcuTestMapState)(atomic.LoadPointer(&m.p)) + if table.store(key, value) { + return + } + log2Len := table.log2Len + 1 + if log2Len > 31 { + panic(log2Len) + } + newTable := newRCUTestMapState(log2Len) + // Copy the contents of the old table over to the new table. + for _, p := range table.table { + node := (*rcuTestMapNode)(p) + for node != nil { + if !newTable.store(node.key, node.value) { + panic(node) + } + node = (*rcuTestMapNode)(node.next) + } + } + if !newTable.store(key, value) { + panic(key) + } + atomic.StorePointer(&m.p, unsafe.Pointer(newTable)) +} diff --git a/gtl/tests/rcu_map_hash.go b/gtl/tests/rcu_map_hash.go new file mode 100644 index 00000000..3bfc5f4e --- /dev/null +++ b/gtl/tests/rcu_map_hash.go @@ -0,0 +1,11 @@ +package tests + +import ( + "hash/fnv" +) + +func testhash(val string) uint64 { + h := fnv.New64() + h.Write([]byte(val)) + return h.Sum64() +} diff --git a/gtl/tests/string_freepool.go b/gtl/tests/string_freepool.go index 42ce23e5..9899bce9 100644 --- a/gtl/tests/string_freepool.go +++ b/gtl/tests/string_freepool.go @@ -1,4 +1,8 @@ -// Code generated from " ../generate.py --prefix=Strings -DELEM=[]string --package=tests --output=string_freepool.go ../randomized_freepool.go.tpl ". DO NOT EDIT. +// Code generated by "../generate.py --output=string_freepool.go --prefix=strings -DELEM=[]string --package=tests ../randomized_freepool.go.tpl". DO NOT EDIT. + +//go:build !race +// +build !race + // StringsFreePool is thread-safe pool that uses power-of-two loadbalancing across // CPUs. @@ -29,8 +33,8 @@ // //go:nosplit // func runtime_procUnpin() // -// //go:linkname fastrand sync.fastrand -// func fastrand() uint32 +// //go:linkname fastrandn sync.fastrandn +// func fastrandn(n uint32) uint32 // // 2. An empty .s file. @@ -45,27 +49,27 @@ import ( // StringsFreePool is a variation of sync.Pool, specialized for a concrete type. // -// - Put() performs power-of-two loadbalancing, and Get() looks only at the -// local queue. This improves the performance of Get() on many-core machines, -// at the cost of slightly more allocations. +// - Put() performs power-of-two loadbalancing, and Get() looks only at the +// local queue. This improves the performance of Get() on many-core machines, +// at the cost of slightly more allocations. // // - It assumes that GOMAXPROCS is fixed at boot. // -// - It never frees objects accumulated in the pool. We could add this feature -// if needed. +// - It never frees objects accumulated in the pool. We could add this feature +// if needed. type StringsFreePool struct { new func() []string - local []StringspoolLocal + local []stringsPoolLocal maxLocalSize int64 } const ( - StringsmaxPrivateElems = 4 - StringscacheLineSize = 64 + stringsMaxPrivateElems = 4 + stringsCacheLineSize = 64 ) -type StringspoolLocalInner struct { - private [StringsmaxPrivateElems][]string // Can be used only by the respective P. +type stringsPoolLocalInner struct { + private [stringsMaxPrivateElems][]string // Can be used only by the respective P. privateSize int shared [][]string // Can be used by any P. @@ -73,10 +77,10 @@ type StringspoolLocalInner struct { mu sync.Mutex // Protects shared. } -type StringspoolLocal struct { - StringspoolLocalInner - // Prevents false sharing. - pad [StringscacheLineSize - unsafe.Sizeof(StringspoolLocalInner{})%StringscacheLineSize]byte +type stringsPoolLocal struct { + stringsPoolLocalInner + // Pad prevents false sharing. + pad [stringsCacheLineSize - unsafe.Sizeof(stringsPoolLocalInner{})%stringsCacheLineSize]byte } // NewStringsFreePool creates a new free object pool. new should create a new @@ -94,13 +98,13 @@ func NewStringsFreePool(new func() []string, maxSize int) *StringsFreePool { } p := &StringsFreePool{ new: new, - local: make([]StringspoolLocal, maxProcs), + local: make([]stringsPoolLocal, maxProcs), maxLocalSize: int64(maxLocalSize), } return p } -func (p *StringsFreePool) pin() *StringspoolLocal { +func (p *StringsFreePool) pin() *stringsPoolLocal { pid := runtime_procPin() if int(pid) >= len(p.local) { panic(pid) @@ -113,7 +117,7 @@ func (p *StringsFreePool) pin() *StringspoolLocal { func (p *StringsFreePool) Put(x []string) { done := false l := p.pin() - if l.privateSize < StringsmaxPrivateElems { + if l.privateSize < stringsMaxPrivateElems { l.private[l.privateSize] = x l.privateSize++ done = true @@ -125,7 +129,7 @@ func (p *StringsFreePool) Put(x []string) { // queues to log(log(#queues)) . // // https://www.eecs.harvard.edu/~michaelm/postscripts/mythesis.pdf - l2 := &p.local[int(fastrand())%len(p.local)] + l2 := &p.local[int(fastrandn(uint32(len(p.local))))] lSize := atomic.LoadInt64(&l.sharedSize) l2Size := atomic.LoadInt64(&l2.sharedSize) if l2Size < lSize { diff --git a/gtl/tests/string_freepool_race.go b/gtl/tests/string_freepool_race.go new file mode 100644 index 00000000..d29c5021 --- /dev/null +++ b/gtl/tests/string_freepool_race.go @@ -0,0 +1,28 @@ +// Code generated by "../generate.py --output=string_freepool_race.go --prefix=strings -DELEM=[]string --package=tests ../randomized_freepool_race.go.tpl". DO NOT EDIT. + +//go:build race +// +build race + +package tests + +import "sync/atomic" + +type StringsFreePool struct { + new func() []string + len int64 +} + +func NewStringsFreePool(new func() []string, maxSize int) *StringsFreePool { + return &StringsFreePool{new: new} +} + +func (p *StringsFreePool) Put(x []string) { + atomic.AddInt64(&p.len, -1) +} + +func (p *StringsFreePool) Get() []string { + atomic.AddInt64(&p.len, 1) + return p.new() +} + +func (p *StringsFreePool) ApproxLen() int { return int(atomic.LoadInt64(&p.len)) } diff --git a/gtl/tests/unsafe.go b/gtl/tests/unsafe.go index d62c2707..8513a99d 100644 --- a/gtl/tests/unsafe.go +++ b/gtl/tests/unsafe.go @@ -1,4 +1,5 @@ -// Code generated from " ../generate.py --prefix= -DELEM=int32 --package=tests --output=unsafe.go ../unsafe.go.tpl ". DO NOT EDIT. +// Code generated by "../generate.py --prefix= -DELEM=int32 --package=tests --output=unsafe.go ../unsafe.go.tpl". DO NOT EDIT. + package tests import ( diff --git a/intervalmap/intervalmap.go b/intervalmap/intervalmap.go index 07ae7c7c..5ba19f2e 100644 --- a/intervalmap/intervalmap.go +++ b/intervalmap/intervalmap.go @@ -6,15 +6,19 @@ // (http://www.sci.utah.edu/~wald/Publications/2007/ParallelBVHBuild/fastbuild.pdf). package intervalmap -//go:generate ../gtl/generate.py --prefix=searcher -DELEM=*searcher --package=intervalmap --output=search_freepool.go ../gtl/randomized_freepool.go.tpl +//go:generate ../gtl/generate_randomized_freepool.py --output=search_freepool --prefix=searcher --PREFIX=searcher -DELEM=*searcher --package=intervalmap import ( + "bytes" + "encoding/gob" + "fmt" "math" "math/rand" "runtime" "unsafe" "github.com/grailbio/base/log" + "github.com/grailbio/base/must" ) // Key is the type for interval boundaries. @@ -56,14 +60,6 @@ func (i Interval) Intersect(j Interval) Interval { return Interval{minKey, maxKey} } -// Width computes i.Limit-i.Start, or zero if i.Empty(). -func (i Interval) Width() Key { - if i.Empty() { - return 0 - } - return i.Limit - i.Start -} - // Empty checks if the interval is empty. func (i Interval) Empty() bool { return i.Start >= i.Limit } @@ -146,13 +142,17 @@ func New(ents []Entry) *T { t.stats.MaxDepth = -1 t.stats.MaxLeafNodeSize = -1 t.root.init("", ients, keyRange(ients), r, &t.stats) - t.pool = NewsearcherFreePool(func() *searcher { + t.pool = newSearcherFreePool(t, len(ents)) + return t +} + +func newSearcherFreePool(t *T, nEnt int) *searcherFreePool { + return NewsearcherFreePool(func() *searcher { return &searcher{ tree: t, - hits: make([]uint32, len(ents)), + hits: make([]uint32, nEnt), } }, runtime.NumCPU()*2) - return t } // searcher keeps state needed during one search episode. It is owned by one @@ -186,6 +186,17 @@ func (t *T) Get(interval Interval, ents *[]*Entry) { } } +// Any checks if any of the entries intersect the given interval. +func (t *T) Any(interval Interval) bool { + s := t.pool.Get() + s.searchID++ + found := t.root.any(interval, s) + if s.searchID < math.MaxUint32 { + t.pool.Put(s) + } + return found +} + func keyRange(ents []*entry) Interval { i := emptyInterval for _, e := range ents { @@ -194,7 +205,7 @@ func keyRange(ents []*entry) Interval { return i } -const maxSample = 16 +const maxSample = 8 // randomSample picks maxSample random elements from ents[]. It shuffles ents[] // in place. @@ -248,17 +259,17 @@ func split(label string, ents []*entry, bounds Interval, r *rand.Rand) (mid Key, } // splitAt splits ents[] into two subsets, assuming bounds is split at mid. - splitAt := func(ents []*entry, mid Key) ([]*entry, []*entry) { - left, right := []*entry{}, []*entry{} + splitAt := func(ents []*entry, mid Key, left, right *[]*entry) { + *left = (*left)[:0] + *right = (*right)[:0] for _, e := range ents { if e.Interval.Intersects(Interval{bounds.Start, mid}) { - left = append(left, e) + *left = append(*left, e) } if e.Interval.Intersects(Interval{mid, bounds.Limit}) { - right = append(right, e) + *right = append(*right, e) } } - return left, right } // Compute the cost of splitting at each of candidates[]. @@ -286,17 +297,19 @@ func split(label string, ents []*entry, bounds Interval, r *rand.Rand) (mid Key, minCost := math.MaxFloat64 var minMid Key var minLeft, minRight []*entry + var tmpLeft, tmpRight []*entry + for _, mid := range candidates[:nCandidate] { - left, right := splitAt(ents, mid) - if len(left) == 0 || len(right) == 0 { + splitAt(ents, mid, &tmpLeft, &tmpRight) + if len(tmpLeft) == 0 || len(tmpRight) == 0 { continue } - cost := float64(len(left))*float64(mid-sampleRange.Start) + - float64(len(right))*float64(sampleRange.Limit-mid) + cost := float64(len(tmpLeft))*float64(mid-sampleRange.Start) + + float64(len(tmpRight))*float64(sampleRange.Limit-mid) if cost < minCost { minMid = mid - minLeft = left - minRight = right + minLeft, tmpLeft = tmpLeft, minLeft + minRight, tmpRight = tmpRight, minRight minCost = cost } } @@ -372,13 +385,126 @@ func (n *node) get(interval Interval, ents *[]*Entry, s *searcher) { n.right.get(interval, ents, s) } -//go:linkname runtime_procPin sync.runtime_procPin -//go:nosplit -func runtime_procPin() int // nolint: golint +func (n *node) any(interval Interval, s *searcher) bool { + interval = interval.Intersect(n.bounds) + if interval.Empty() { + return false + } + if len(n.ents) > 0 { // Leaf node + for _, e := range n.ents { + if interval.Intersects(e.Interval) { + return true + } + } + return false + } + found := n.left.any(interval, s) + if !found { + found = n.right.any(interval, s) + } + return found +} + +// GOB support + +const gobFormatVersion = 1 + +// MarshalBinary implements encoding.BinaryMarshaler interface. It allows T to +// be encoded and decoded using Gob. +func (t *T) MarshalBinary() (data []byte, err error) { + buf := bytes.Buffer{} + e := gob.NewEncoder(&buf) + must.Nil(e.Encode(gobFormatVersion)) + marshalNode(e, &t.root) + must.Nil(e.Encode(t.stats)) + return buf.Bytes(), nil +} + +func marshalNode(e *gob.Encoder, n *node) { + if n == nil { + must.Nil(e.Encode(false)) + return + } + must.Nil(e.Encode(true)) + must.Nil(e.Encode(n.bounds)) + marshalNode(e, n.left) + marshalNode(e, n.right) + must.Nil(e.Encode(len(n.ents))) + for _, ent := range n.ents { + must.Nil(e.Encode(ent.Entry)) + must.Nil(e.Encode(ent.id)) + } + must.Nil(e.Encode(n.label)) +} -//go:linkname runtime_procUnpin sync.runtime_procUnpin -//go:nosplit -func runtime_procUnpin() // nolint: golint +// UnmarshalBinary implements encoding.BinaryUnmarshaler interface. +// It allows T to be encoded and decoded using Gob. +func (t *T) UnmarshalBinary(data []byte) error { + buf := bytes.NewReader(data) + d := gob.NewDecoder(buf) + var version int + if err := d.Decode(&version); err != nil { + return err + } + if version != gobFormatVersion { + return fmt.Errorf("gob decode: got version %d, want %d", version, gobFormatVersion) + } + var ( + maxid = -1 + err error + root *node + ) + if root, err = unmarshalNode(d, &maxid); err != nil { + return err + } + t.root = *root + if err := d.Decode(&t.stats); err != nil { + return err + } + t.pool = newSearcherFreePool(t, maxid+1) + return nil +} -//go:linkname fastrand sync.fastrand -func fastrand() uint32 +func unmarshalNode(d *gob.Decoder, maxid *int) (*node, error) { + var ( + exist bool + err error + ) + if err = d.Decode(&exist); err != nil { + return nil, err + } + if !exist { + return nil, nil + } + n := &node{} + if err := d.Decode(&n.bounds); err != nil { + return nil, err + } + if n.left, err = unmarshalNode(d, maxid); err != nil { + return nil, err + } + if n.right, err = unmarshalNode(d, maxid); err != nil { + return nil, err + } + var nEnt int + if err := d.Decode(&nEnt); err != nil { + return nil, err + } + n.ents = make([]*entry, nEnt) + for i := 0; i < nEnt; i++ { + n.ents[i] = &entry{} + if err := d.Decode(&n.ents[i].Entry); err != nil { + return nil, err + } + if err := d.Decode(&n.ents[i].id); err != nil { + return nil, err + } + if n.ents[i].id > *maxid { + *maxid = n.ents[i].id + } + } + if err := d.Decode(&n.label); err != nil { + return nil, err + } + return n, nil +} diff --git a/intervalmap/intervalmap_test.go b/intervalmap/intervalmap_test.go index e36f616d..519de503 100644 --- a/intervalmap/intervalmap_test.go +++ b/intervalmap/intervalmap_test.go @@ -1,6 +1,8 @@ package intervalmap import ( + "bytes" + "encoding/gob" "fmt" "math/rand" "sort" @@ -144,6 +146,7 @@ func testRandom(t *testing.T, seed int, nElem int, max Key, width float64) { entries = append(entries, newEntry(start, limit)) } tree := New(entries) + tree2 := gobEncodeAndDecode(t, tree) for i := 0; i < 1000; i++ { start, limit := randInterval(r, max, width) @@ -153,6 +156,7 @@ func testRandom(t *testing.T, seed int, nElem int, max Key, width float64) { result := sortIntervals(testGet(tree, start, limit)) assert.EQ(t, result, r0, "seed=%d, i=%d, search=[%d,%d)", seed, i, start, limit) assert.EQ(t, result, r1, "seed=%d, i=%d, search=[%d,%d)", seed, i, start, limit) + assert.EQ(t, result, sortIntervals(testGet(tree2, start, limit))) } } @@ -160,6 +164,31 @@ func TestRandom0(t *testing.T) { testRandom(t, 0, 128, 1024, 10) } func TestRandom1(t *testing.T) { testRandom(t, 1, 128, 1024, 100) } func TestRandom2(t *testing.T) { testRandom(t, 1, 1000, 8192, 1000) } +func gobEncodeAndDecode(t *testing.T, tree *T) *T { + buf := bytes.Buffer{} + e := gob.NewEncoder(&buf) + assert.NoError(t, e.Encode(tree)) + + d := gob.NewDecoder(&buf) + var tree2 *T + assert.NoError(t, d.Decode(&tree2)) + return tree2 +} + +func TestGobEmpty(t *testing.T) { + tree := New(nil) + tree2 := gobEncodeAndDecode(t, tree) + expect.EQ(t, testGet(tree2, 1, 2), []Interval{}) +} + +func TestGobSmall(t *testing.T) { + tree := gobEncodeAndDecode(t, New([]Entry{newEntry(1, 2), newEntry(10, 15)})) + expect.EQ(t, testGet(tree, -1, 0), []Interval{}) + expect.EQ(t, testGet(tree, 0, 2), []Interval{Interval{1, 2}}) + expect.EQ(t, testGet(tree, 0, 10), []Interval{Interval{1, 2}}) + expect.EQ(t, sortIntervals(testGet(tree, 0, 11)), []Interval{Interval{1, 2}, Interval{10, 15}}) +} + func benchmarkRandom(b *testing.B, seed int, nElem int, max Key, width float64) { b.StopTimer() r := rand.New(rand.NewSource(int64(seed))) @@ -204,7 +233,9 @@ func (i testInterval) Overlap(b interval.IntRange) bool { func (i testInterval) ID() uintptr { return i.id } // Range implements interval.IntInterface. -func (i testInterval) Range() interval.IntRange { return interval.IntRange{int(i.start), int(i.limit)} } +func (i testInterval) Range() interval.IntRange { + return interval.IntRange{Start: int(i.start), End: int(i.limit)} +} // String implements interval.IntInterface func (i testInterval) String() string { return fmt.Sprintf("[%d,%d)#%d", i.start, i.limit, i.id) } @@ -274,3 +305,47 @@ func Example() { // [3,5) // [3,5),[6,7) } + +// Example_gob is an example of serializing an intervalmap using Gob. +func Example_gob() { + newEntry := func(start, limit Key) Entry { + return Entry{ + Interval: Interval{start, limit}, + Data: fmt.Sprintf("[%d,%d)", start, limit), + } + } + + tree := New([]Entry{newEntry(1, 4), newEntry(3, 5), newEntry(6, 7)}) + + buf := bytes.Buffer{} + enc := gob.NewEncoder(&buf) + if err := enc.Encode(tree); err != nil { + panic(err) + } + dec := gob.NewDecoder(&buf) + var tree2 *T + if err := dec.Decode(&tree2); err != nil { + panic(err) + } + + doGet := func(tree *T, start, limit Key) string { + matches := []*Entry{} + tree.Get(Interval{start, limit}, &matches) + s := []string{} + for _, m := range matches { + s = append(s, m.Data.(string)) + } + sort.Strings(s) + return strings.Join(s, ",") + } + + fmt.Println(doGet(tree2, 0, 2)) + fmt.Println(doGet(tree2, 0, 4)) + fmt.Println(doGet(tree2, 4, 6)) + fmt.Println(doGet(tree2, 4, 7)) + // Output: + // [1,4) + // [1,4),[3,5) + // [3,5) + // [3,5),[6,7) +} diff --git a/intervalmap/randomized_freepool_internal.go b/intervalmap/randomized_freepool_internal.go new file mode 100644 index 00000000..c439b998 --- /dev/null +++ b/intervalmap/randomized_freepool_internal.go @@ -0,0 +1,26 @@ +// Code generated by generate_randomized_freepool.py. DO NOT EDIT. +package intervalmap + +// This import is needed to use go:linkname. +import _ "unsafe" + +// The following functions are defined in go runtime. To use them, we need to +// import "unsafe", and elsewhere in this package, import "C" to force compiler +// to recognize the "go:linktime" directive. Some of the details are explained +// in the below blog post. +// +// procPin() pins the caller to the current processor, and returns the processor +// id in range [0,GOMAXPROCS). procUnpin() undos the effect of procPin(). +// +// http://www.alangpierce.com/blog/2016/03/17/adventures-in-go-accessing-unexported-functions/ + +//go:linkname runtime_procPin sync.runtime_procPin +//go:nosplit +func runtime_procPin() int + +//go:linkname runtime_procUnpin sync.runtime_procUnpin +//go:nosplit +func runtime_procUnpin() + +//go:linkname fastrandn sync.fastrandn +func fastrandn(n uint32) uint32 diff --git a/intervalmap/randomized_freepool_internal.s b/intervalmap/randomized_freepool_internal.s new file mode 100644 index 00000000..0c518de0 --- /dev/null +++ b/intervalmap/randomized_freepool_internal.s @@ -0,0 +1,6 @@ +// Code generated by generate_randomized_freepool.py. DO NOT EDIT. + +// Dummy file to force the go compiler to honor go:linkname directives. See +// +// https://github.com/golang/go/issues/15006 +// http://www.alangpierce.com/blog/2016/03/17/adventures-in-go-accessing-unexported-functions/ diff --git a/intervalmap/search_freepool.go b/intervalmap/search_freepool.go index a79d4df0..90c703b5 100644 --- a/intervalmap/search_freepool.go +++ b/intervalmap/search_freepool.go @@ -1,4 +1,8 @@ -// Code generated from " ../gtl/generate.py --prefix=searcher -DELEM=*searcher --package=intervalmap --output=search_freepool.go ../gtl/randomized_freepool.go.tpl ". DO NOT EDIT. +// Code generated by "../gtl/generate.py --output=search_freepool.go --prefix=searcher --PREFIX=searcher -DELEM=*searcher --package=intervalmap ../gtl/randomized_freepool.go.tpl". DO NOT EDIT. + +//go:build !race +// +build !race + // searcherFreePool is thread-safe pool that uses power-of-two loadbalancing across // CPUs. @@ -29,8 +33,8 @@ // //go:nosplit // func runtime_procUnpin() // -// //go:linkname fastrand sync.fastrand -// func fastrand() uint32 +// //go:linkname fastrandn sync.fastrandn +// func fastrandn(n uint32) uint32 // // 2. An empty .s file. @@ -45,27 +49,27 @@ import ( // searcherFreePool is a variation of sync.Pool, specialized for a concrete type. // -// - Put() performs power-of-two loadbalancing, and Get() looks only at the -// local queue. This improves the performance of Get() on many-core machines, -// at the cost of slightly more allocations. +// - Put() performs power-of-two loadbalancing, and Get() looks only at the +// local queue. This improves the performance of Get() on many-core machines, +// at the cost of slightly more allocations. // // - It assumes that GOMAXPROCS is fixed at boot. // -// - It never frees objects accumulated in the pool. We could add this feature -// if needed. +// - It never frees objects accumulated in the pool. We could add this feature +// if needed. type searcherFreePool struct { new func() *searcher - local []searcherpoolLocal + local []searcherPoolLocal maxLocalSize int64 } const ( - searchermaxPrivateElems = 4 - searchercacheLineSize = 64 + searcherMaxPrivateElems = 4 + searcherCacheLineSize = 64 ) -type searcherpoolLocalInner struct { - private [searchermaxPrivateElems]*searcher // Can be used only by the respective P. +type searcherPoolLocalInner struct { + private [searcherMaxPrivateElems]*searcher // Can be used only by the respective P. privateSize int shared []*searcher // Can be used by any P. @@ -73,10 +77,10 @@ type searcherpoolLocalInner struct { mu sync.Mutex // Protects shared. } -type searcherpoolLocal struct { - searcherpoolLocalInner +type searcherPoolLocal struct { + searcherPoolLocalInner // Pad prevents false sharing. - pad [searchercacheLineSize - unsafe.Sizeof(searcherpoolLocalInner{})%searchercacheLineSize]byte + pad [searcherCacheLineSize - unsafe.Sizeof(searcherPoolLocalInner{})%searcherCacheLineSize]byte } // NewsearcherFreePool creates a new free object pool. new should create a new @@ -94,13 +98,13 @@ func NewsearcherFreePool(new func() *searcher, maxSize int) *searcherFreePool { } p := &searcherFreePool{ new: new, - local: make([]searcherpoolLocal, maxProcs), + local: make([]searcherPoolLocal, maxProcs), maxLocalSize: int64(maxLocalSize), } return p } -func (p *searcherFreePool) pin() *searcherpoolLocal { +func (p *searcherFreePool) pin() *searcherPoolLocal { pid := runtime_procPin() if int(pid) >= len(p.local) { panic(pid) @@ -113,7 +117,7 @@ func (p *searcherFreePool) pin() *searcherpoolLocal { func (p *searcherFreePool) Put(x *searcher) { done := false l := p.pin() - if l.privateSize < searchermaxPrivateElems { + if l.privateSize < searcherMaxPrivateElems { l.private[l.privateSize] = x l.privateSize++ done = true @@ -125,7 +129,7 @@ func (p *searcherFreePool) Put(x *searcher) { // queues to log(log(#queues)) . // // https://www.eecs.harvard.edu/~michaelm/postscripts/mythesis.pdf - l2 := &p.local[int(fastrand())%len(p.local)] + l2 := &p.local[int(fastrandn(uint32(len(p.local))))] lSize := atomic.LoadInt64(&l.sharedSize) l2Size := atomic.LoadInt64(&l2.sharedSize) if l2Size < lSize { diff --git a/intervalmap/search_freepool_race.go b/intervalmap/search_freepool_race.go new file mode 100644 index 00000000..3536100e --- /dev/null +++ b/intervalmap/search_freepool_race.go @@ -0,0 +1,28 @@ +// Code generated by "../gtl/generate.py --output=search_freepool_race.go --prefix=searcher --PREFIX=searcher -DELEM=*searcher --package=intervalmap ../gtl/randomized_freepool_race.go.tpl". DO NOT EDIT. + +//go:build race +// +build race + +package intervalmap + +import "sync/atomic" + +type searcherFreePool struct { + new func() *searcher + len int64 +} + +func NewsearcherFreePool(new func() *searcher, maxSize int) *searcherFreePool { + return &searcherFreePool{new: new} +} + +func (p *searcherFreePool) Put(x *searcher) { + atomic.AddInt64(&p.len, -1) +} + +func (p *searcherFreePool) Get() *searcher { + atomic.AddInt64(&p.len, 1) + return p.new() +} + +func (p *searcherFreePool) ApproxLen() int { return int(atomic.LoadInt64(&p.len)) } diff --git a/intervalmap/unsafe.s b/intervalmap/unsafe.s deleted file mode 100644 index e69de29b..00000000 diff --git a/ioctx/fsctx/fs.go b/ioctx/fsctx/fs.go new file mode 100644 index 00000000..b8e464de --- /dev/null +++ b/ioctx/fsctx/fs.go @@ -0,0 +1,35 @@ +// fsctx adds context.Context to io/fs APIs. +// +// TODO: Specify policy for future additions to this package. See ioctx. +package fsctx + +import ( + "context" + "os" +) + +// FS is io/fs.FS with context added. +type FS interface { + Open(_ context.Context, name string) (File, error) +} + +// File is io/fs.File with context added. +type File interface { + Stat(context.Context) (os.FileInfo, error) + Read(context.Context, []byte) (int, error) + Close(context.Context) error +} + +// DirEntry is io/fs.DirEntry with context added. +type DirEntry interface { + Name() string + IsDir() bool + Type() os.FileMode + Info(context.Context) (os.FileInfo, error) +} + +// ReadDirFile is io/fs.ReadDirFile with context added. +type ReadDirFile interface { + File + ReadDir(_ context.Context, n int) ([]DirEntry, error) +} diff --git a/ioctx/fsctx/stat.go b/ioctx/fsctx/stat.go new file mode 100644 index 00000000..f8c23d5f --- /dev/null +++ b/ioctx/fsctx/stat.go @@ -0,0 +1,26 @@ +package fsctx + +import ( + "context" + "os" +) + +// StatFS is io/fs.StatFS with context added. +type StatFS interface { + FS + Stat(_ context.Context, name string) (os.FileInfo, error) +} + +// Stat is io/fs.Stat with context added. +func Stat(ctx context.Context, fsys FS, name string) (os.FileInfo, error) { + if fsys, ok := fsys.(StatFS); ok { + return fsys.Stat(ctx, name) + } + + file, err := fsys.Open(ctx, name) + if err != nil { + return nil, err + } + defer func() { _ = file.Close(ctx) }() + return file.Stat(ctx) +} diff --git a/ioctx/io.go b/ioctx/io.go new file mode 100644 index 00000000..7339475f --- /dev/null +++ b/ioctx/io.go @@ -0,0 +1,53 @@ +// ioctx adds context.Context to io APIs. +// +// TODO: Specify policy for future additions to this package. It may be best to only add analogues +// of the stdlib "io" so ioctx.* are easy for readers to understand. New functions and types (other +// than conversions to and from stdlib types) should be added elsewhere. +package ioctx + +import ( + "context" +) + +// Reader is io.Reader with context added. +type Reader interface { + Read(context.Context, []byte) (n int, err error) +} + +// Writer is io.Writer with context added. +type Writer interface { + Write(context.Context, []byte) (n int, err error) +} + +// Closer is io.Closer with context added. +type Closer interface { + Close(context.Context) error +} + +// Seeker is io.Seeker with context added. +type Seeker interface { + Seek(_ context.Context, offset int64, whence int) (int64, error) +} + +// ReadCloser is io.ReadCloser with context added. + +type ReadCloser interface { + Reader + Closer +} + +// ReadSeeker is io.ReadSeeker with context added. +type ReadSeeker interface { + Reader + Seeker +} + +// ReaderAt is io.ReaderAt with context added. +type ReaderAt interface { + ReadAt(_ context.Context, dst []byte, off int64) (n int, err error) +} + +// WriterAt is io.WriterAt with context added. +type WriterAt interface { + WriteAt(_ context.Context, p []byte, off int64) (n int, err error) +} diff --git a/ioctx/spliceio/spliceio.go b/ioctx/spliceio/spliceio.go new file mode 100644 index 00000000..dfae38eb --- /dev/null +++ b/ioctx/spliceio/spliceio.go @@ -0,0 +1,51 @@ +package spliceio + +import ( + "context" + "os" + + "github.com/grailbio/base/ioctx/fsctx" +) + +// ReaderAt reads data by giving the caller an OS file descriptor plus coordinates so the +// caller can directly splice (or just read) the file descriptor. Concurrent calls are allowed. +// +// It's possible for gotSize to be less than wantSize with nil error. This is different from +// io.ReaderAt. Callers should simply continue their read at off + gotSize. +type ReaderAt interface { + // SpliceReadAt returns a file descriptor and coordinates, or error. + // + // Note that fdOff is totally unrelated to off; fdOff must only be used for operations on fd. + // No guarantees are made about fd from different calls (no consistency, no uniqueness). + // No guarantees are made about fdOff from different calls (no ordering, no uniqueness). + SpliceReadAt(_ context.Context, wantSize int, off int64) (fd uintptr, gotSize int, fdOff int64, _ error) +} + +// OSFile is a ReaderAt wrapping os.File. It's also a fsctx.File and a +// fsnodefuse.Writable. +type OSFile os.File + +var ( + _ fsctx.File = (*OSFile)(nil) + _ ReaderAt = (*OSFile)(nil) +) + +func (f *OSFile) SpliceReadAt( + _ context.Context, wantSize int, off int64, +) ( + fd uintptr, gotSize int, fdOff int64, _ error, +) { + // TODO: Validation? Probably don't need to check file size. Maybe wantSize, off >= 0? + return (*os.File)(f).Fd(), wantSize, off, nil +} + +func (f *OSFile) Stat(context.Context) (os.FileInfo, error) { return (*os.File)(f).Stat() } +func (f *OSFile) Read(_ context.Context, b []byte) (int, error) { return (*os.File)(f).Read(b) } +func (f *OSFile) Close(context.Context) error { return (*os.File)(f).Close() } + +func (f *OSFile) WriteAt(_ context.Context, b []byte, offset int64) (int, error) { + return (*os.File)(f).WriteAt(b, offset) +} +func (f *OSFile) Truncate(_ context.Context, size int64) error { return (*os.File)(f).Truncate(size) } +func (f *OSFile) Flush(_ context.Context) error { return nil } +func (f *OSFile) Fsync(_ context.Context) error { return (*os.File)(f).Sync() } diff --git a/ioctx/spliceio/spliceio_test.go b/ioctx/spliceio/spliceio_test.go new file mode 100644 index 00000000..d0206158 --- /dev/null +++ b/ioctx/spliceio/spliceio_test.go @@ -0,0 +1,9 @@ +package spliceio_test + +import ( + "github.com/grailbio/base/file/fsnodefuse" + "github.com/grailbio/base/ioctx/spliceio" +) + +// Check this here to avoid circular package dependency with fsnodefuse. +var _ fsnodefuse.Writable = (*spliceio.OSFile)(nil) diff --git a/ioctx/std.go b/ioctx/std.go new file mode 100644 index 00000000..7688917c --- /dev/null +++ b/ioctx/std.go @@ -0,0 +1,151 @@ +package ioctx + +import ( + "context" + "io" +) + +type ( + fromStdReader struct{ io.Reader } + fromStdWriter struct{ io.Writer } + fromStdCloser struct{ io.Closer } + fromStdSeeker struct{ io.Seeker } + fromStdReaderAt struct{ io.ReaderAt } + fromStdWriterAt struct{ io.WriterAt } + + StdReader struct { + Ctx context.Context + Reader + } + StdWriter struct { + Ctx context.Context + Writer + } + StdCloser struct { + Ctx context.Context + Closer + } + StdSeeker struct { + Ctx context.Context + Seeker + } + StdReadCloser struct { + Ctx context.Context + ReadCloser + } + StdReaderAt struct { + Ctx context.Context + ReaderAt + } + StdWriterAt struct { + Ctx context.Context + WriterAt + } +) + +// FromStdReader wraps io.Reader as Reader. +func FromStdReader(r io.Reader) Reader { return fromStdReader{r} } + +func (r fromStdReader) Read(_ context.Context, dst []byte) (n int, err error) { + return r.Reader.Read(dst) +} + +// FromStdWriter wraps io.Writer as Writer. +func FromStdWriter(w io.Writer) Writer { return fromStdWriter{w} } + +func (w fromStdWriter) Write(_ context.Context, p []byte) (n int, err error) { + return w.Writer.Write(p) +} + +// FromStdCloser wraps io.Closer as Closer. +func FromStdCloser(c io.Closer) Closer { return fromStdCloser{c} } + +func (c fromStdCloser) Close(context.Context) error { return c.Closer.Close() } + +// FromStdSeeker wraps io.Seeker as Seeker. +func FromStdSeeker(s io.Seeker) Seeker { return fromStdSeeker{s} } + +func (s fromStdSeeker) Seek(_ context.Context, offset int64, whence int) (int64, error) { + return s.Seeker.Seek(offset, whence) +} + +// FromStdReadCloser wraps io.ReadCloser as ReadCloser. +func FromStdReadCloser(rc io.ReadCloser) ReadCloser { + return struct { + Reader + Closer + }{FromStdReader(rc), FromStdCloser(rc)} +} + +// FromStdReadSeeker wraps io.ReadSeeker as ReadSeeker. +func FromStdReadSeeker(rs io.ReadSeeker) ReadSeeker { + return struct { + Reader + Seeker + }{FromStdReader(rs), FromStdSeeker(rs)} +} + +// FromStdReaderAt wraps io.ReaderAt as ReaderAt. +func FromStdReaderAt(r io.ReaderAt) ReaderAt { return fromStdReaderAt{r} } + +func (r fromStdReaderAt) ReadAt(_ context.Context, dst []byte, off int64) (n int, err error) { + return r.ReaderAt.ReadAt(dst, off) +} + +// ToStdReader wraps Reader as io.Reader. +func ToStdReader(ctx context.Context, r Reader) io.Reader { return StdReader{ctx, r} } + +func (r StdReader) Read(dst []byte) (n int, err error) { + return r.Reader.Read(r.Ctx, dst) +} + +// ToStdWriter wraps Writer as io.Writer. +func ToStdWriter(ctx context.Context, w Writer) io.Writer { return StdWriter{ctx, w} } + +func (w StdWriter) Write(p []byte) (n int, err error) { + return w.Writer.Write(w.Ctx, p) +} + +// ToStdCloser wraps Closer as io.Closer. +func ToStdCloser(ctx context.Context, c Closer) io.Closer { return StdCloser{ctx, c} } + +func (c StdCloser) Close() error { + return c.Closer.Close(c.Ctx) +} + +// ToStdSeeker wraps Seeker as io.Seeker. +func ToStdSeeker(ctx context.Context, s Seeker) io.Seeker { return StdSeeker{ctx, s} } + +func (r StdSeeker) Seek(offset int64, whence int) (int64, error) { + return r.Seeker.Seek(r.Ctx, offset, whence) +} + +// ToStdReadCloser wraps ReadCloser as io.ReadCloser. +func ToStdReadCloser(ctx context.Context, rc ReadCloser) io.ReadCloser { + return struct { + io.Reader + io.Closer + }{ToStdReader(ctx, rc), ToStdCloser(ctx, rc)} +} + +// ToStdReadSeeker wraps ReadSeeker as io.ReadSeeker. +func ToStdReadSeeker(ctx context.Context, rs ReadSeeker) io.ReadSeeker { + return struct { + io.Reader + io.Seeker + }{ToStdReader(ctx, rs), ToStdSeeker(ctx, rs)} +} + +// ToStdReaderAt wraps ReaderAt as io.ReaderAt. +func ToStdReaderAt(ctx context.Context, r ReaderAt) io.ReaderAt { return StdReaderAt{ctx, r} } + +func (r StdReaderAt) ReadAt(dst []byte, off int64) (n int, err error) { + return r.ReaderAt.ReadAt(r.Ctx, dst, off) +} + +// ToStdWriterAt wraps WriterAt as io.WriterAt. +func ToStdWriterAt(ctx context.Context, w WriterAt) io.WriterAt { return StdWriterAt{ctx, w} } + +func (w StdWriterAt) WriteAt(dst []byte, off int64) (n int, err error) { + return w.WriterAt.WriteAt(w.Ctx, dst, off) +} diff --git a/iofmt/linewriter.go b/iofmt/linewriter.go new file mode 100644 index 00000000..c80e9b80 --- /dev/null +++ b/iofmt/linewriter.go @@ -0,0 +1,70 @@ +// Copyright 2021 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package iofmt + +import ( + "bytes" + "io" +) + +type lineWriter struct { + w io.Writer + buf []byte +} + +// LineWriter returns an io.WriteCloser that only calls w.Write with +// complete lines. This can be used to make it less likely (without +// locks) for lines to interleave, for example if you are concurrently +// writing lines of text to os.Stdout. This is particularly useful when +// composed with PrefixWriter. +// +// // Full lines will be written to os.Stdout, so they will be less likely to +// // be interleaved with other output. +// linew := LineWriter(os.Stdout) +// defer func() { +// _ = linew.Close() // Handle the possible error. +// }() +// w := PrefixWriter(linew, "my-prefix: ") +// +// Close will write any remaining partial line to the underlying writer. +func LineWriter(w io.Writer) io.WriteCloser { + return &lineWriter{w: w} +} + +func (w *lineWriter) Write(p []byte) (int, error) { + var n int + for { + i := bytes.Index(p, newline) + // TODO(jcharumilind): Limit buffer size. + switch i { + case -1: + w.buf = append(w.buf, p...) + return n + len(p), nil + default: + var err error + if len(w.buf) > 0 { + w.buf = append(w.buf, p[:i+1]...) + _, err = w.w.Write(w.buf) + w.buf = w.buf[:0] + } else { + _, err = w.w.Write(p[:i+1]) + } + n += i + 1 + if err != nil { + return n, err + } + p = p[i+1:] + } + } +} + +func (w *lineWriter) Close() error { + if len(w.buf) == 0 { + return nil + } + _, err := w.w.Write(w.buf) + w.buf = nil + return err +} diff --git a/iofmt/linewriter_test.go b/iofmt/linewriter_test.go new file mode 100644 index 00000000..1c2f3cf1 --- /dev/null +++ b/iofmt/linewriter_test.go @@ -0,0 +1,144 @@ +// Copyright 2021 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package iofmt_test + +import ( + "bytes" + "fmt" + "math/rand" + "strings" + "testing" + + "github.com/grailbio/base/iofmt" + "github.com/grailbio/testutil/assert" +) + +// saveWriter saves the calls made to Write for later comparison. +type saveWriter struct { + writes [][]byte +} + +func (w *saveWriter) Write(p []byte) (int, error) { + pCopy := make([]byte, len(p)) + copy(pCopy, p) + w.writes = append(w.writes, pCopy) + return len(p), nil +} + +// TestLineWriter verifies that a LineWriter calls Write on its underlying +// writer with complete lines. +func TestLineWriter(t *testing.T) { + for _, c := range []struct { + name string + makeLines func() []string + }{ + { + name: "Empty", + makeLines: func() []string { return []string{} }, + }, + { + name: "Numbered", + makeLines: func() []string { + const Nlines = 1000 + lines := make([]string, Nlines) + for i := range lines { + lines[i] = fmt.Sprintf("line %04d", i) + } + return lines + }, + }, + { + name: "SomeEmpty", + makeLines: func() []string { + const Nlines = 1000 + lines := make([]string, Nlines) + for i := range lines { + if rand.Intn(2) == 0 { + continue + } + lines[i] = fmt.Sprintf("line %04d", i) + } + return lines + }, + }, + { + name: "SomeLong", + makeLines: func() []string { + const Nlines = 1000 + lines := make([]string, Nlines) + for i := range lines { + var b strings.Builder + fmt.Fprintf(&b, "line %04d:", i) + for j := 0; j < rand.Intn(100); j++ { + b.WriteString(" lorem ipsum") + } + } + return lines + }, + }, + } { + t.Run(c.name, func(t *testing.T) { + lines := c.makeLines() + // bs is a concatenation of all the lines. We write this to a + // LineWriter in random segments. + var bs []byte + for _, line := range lines { + bs = append(bs, []byte(fmt.Sprintf("%s\n", line))...) + } + s := &saveWriter{} + w := iofmt.LineWriter(s) + defer func() { + assert.Nil(t, w.Close()) + }() + for len(bs) > 0 { + // Write in random segments. + n := rand.Intn(20) + if len(bs) < n { + n = len(bs) + } + m, err := w.Write(bs[:n]) + assert.Nil(t, err) + assert.EQ(t, m, n) + bs = bs[n:] + } + want := make([][]byte, len(lines)) + for i, line := range lines { + want[i] = []byte(fmt.Sprintf("%s\n", line)) + } + assert.EQ(t, s.writes, want) + }) + } +} + +// TestLineWriterClose verifies that (*LineWriter).Close writes any remaining +// partial line to the underlying writer. +func TestLineWriterClose(t *testing.T) { + for _, c := range []struct { + name string + bs []byte + }{ + { + name: "Empty", + bs: []byte{}, + }, + { + name: "PartialOnly", + bs: []byte("no terminal newline"), + }, + { + name: "Partial", + bs: []byte("line 0\nline 1\nline 2\nno terminal newline"), + }, + } { + t.Run(c.name, func(t *testing.T) { + var b bytes.Buffer + w := iofmt.LineWriter(&b) + _, err := w.Write(c.bs) + assert.Nil(t, err) + assert.Nil(t, w.Close()) + assert.EQ(t, b.Bytes(), c.bs) + }) + } +} diff --git a/iofmt/prefixwriter.go b/iofmt/prefixwriter.go new file mode 100644 index 00000000..1982b09a --- /dev/null +++ b/iofmt/prefixwriter.go @@ -0,0 +1,59 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +// Package iofmt contains various IO utilities used for formatting +// output. +package iofmt + +import ( + "bytes" + "io" +) + +var newline = []byte{'\n'} + +// prefixWriter is an io.Writer that outputs a prefix before each line. +type prefixWriter struct { + w io.Writer + prefix string + needPrefix bool +} + +// PrefixWriter returns a new io.Writer that copies its writes +// to the provided io.Writer, adding a prefix at the beginning +// of each line. +func PrefixWriter(w io.Writer, prefix string) io.Writer { + return &prefixWriter{w: w, prefix: prefix, needPrefix: true} +} + +func (w *prefixWriter) Write(p []byte) (n int, err error) { + if w.needPrefix { + if _, err := io.WriteString(w.w, w.prefix); err != nil { + return 0, err + } + w.needPrefix = false + } + for { + i := bytes.Index(p, newline) + switch i { + case len(p) - 1: + w.needPrefix = true + fallthrough + case -1: + m, err := w.w.Write(p) + return n + m, err + default: + m, err := w.w.Write(p[:i+1]) + n += m + if err != nil { + return n, err + } + _, err = io.WriteString(w.w, w.prefix) + if err != nil { + return n, err + } + p = p[i+1:] + } + } +} diff --git a/iofmt/prefixwriter_test.go b/iofmt/prefixwriter_test.go new file mode 100644 index 00000000..81b873e8 --- /dev/null +++ b/iofmt/prefixwriter_test.go @@ -0,0 +1,28 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package iofmt + +import ( + "bytes" + "io" + "testing" +) + +func TestPrefixWriter(t *testing.T) { + var b bytes.Buffer + w := PrefixWriter(&b, "prefix: ") + io.WriteString(w, "hello") + io.WriteString(w, "\nworld\n\n") + io.WriteString(w, "another\ntest\nthere\n") + if got, want := b.String(), `prefix: hello +prefix: world +prefix: +prefix: another +prefix: test +prefix: there +`; got != want { + t.Errorf("got %q, want %q", got, want) + } +} diff --git a/limitbuf/limitbuf.go b/limitbuf/limitbuf.go new file mode 100644 index 00000000..3a5ea101 --- /dev/null +++ b/limitbuf/limitbuf.go @@ -0,0 +1,87 @@ +package limitbuf + +import ( + "fmt" + "strings" + + "github.com/grailbio/base/log" +) + +type ( + // Logger is like strings.Builder, but with maximum length. If the caller tries + // to add data beyond the capacity, they will be dropped, and Logger.String() + // will append "(truncated)" at the end. + // + // TODO: Consider renaming to Builder or Buffer since this type's behavior + // is analogous to those. + Logger struct { + maxLen int + b strings.Builder + + // seen counts the total number of bytes passed to Write. + seen int64 + logIfTruncatingMaxMultiple float64 + } + + // LoggerOption is passed to NewLogger to configure a Logger. + LoggerOption func(*Logger) +) + +// LogIfTruncatingMaxMultiple controls informative logging about how much data +// passed to Write has been truncated. +// +// If zero, this logging is disabled. Otherwise, if the sum of len(data) +// passed to prior Write calls is greater than LogIfTruncatingMaxMultiple * +// maxLen (passed to NewLogger), a log message is written in the next call +// to String(). After logging, LogIfTruncatingMaxMultiple is set to zero +// to avoid repeating the same message. +// +// This can be a useful diagnostic for both CPU and memory usage if a huge +// amount of data is written and only a tiny fraction is used. For example, +// if a caller writes to the log with fmt.Fprint(logger, ...) they may +// not realize that fmt.Fprint* actually buffers the *entire* formatted +// string in memory first, then writes it to logger. +// TODO: Consider serving the fmt use case better for e.g. bigslice. +// +// Note that the log message is written to log.Error, not the Logger itself +// (it's not part of String's return). +func LogIfTruncatingMaxMultiple(m float64) LoggerOption { + return func(l *Logger) { l.logIfTruncatingMaxMultiple = m } +} + +// NewLogger creates a new Logger object with the given capacity. +func NewLogger(maxLen int, opts ...LoggerOption) *Logger { + l := Logger{maxLen: maxLen} + for _, opt := range opts { + opt(&l) + } + return &l +} + +// Write implements io.Writer interface. +func (b *Logger) Write(data []byte) (int, error) { + n := b.maxLen - b.b.Len() + if n > len(data) { + n = len(data) + } + if n > 0 { + b.b.Write(data[:n]) + } + b.seen += int64(len(data)) + return len(data), nil +} + +// String reports the data written so far. If the length of the data exceeds the +// buffer capacity, the prefix of the data, plus "(truncated)" will be reported. +func (b *Logger) String() string { + if b.seen <= int64(b.maxLen) { + return b.b.String() + } + // Truncated. + if b.logIfTruncatingMaxMultiple > 0 && + b.seen > int64(float64(b.maxLen)*b.logIfTruncatingMaxMultiple) { + b.logIfTruncatingMaxMultiple = 0 + log.Errorf("limitbuf: extreme truncation: %d -> %d bytes", b.seen, b.maxLen) + } + return b.b.String() + fmt.Sprintf("(truncated %d bytes)", b.seen-int64(b.maxLen)) +} diff --git a/limitbuf/limitbuf_test.go b/limitbuf/limitbuf_test.go new file mode 100644 index 00000000..244e5a60 --- /dev/null +++ b/limitbuf/limitbuf_test.go @@ -0,0 +1,43 @@ +package limitbuf_test + +import ( + "bytes" + "testing" + + "github.com/grailbio/base/limitbuf" + "github.com/grailbio/base/log" + "github.com/grailbio/testutil/expect" +) + +func TestLogger(t *testing.T) { + l := limitbuf.NewLogger(10) + l.Write([]byte("blah")) + expect.EQ(t, l.String(), "blah") + l.Write([]byte("abcdefgh")) + expect.EQ(t, l.String(), "blahabcdef(truncated 2 bytes)") + expect.EQ(t, l.String(), "blahabcdef(truncated 2 bytes)") +} + +func TestLoggerExtremeTruncation(t *testing.T) { + oldOutputter := log.GetOutputter() + t.Cleanup(func() { log.SetOutputter(oldOutputter) }) + var outputter testOutputter + log.SetOutputter(&outputter) + + logger := limitbuf.NewLogger(2, limitbuf.LogIfTruncatingMaxMultiple(3)) + _, err := logger.Write([]byte("abcdefg")) + expect.NoError(t, err) + + expect.EQ(t, logger.String(), "ab(truncated 5 bytes)") + expect.HasSubstr(t, outputter.String(), "extreme truncation") +} + +type testOutputter struct{ bytes.Buffer } + +func (o *testOutputter) Level() log.Level { + return log.Error +} +func (o *testOutputter) Output(_ int, _ log.Level, s string) error { + _, err := o.Buffer.WriteString(s) + return err +} diff --git a/limiter/batch.go b/limiter/batch.go new file mode 100644 index 00000000..a1caf7da --- /dev/null +++ b/limiter/batch.go @@ -0,0 +1,207 @@ +// Copyright 2021 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package limiter + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/grailbio/base/sync/ctxsync" + "golang.org/x/time/rate" +) + +// BatchLimiter provides the ability to batch calls and apply a rate limit (on the batches). +// Users have to provide an implementation of BatchApi and a rate.Limiter. +// Thereafter callers can concurrently Do calls for each individual ID and the BatchLimiter will +// batch calls (whenever appropriate) while respecting the rate limit. +// Individual requests are serviced in the order of submission. +type BatchLimiter struct { + api BatchApi + limiter *rate.Limiter + wait time.Duration + + mu sync.Mutex + // pending is the list of pending ids in the order of submission + pending []ID + // results maps each submitted ID to its result. + results map[ID]*Result +} + +// BatchApi needs to be implemented in order to use BatchLimiter. +type BatchApi interface { + // MaxPerBatch is the max number of ids to call per `Do` (zero implies no limit). + MaxPerBatch() int + + // Do the batch call with the given map of IDs to Results. + // The implementation must call Result.Set to provide the Value or Err (as applicable) for the every ID. + // At the end of this call, if Result.Set was not called on the result of a particular ID, + // the corresponding ID's `Do` call will get ErrNoResult. + Do(map[ID]*Result) +} + +// ID is the identifier of each call. +type ID interface{} + +// Result is the result of an API call for a given id. +type Result struct { + mu sync.Mutex + cond *ctxsync.Cond + id ID + value interface{} + err error + done bool + nWaiters int +} + +// Set sets the result of a given id with the given value v and error err. +func (r *Result) Set(v interface{}, err error) { + r.mu.Lock() + defer r.mu.Unlock() + r.done = true + r.value = v + r.err = err + r.cond.Broadcast() +} + +func (r *Result) doneC() <-chan struct{} { + r.mu.Lock() + return r.cond.Done() +} + +// NewBatchLimiter returns a new BatchLimiter which will call the given batch API +// as per the limits set by the given rate limiter. +func NewBatchLimiter(api BatchApi, limiter *rate.Limiter) *BatchLimiter { + eventsPerSecond := limiter.Limit() + if eventsPerSecond == 0 { + panic("limiter does not allow any events") + } + d := float64(time.Second) / float64(eventsPerSecond) + wait := time.Duration(d) + return &BatchLimiter{api: api, limiter: limiter, wait: wait, results: make(map[ID]*Result)} +} + +var ErrNoResult = fmt.Errorf("no result") + +// Do submits the given ID to the batch limiter and returns the result or an error. +// If the returned error is ErrNoResult, it indicates that the batch call did not produce any result for the given ID. +// Callers may then apply their own retry strategy if necessary. +// Do merges duplicate calls if the IDs are of a comparable type (and if the result is still pending) +// However, de-duplication is not guaranteed. +// Callers can avoid de-duplication by using a pointer type instead. +func (l *BatchLimiter) Do(ctx context.Context, id ID) (interface{}, error) { + var t *time.Timer + defer func() { + if t != nil { + t.Stop() + } + }() + r := l.register(id) + defer l.unregister(r) + for { + if done, v, err := l.get(r); done { + return v, err + } + if l.limiter.Allow() { + m := l.claim() + if len(m) > 0 { + l.api.Do(m) + l.update(m) + continue + } + } + // Wait half the interval to increase chances of making the next call as early as possible. + d := l.wait / 2 + if t == nil { + t = time.NewTimer(d) + } else { + t.Reset(d) + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-r.doneC(): + case <-t.C: + } + } +} + +// register registers the given id. +func (l *BatchLimiter) register(id ID) *Result { + l.mu.Lock() + defer l.mu.Unlock() + if _, ok := l.results[id]; !ok { + l.pending = append(l.pending, id) + r := &Result{id: id} + r.cond = ctxsync.NewCond(&r.mu) + l.results[id] = r + } + r := l.results[id] + r.mu.Lock() + r.nWaiters += 1 + r.mu.Unlock() + return r +} + +// unregister indicates that the calling goroutine is no longer interested in the given result. +func (l *BatchLimiter) unregister(r *Result) { + var remove bool + r.mu.Lock() + r.nWaiters -= 1 + remove = r.nWaiters == 0 + r.mu.Unlock() + if remove { + l.mu.Lock() + delete(l.results, r.id) + l.mu.Unlock() + } +} + +// get returns whether the result is done and the value and error. +func (l *BatchLimiter) get(r *Result) (bool, interface{}, error) { + r.mu.Lock() + defer r.mu.Unlock() + return r.done, r.value, r.err +} + +// update updates the internal results using the given ones. +// update also sets ErrNoResult as the error result for IDs for which `Result.Set` was not called. +func (l *BatchLimiter) update(results map[ID]*Result) { + for _, r := range results { + r.mu.Lock() + if !r.done { + r.done, r.err = true, ErrNoResult + } + r.mu.Unlock() + } +} + +// claim claims pending ids and returns a mapping of those ids to their results. +func (l *BatchLimiter) claim() map[ID]*Result { + l.mu.Lock() + defer l.mu.Unlock() + max := l.api.MaxPerBatch() + if max == 0 { + max = len(l.pending) + } + claimed := make(map[ID]*Result) + i := 0 + for ; i < len(l.pending) && len(claimed) < max; i++ { + id := l.pending[i] + r := l.results[id] + if r == nil { + continue + } + r.mu.Lock() + if !r.done { + claimed[id] = r + } + r.mu.Unlock() + } + // Remove the claimed ids from the pending list. + l.pending = l.pending[i:] + return claimed +} diff --git a/limiter/batch_test.go b/limiter/batch_test.go new file mode 100644 index 00000000..fd2abea3 --- /dev/null +++ b/limiter/batch_test.go @@ -0,0 +1,234 @@ +// Copyright 2021 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package limiter + +import ( + "context" + "fmt" + "strconv" + "sync" + "testing" + "time" + + "github.com/grailbio/base/traverse" + "golang.org/x/time/rate" +) + +type testBatchApi struct { + mu sync.Mutex + usePtr bool + maxPerBatch int + last time.Time + perBatchIds [][]string + durs []time.Duration + idSeenCount map[string]int +} + +func (a *testBatchApi) MaxPerBatch() int { return a.maxPerBatch } +func (a *testBatchApi) Do(results map[ID]*Result) { + a.mu.Lock() + defer a.mu.Unlock() + now := time.Now() + if a.last.IsZero() { + a.last = now + } + ids := make([]string, 0, len(results)) + for k, r := range results { + var id string + if a.usePtr { + id = *k.(*string) + } else { + id = k.(string) + } + ids = append(ids, id) + idSeenCount := a.idSeenCount[id] + i, err := strconv.Atoi(id) + if err != nil { + i = -1 + } + switch { + case shouldErr(i): + case i%2 == 0: + r.Set(nil, fmt.Errorf("failed_%s_count_%d", id, idSeenCount)) + default: + r.Set(fmt.Sprintf("value-%s", id), nil) + } + a.idSeenCount[id] = idSeenCount + 1 + } + a.perBatchIds = append(a.perBatchIds, ids) + a.durs = append(a.durs, now.Sub(a.last)) + a.last = now + return +} + +func TestSimple(t *testing.T) { + a := &testBatchApi{idSeenCount: make(map[string]int)} + l := NewBatchLimiter(a, rate.NewLimiter(rate.Every(time.Millisecond), 1)) + id := "test" + _, _ = l.Do(context.Background(), id) + if got, want := a.idSeenCount[id], 1; got != want { + t.Errorf("got %d, want %d", got, want) + } + _, _ = l.Do(context.Background(), id) + if got, want := a.idSeenCount[id], 2; got != want { + t.Errorf("got %d, want %d", got, want) + } +} + +func TestCtxCanceled(t *testing.T) { + a := &testBatchApi{idSeenCount: make(map[string]int)} + l := NewBatchLimiter(a, rate.NewLimiter(rate.Every(time.Second), 1)) + id1, id2 := "test1", "test2" + _, _ = l.Do(context.Background(), id1) + if got, want := a.idSeenCount[id1], 1; got != want { + t.Errorf("got %d, want %d", got, want) + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + _, _ = l.Do(ctx, id1) + }() + wg.Add(1) + go func() { + defer wg.Done() + _, _ = l.Do(context.Background(), id2) + }() + wg.Wait() + if got, want := a.idSeenCount[id1], 1; got != want { + t.Errorf("got %d, want %d", got, want) + } + if got, want := a.idSeenCount[id2], 1; got != want { + t.Errorf("got %d, want %d", got, want) + } +} + +func TestSometimesDedup(t *testing.T) { + const num = 5 + a := &testBatchApi{idSeenCount: make(map[string]int)} + l := NewBatchLimiter(a, rate.NewLimiter(rate.Every(10*time.Millisecond), num)) + id := "test" + a.mu.Lock() // Locks the batch API. + var done sync.WaitGroup + done.Add(num) + for i := 0; i < num; i++ { + go func() { + defer done.Done() + _, _ = l.Do(context.Background(), id) + }() + } + var allWaiting bool + for !allWaiting { + l.mu.Lock() + r := l.results[id] + l.mu.Unlock() + if r == nil { + time.Sleep(time.Millisecond) + continue + } + r.mu.Lock() + allWaiting = r.nWaiters == num + r.mu.Unlock() + } + a.mu.Unlock() // Unlock the batch API. + done.Wait() // Wait for all the goroutines on the same ID to complete + if got, want := a.idSeenCount[id], 1; got != want { + t.Errorf("got %d, want %d", got, want) + } +} + +func TestNoDedup(t *testing.T) { + a := &testBatchApi{usePtr: true, idSeenCount: make(map[string]int)} + l := NewBatchLimiter(a, rate.NewLimiter(rate.Every(10*time.Millisecond), 1)) + id := "test" + a.mu.Lock() // Locks the batch API. + var started, done sync.WaitGroup + started.Add(5) + done.Add(5) + for i := 0; i < 5; i++ { + go func() { + started.Done() + id := id + _, _ = l.Do(context.Background(), &id) + done.Done() + }() + } + started.Wait() // Wait for all the goroutines on the same ID to start + a.mu.Unlock() // Unlock the batch API. + done.Wait() // Wait for all the goroutines on the same ID to complete + if got, want := a.idSeenCount[id], 5; got != want { + t.Errorf("got %d, want %d", got, want) + } +} + +func TestDo(t *testing.T) { + testApi(t, &testBatchApi{idSeenCount: make(map[string]int)}, time.Second) +} + +func TestDoWithMax5(t *testing.T) { + testApi(t, &testBatchApi{maxPerBatch: 5, idSeenCount: make(map[string]int)}, 3*time.Second) +} + +func TestDoWithMax8(t *testing.T) { + testApi(t, &testBatchApi{maxPerBatch: 8, idSeenCount: make(map[string]int)}, 2*time.Second) +} + +type result struct { + v string + err error +} + +func shouldErr(i int) bool { + return i%5 == 0 && i%2 != 0 +} + +func testApi(t *testing.T, a *testBatchApi, timeout time.Duration) { + const numIds = 100 + var interval = 100 * time.Millisecond + l := NewBatchLimiter(a, rate.NewLimiter(rate.Every(interval), 1)) + var mu sync.Mutex + results := make(map[string]result) + _ = traverse.Each(numIds, func(i int) error { + time.Sleep(time.Duration(i*10) * time.Millisecond) + id := fmt.Sprintf("%d", i) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + v, err := l.Do(ctx, id) + mu.Lock() + r := result{err: err} + if r.err == nil { + r.v = v.(string) + } + results[id] = r + mu.Unlock() + return nil + }) + for i := 0; i < numIds; i++ { + id := fmt.Sprintf("%d", i) + if got, want := a.idSeenCount[id], 1; got != want { + t.Errorf("[%v] got %d, want %d", id, got, want) + } + if shouldErr(i) { + if got, want := results[id].err, ErrNoResult; got != want { + t.Errorf("[%d] got %v, want %v", i, got, want) + } + } + } + for _, dur := range a.durs[1:] { + if got, want, diff := dur, interval, (dur - interval).Round(5*time.Millisecond); diff < 0 { + t.Errorf("got %v, want %v, diff %v", got, want, diff) + } + } + for i, batchIds := range a.perBatchIds { + if want := a.maxPerBatch; want > 0 { + if got := len(batchIds); got > want { + t.Errorf("got %v, want <=%v", got, want) + } + } + t.Logf("batch %d (after %s): %v", i, a.durs[i].Round(time.Millisecond), batchIds) + } +} diff --git a/limiter/limiter.go b/limiter/limiter.go index b2c3584f..a3db91e6 100644 --- a/limiter/limiter.go +++ b/limiter/limiter.go @@ -14,6 +14,8 @@ import "context" // goroutine before proceeding. A limiter is not fair: tokens are not // granted in FIFO order; rather, waiters are picked randomly to be // granted new tokens. +// +// A nil limiter issues an infinite number of tokens. type Limiter struct { c chan int waiter chan struct{} @@ -29,6 +31,9 @@ func New() *Limiter { // Acquire blocks until the goroutine is granted the desired number // of tokens, or until the context is done. func (l *Limiter) Acquire(ctx context.Context, need int) error { + if l == nil { + return ctx.Err() + } select { case <-l.waiter: case <-ctx.Done(): @@ -56,6 +61,9 @@ func (l *Limiter) Acquire(ctx context.Context, need int) error { // Release adds a number of tokens back into the limiter. func (l *Limiter) Release(n int) { + if l == nil { + return + } if n == 0 { return } @@ -68,3 +76,8 @@ func (l *Limiter) Release(n int) { } } } + +type LimiterIfc interface { + Release(n int) + Acquire(ctx context.Context, need int) error +} diff --git a/limiter/limiter_test.go b/limiter/limiter_test.go index 503be801..04e1045a 100644 --- a/limiter/limiter_test.go +++ b/limiter/limiter_test.go @@ -23,7 +23,8 @@ func TestLimiter(t *testing.T) { if err := l.Acquire(context.Background(), 5); err != nil { t.Fatal(err) } - ctx, _ := context.WithTimeout(context.Background(), time.Second) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() if want, got := context.DeadlineExceeded, l.Acquire(ctx, 10); got != want { t.Fatalf("got %v, want %v", got, want) } @@ -43,7 +44,7 @@ func TestLimiterConcurrently(t *testing.T) { l.Release(T) var begin sync.WaitGroup begin.Add(N) - err := traverse.Each(N).Do(func(i int) error { + err := traverse.Each(N, func(i int) error { begin.Done() begin.Wait() n := rand.Intn(T) + 1 diff --git a/log/golog.go b/log/golog.go index 23496bb4..26004f33 100644 --- a/log/golog.go +++ b/log/golog.go @@ -9,13 +9,22 @@ import ( "fmt" "io" golog "log" + "runtime/debug" + "sync/atomic" ) var golevel = Info +var called int32 = 0 + // AddFlags adds a standard log level flags to the flag.CommandLine // flag set. func AddFlags() { + if atomic.AddInt32(&called, 1) != 1 { + Error.Printf("log.AddFlags: called twice!") + debug.PrintStack() + return + } flag.Var(new(logFlag), "log", "set log level (off, error, info, debug)") } @@ -47,6 +56,12 @@ func SetPrefix(prefix string) { golog.SetPrefix(prefix) } +// SetLevel sets the log level for the Go standard logger. +// It should be called once at the beginning of a program's main. +func SetLevel(level Level) { + golevel = level +} + type logFlag string func (f logFlag) String() string { diff --git a/log/log.go b/log/log.go index cc518146..cbe1b94c 100644 --- a/log/log.go +++ b/log/log.go @@ -50,6 +50,11 @@ func SetOutputter(newOut Outputter) Outputter { return old } +// GetOutputter returns the current outputter used by the log package. +func GetOutputter() Outputter { + return out +} + // At returns whether the logger is currently logging at the provided level. func At(level Level) bool { return level <= out.Level() @@ -103,7 +108,7 @@ func (l Level) String() string { // at level l to the current outputter. func (l Level) Print(v ...interface{}) { if At(l) { - out.Output(2, l, fmt.Sprint(v...)) + _ = out.Output(2, l, fmt.Sprint(v...)) } } @@ -111,7 +116,7 @@ func (l Level) Print(v ...interface{}) { // it at level l to the current outputter. func (l Level) Println(v ...interface{}) { if At(l) { - out.Output(2, l, fmt.Sprintln(v...)) + _ = out.Output(2, l, fmt.Sprintln(v...)) } } @@ -119,7 +124,7 @@ func (l Level) Println(v ...interface{}) { // it at level l to the current outputter. func (l Level) Printf(format string, v ...interface{}) { if At(l) { - out.Output(2, l, fmt.Sprintf(format, v...)) + _ = out.Output(2, l, fmt.Sprintf(format, v...)) } } @@ -127,7 +132,7 @@ func (l Level) Printf(format string, v ...interface{}) { // and outputs it at the Info level to the current outputter. func Print(v ...interface{}) { if At(Info) { - out.Output(2, Info, fmt.Sprint(v...)) + _ = out.Output(2, Info, fmt.Sprint(v...)) } } @@ -135,15 +140,21 @@ func Print(v ...interface{}) { // and outputs it at the Info level to the current outputter. func Printf(format string, v ...interface{}) { if At(Info) { - out.Output(2, Info, fmt.Sprintf(format, v...)) + _ = out.Output(2, Info, fmt.Sprintf(format, v...)) } } +// Errorf formats a message in the manner of fmt.Sprintf +// and outputs it at the Error level to the current outputter. +func Errorf(format string, v ...interface{}) { + _ = out.Output(2, Error, fmt.Sprintf(format, v...)) +} + // Fatal formats a message in the manner of fmt.Sprint, outputs it at // the error level to the current outputter and then calls // os.Exit(1). func Fatal(v ...interface{}) { - out.Output(2, Error, fmt.Sprint(v...)) + _ = out.Output(2, Error, fmt.Sprint(v...)) os.Exit(1) } @@ -151,7 +162,7 @@ func Fatal(v ...interface{}) { // the error level to the current outputter and then calls // os.Exit(1). func Fatalf(format string, v ...interface{}) { - out.Output(2, Error, fmt.Sprintf(format, v...)) + _ = out.Output(2, Error, fmt.Sprintf(format, v...)) os.Exit(1) } @@ -159,7 +170,7 @@ func Fatalf(format string, v ...interface{}) { // at the error level to the current outputter and then panics. func Panic(v ...interface{}) { s := fmt.Sprint(v...) - out.Output(2, Error, s) + _ = out.Output(2, Error, s) panic(s) } @@ -167,6 +178,12 @@ func Panic(v ...interface{}) { // at the error level to the current outputter and then panics. func Panicf(format string, v ...interface{}) { s := fmt.Sprintf(format, v...) - out.Output(2, Error, s) + _ = out.Output(2, Error, s) panic(s) } + +// Outputf is formats a message using fmt.Sprintf and outputs it +// to the provided logger at the provided level. +func Outputf(out Outputter, level Level, format string, v ...interface{}) { + _ = out.Output(2, level, fmt.Sprintf(format, v...)) +} diff --git a/log/log_test.go b/log/log_test.go index cdc1e34c..6bf58e34 100644 --- a/log/log_test.go +++ b/log/log_test.go @@ -67,7 +67,7 @@ func TestLog(t *testing.T) { } } -func ExampleDefault() { +func Example() { log.SetOutput(os.Stdout) log.SetFlags(0) log.Print("hello, world!") diff --git a/logio/logio.go b/logio/logio.go new file mode 100644 index 00000000..562d7c85 --- /dev/null +++ b/logio/logio.go @@ -0,0 +1,75 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +// Package logio implements a failure-tolerant log, typically used as +// a write-ahead log. Logs are "history oblivious": new log entries +// do not depend on previous entries; and logs may be concatenated on +// block boundaries while preserving integrity. Likewise, logs may be +// read from a stream without seeking. +// +// Data layout +// +// Logio follows the leveldb log format [1] with some modifications +// to permit efficient re-syncing from the end of a log, as well as +// to use a modern checksum algorithm (xxhash). +// +// A log file is a sequence of 32kB blocks, each containing a sequence +// of records and possibly followed by padding. Records may not span +// blocks; log entries that would straddle block boundaries are broken +// up into multiple records, to be reassembled at read time. +// +// block := record* padding? +// +// record := +// checksum uint32 // xxhash[2] checksum of the remainder of the record +// type uint8 // the record type, detailed below +// length uint16 // the length of the record data, below +// offset uint64 // the offset (in bytes) of this record from the record that begins the entry +// data [length]uint8 // the record data +// +// The record types are as follows: +// +// FULL=1 // the record contains the full entry +// FIRST=2 // the record is the first in an assembly +// MIDDLE=3 // the record is in the middle of an assembly +// LAST=4 // the record concludes an assembly +// +// Thus, entries are assembled by reading a sequence of records: +// +// entry := +// FULL +// | FIRST MIDDLE* LAST +// +// Failure tolerance +// +// Logio recovers from record corruption (e.g., checksum errors) and truncated +// writes by re-syncing at read time. If a corrupt record is encountered, the +// reader skips to the next block boundary (which always begins a record) and +// finds the first FULL or FIRST record to re-commence reading. +// +// [1] https://github.com/google/leveldb/blob/master/doc/log_format.md +// [2] http://cyan4973.github.io/xxHash/ +package logio + +import ( + "encoding/binary" +) + +// Blocksz is the size of the blocks written to the log files +// produced by this package. See package docs for a detailed +// description. +const Blocksz = 32 << 10 + +const headersz = 4 + 1 + 2 + 8 + +var byteOrder = binary.LittleEndian + +var zeros = make([]byte, Blocksz) + +const ( + recordFull uint8 = 1 + iota + recordFirst + recordMiddle + recordLast +) diff --git a/logio/logio_test.go b/logio/logio_test.go new file mode 100644 index 00000000..a8b9661b --- /dev/null +++ b/logio/logio_test.go @@ -0,0 +1,193 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package logio + +import ( + "bytes" + "fmt" + "io" + "math/rand" + "testing" +) + +func TestLogIO(t *testing.T) { + sizes := records(100, 128<<10) + // Make sure that there are some "tricky" sizes in here, to exercise + // all of the code paths. + sizes[10] = Blocksz // doesn't fit by a small margin + sizes[11] = Blocksz - headersz // exact fit + sizes[12] = Blocksz - headersz - 2 // underfit by less than headersz; next entry requires padding + + var ( + buf bytes.Buffer + scratch []byte + ) + w := NewWriter(&buf, 0) + for _, sz := range sizes { + scratch = data(scratch, sz) + must(t, w.Append(scratch)) + } + + r := NewReader(&buf, 0) + for i, sz := range sizes { + t.Logf("record %d size %d", i, sz) + rec, err := r.Read() + mustf(t, err, "record %d (size %d)", i, sz) + if got, want := len(rec), sz; got != want { + t.Errorf("got %v, want %v", got, want) + } + mustData(t, rec) + } + mustEOF(t, r) +} + +func TestResync(t *testing.T) { + var ( + sizes = records(20, 128<<10) + scratch []byte + buf bytes.Buffer + ) + w := NewWriter(&buf, 0) + for _, sz := range sizes { + scratch = data(scratch, sz) + must(t, w.Append(scratch)) + } + buf.Bytes()[1]++ + r := NewReader(&buf, 0) + var i int + for i = range sizes { + rec, err := r.Read() + if err == ErrCorrupted { + break + } + must(t, err) + if got, want := len(rec), sizes[i]; got != want { + t.Errorf("got %v, want %v", got, want) + } + mustData(t, rec) + } + if i == len(sizes) { + t.Fatal("corrupted record not detected") + } + rec, err := r.Read() + mustf(t, err, "failed to recover from corrupted record") + mustData(t, rec) + j := i + for ; i < len(sizes); i++ { + if len(rec) == sizes[i] { + i++ + break + } + } + if i == len(sizes) { + t.Fatal("failed to resync") + } + t.Logf("skipped %d records", i-j) + for ; i < len(sizes); i++ { + rec, err := r.Read() + mustf(t, err, "record %d/%d", i, len(sizes)) + mustData(t, rec) + } + mustEOF(t, r) +} + +func TestRewind(t *testing.T) { + sizes := records(50, 128<<10) + var ( + buf bytes.Buffer + scratch []byte + ) + w := NewWriter(&buf, 0) + for _, sz := range sizes { + scratch = data(scratch, sz) + must(t, w.Append(scratch)) + } + var ( + rd = bytes.NewReader(buf.Bytes()) + off = int64(rd.Len()) + ) + for n := 1; n <= 10; n++ { + var err error + off, err = Rewind(rd, off) + must(t, err) + // Check that Rewind also seeked rd to the correct offset. + seekPos, err := rd.Seek(0, io.SeekCurrent) + must(t, err) + if got, want := off, seekPos; got != want { + t.Fatalf("got %v, want %v", got, want) + } + r := NewReader(rd, off) + for i, sz := range sizes[len(sizes)-n:] { + rec, err := r.Read() + must(t, err) + if got, want := len(rec), sz; got != want { + t.Fatalf("%d,%d: got %v, want %v", n, i, got, want) + } + mustData(t, rec) + } + mustEOF(t, r) + } +} + +func records(n, max int) []int { + if n > max { + panic("n > max") + } + var ( + recs = make([]int, n) + stride = max / n + ) + for i := range recs { + recs[i] = 1 + stride*i + } + r := rand.New(rand.NewSource(int64(n + max))) + r.Shuffle(n, func(i, j int) { recs[i], recs[j] = recs[j], recs[i] }) + return recs +} + +func data(scratch []byte, n int) []byte { + if n <= cap(scratch) { + scratch = scratch[:n] + } else { + scratch = make([]byte, n) + } + r := rand.New(rand.NewSource(int64(n))) + for i := range scratch { + scratch[i] = byte(r.Intn(256)) + } + return scratch +} + +func mustData(t *testing.T, b []byte) { + t.Helper() + r := rand.New(rand.NewSource(int64(len(b)))) + for i := range b { + if got, want := int(b[i]), r.Intn(256); got != want { + t.Fatalf("byte %d: got %v, want %v", i, got, want) + } + } +} + +func mustEOF(t *testing.T, r *Reader) { + t.Helper() + _, err := r.Read() + if got, want := err, io.EOF; got != want { + t.Errorf("got %v, want %v", got, want) + } +} + +func must(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } +} + +func mustf(t *testing.T, err error, format string, v ...interface{}) { + t.Helper() + if err != nil { + t.Fatalf("%s: %v", fmt.Sprintf(format, v...), err) + } +} diff --git a/logio/reader.go b/logio/reader.go new file mode 100644 index 00000000..df37f16b --- /dev/null +++ b/logio/reader.go @@ -0,0 +1,306 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package logio + +import ( + "errors" + "fmt" + "io" +) + +// ErrCorrupted is returned when log file corruption is detected. +var ErrCorrupted = errors.New("corrupted log file") + +// Reader reads entries from a log file. +type Reader struct { + rd io.Reader + off int64 + + needResync bool + + block block +} + +// NewReader returns a log file reader that reads log entries from +// the provider io.Reader. The offset must be the current offset of +// the io.Reader into the IO stream from which records are read. +func NewReader(r io.Reader, offset int64) *Reader { + return &Reader{rd: r, off: offset} +} + +// Read returns the next log entry. It returns ErrCorrupted if a +// corrupted log entry was encountered, in which case the next call +// to Read will re-sync the log file, potentially skipping entries. +// The returned slice should not be modified and is only valid until +// the next call to Read or Rewind. +func (r *Reader) Read() (data []byte, err error) { + if r.needResync { + if err := r.resync(); err != nil { + return nil, err + } + r.needResync = false + } + for first := true; ; first = false { + if r.block.eof() { + err := r.block.read(r.rd, &r.off) + if err == io.EOF && !first { + return nil, io.ErrUnexpectedEOF + } else if err != nil { + return nil, err + } + } + record, ok := r.block.next() + switch record.typ { + case recordFull, recordFirst: + ok = ok && first + case recordMiddle, recordLast: + ok = ok && !first + } + if !ok { + r.needResync = true + return nil, ErrCorrupted + } + switch record.typ { + case recordFull: + return record.data, nil + case recordFirst: + data = append([]byte{}, record.data...) + case recordMiddle: + data = append(data, record.data...) + case recordLast: + return append(data, record.data...), nil + } + } +} + +// Reset resets the reader's state; subsequent entries are +// read from the provided reader at the provided offset. +func (r *Reader) Reset(rd io.Reader, offset int64) { + *r = Reader{rd: rd, off: offset} +} + +func (r *Reader) resync() error { + for { + if err := r.block.read(r.rd, &r.off); err != nil { + return err + } + for { + record, ok := r.block.peek() + if !ok { + break + } + if record.typ == recordFirst || record.typ == recordFull { + return nil + } + r.block.next() + } + } +} + +// Rewind finds and returns the offset of the last log entry in the +// log file represented by the reader r. The provided limit is the +// offset of the end of the log stream; thus Rewind may be used to +// traverse a log file in the backwards direction (error handling is +// left as an exercise to the reader): +// +// file, err := os.Open(...) +// info, err := file.Stat() +// off := info.Size() +// for { +// off, err = logio.Rewind(file, off) +// if err == io.EOF { +// break +// } +// file.Seek(off, io.SeekStart) +// record, err := logio.NewReader(file, off).Read() +// } +// +// Rewind returns io.EOF when no records can be located in the +// reader limited by the provided limit. +// +// If the passed reader is also an io.Seeker, then Rewind will seek +// to the returned offset. +func Rewind(r io.ReaderAt, limit int64) (off int64, err error) { + if s, ok := r.(io.Seeker); ok { + defer func() { + if err != nil { + return + } + off, err = s.Seek(off, io.SeekStart) + }() + } + + if limit <= headersz { + return 0, io.EOF + } + off = limit - limit%Blocksz + // Special case: if the limit is on a block boundary, we begin by rewinding + // to the previous block. + if off == limit { + off -= Blocksz + } + for ; off >= 0; off -= Blocksz { + var b block + off -= off % Blocksz + if err = b.readLimit(r, off, limit); err != nil { + return + } + + // Find the last valid record in the block. + var last record + for { + r, ok := b.next() + if !ok { + break + } + last = r + } + if last.isEmpty() { + // First record was invalid; try previous block. + continue + } + + off += int64(last.blockOff) - int64(last.offset) + err = b.readLimit(r, off, limit) + if err != nil { + return + } + if r, ok := b.next(); ok && r.offset == 0 { + return + } + } + err = io.EOF + return +} + +type record struct { + blockOff int + + typ uint8 + offset uint64 + data []byte +} + +func (r record) String() string { + return fmt.Sprintf("record blockOff:%d typ:%d offset:%d data:%d", r.blockOff, r.typ, r.offset, len(r.data)) +} + +func (r record) isEmpty() bool { + return r.blockOff == 0 && r.typ == 0 && r.offset == 0 && r.data == nil +} + +type block struct { + buf [Blocksz]byte + off, limit int + parsed record + ok bool +} + +func (b *block) String() string { + return fmt.Sprintf("block off:%d limit:%d", b.off, b.limit) +} + +func (b *block) eof() bool { + return b.off >= b.limit-headersz && b.parsed.isEmpty() +} + +func (b *block) next() (record, bool) { + rec, ok := b.peek() + b.parsed = record{} + return rec, ok +} + +func (b *block) peek() (record, bool) { + if b.parsed.isEmpty() { + b.parsed, b.ok = b.parse() + } + return b.parsed, b.ok +} + +func (b *block) parse() (record, bool) { + if b.off >= b.limit-headersz { + return record{}, false + } + var r record + r.blockOff = b.off + chk := b.uint32() + r.typ = b.uint8() + length := b.uint16() + r.offset = b.uint64() + if int(length) > b.limit-b.off || checksum(b.buf[r.blockOff+4:r.blockOff+headersz+int(length)]) != chk { + return record{}, false + } + r.data = b.bytes(int(length)) + var ok bool + switch r.typ { + case recordFirst, recordFull: + ok = r.offset == 0 + default: + ok = r.offset != 0 + } + return r, ok +} + +func (b *block) read(r io.Reader, off *int64) error { + b.reset(Blocksz - int(*off%Blocksz)) + n, err := io.ReadFull(r, b.buf[:b.limit]) + if err == io.ErrUnexpectedEOF { + b.limit = n + err = nil + } + *off += int64(n) + return err +} + +func (b *block) readLimit(r io.ReaderAt, off, limit int64) error { + b.reset(Blocksz - int(off%Blocksz)) + if n := limit - off; n < int64(b.limit) { + b.limit = int(n) + } + if b.limit > len(b.buf) { + panic(off) + } + n, err := r.ReadAt(b.buf[:b.limit], off) + if err == io.EOF && n == b.limit && n < Blocksz { + err = nil + } + return err +} + +func (b *block) reset(limit int) { + b.parsed = record{} + b.off = 0 + b.limit = limit +} + +func (b *block) uint8() uint8 { + v := b.buf[b.off] + b.off++ + return uint8(v) +} + +func (b *block) uint16() uint16 { + v := byteOrder.Uint16(b.buf[b.off:]) + b.off += 2 + return v +} + +func (b *block) uint32() uint32 { + v := byteOrder.Uint32(b.buf[b.off:]) + b.off += 4 + return v +} + +func (b *block) uint64() uint64 { + v := byteOrder.Uint64(b.buf[b.off:]) + b.off += 8 + return v +} + +func (b *block) bytes(n int) []byte { + p := b.buf[b.off : b.off+n] + b.off += n + return p +} diff --git a/logio/writer.go b/logio/writer.go new file mode 100644 index 00000000..43d9523b --- /dev/null +++ b/logio/writer.go @@ -0,0 +1,127 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package logio + +import ( + "io" + + xxhash "github.com/cespare/xxhash/v2" +) + +// Append writes an entry to the io.Writer w. The writer must be +// positioned at the provided offset. If non-nil, Append will use the +// scratch buffer for working space, avoiding additional allocation. +// The scratch buffer must be at least Blocksz. +func Append(w io.Writer, off int64, data, scratch []byte) (nwrite int, err error) { + if n := off % Blocksz; n > 0 && n < headersz { + // Corrupted file: skip to the next block boundary. + // + // TODO(marius): make sure that in the case of append failure + // that occurs at the end of the file, that we can recover without + // exposing the error to the user. + n, err := w.Write(zeros[:Blocksz-n]) + nwrite += n + if err != nil { + return nwrite, err + } + } else if left := Blocksz - n; left <= headersz { + // Need padding. + n, err := w.Write(zeros[:left]) + nwrite += n + if err != nil { + return nwrite, err + } + } + for base := nwrite; len(data) > 0; { + n := Blocksz - int(off+int64(nwrite))%Blocksz + n -= headersz + var typ uint8 + switch { + case len(data) <= n && nwrite == base: + typ, n = recordFull, len(data) + case len(data) <= n: + typ, n = recordLast, len(data) + case nwrite == base: + typ = recordFirst + default: + typ = recordMiddle + } + scratch = appendRecord(scratch[:0], typ, uint64(nwrite-base), data[:n]) + data = data[n:] + n, err = w.Write(scratch) + nwrite += n + if err != nil { + return nwrite, err + } + } + return nwrite, nil +} + +// Aligned aligns the provided offset for the next write: it returns +// the offset at which the next record will be written, if a writer +// with the provided offset is provided to Append. This can be used +// to index into logio files. +func Aligned(off int64) int64 { + if n := int64(Blocksz - off%Blocksz); n <= headersz { + return off + n + } + return off +} + +// A Writer appends to a log file. Writers are thin stateful wrappers +// around Append. +type Writer struct { + wr io.Writer + off int64 + scratch []byte +} + +// NewWriter returns a new writer that appends log entries to the +// provided io.Writer. The offset given must be the offset into the +// underlying IO stream represented by wr. +func NewWriter(wr io.Writer, offset int64) *Writer { + return &Writer{wr: wr, off: offset, scratch: make([]byte, Blocksz)} +} + +// Append appends a new entry to the log file. Appending an empty +// record is a no-op. Note that the writer appends only appends to +// the underlying stream. It is the responsibility of the caller to +// ensure that the writes are committed to stable storage (e.g., by +// calling file.Sync). +func (w *Writer) Append(data []byte) error { + n, err := Append(w.wr, w.off, data, w.scratch) + w.off += int64(n) + return err +} + +// Tell returns the offset of the next record to be appended. +// This may be used to index into the log file. +func (w *Writer) Tell() int64 { + return Aligned(w.off) +} + +// appendRecord appends a record, specified by typ, offset, and data, to p. p +// must have enough capacity for the record. +func appendRecord(p []byte, typ uint8, offset uint64, data []byte) []byte { + off := len(p) + p = p[:off+headersz+len(data)] + p[off+4] = typ + byteOrder.PutUint16(p[off+5:], uint16(len(data))) + byteOrder.PutUint64(p[off+7:], offset) + copy(p[off+15:], data) + byteOrder.PutUint32(p[off:], checksum(p[off+4:])) + return p +} + +func (w *Writer) write(p []byte) error { + n, err := w.wr.Write(p) + w.off += int64(n) + return err +} + +func checksum(data []byte) uint32 { + h := xxhash.Sum64(data) + return uint32(h<<32) ^ uint32(h) +} diff --git a/mapio/block.go b/mapio/block.go new file mode 100644 index 00000000..2825cf23 --- /dev/null +++ b/mapio/block.go @@ -0,0 +1,212 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package mapio + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "hash/crc32" + "sort" +) + +const ( + maxBlockHeaderSize = binary.MaxVarintLen32 + // sharedSize + binary.MaxVarintLen32 + // unsharedSize + binary.MaxVarintLen32 // valueSize + + blockMinTrailerSize = 4 + // restart count + 1 + // block type + 4 // crc32 (IEEE) checksum of contents +) + +var order = binary.LittleEndian + +// A blockBuffer is a writable block buffer. +type blockBuffer struct { + bytes.Buffer + + lastKey []byte + + restartInterval int + restarts []int + restartCount int +} + +// Append appends the provided entry to the block. Must be called +// in lexicographic order of keys, or else Append panics. +func (b *blockBuffer) Append(key, value []byte) { + if bytes.Compare(key, b.lastKey) < 0 { + panic("keys added out of order") + } + var shared int + if b.restartCount < b.restartInterval { + n := len(b.lastKey) + if len(key) < n { + n = len(key) + } + for shared = 0; shared < n; shared++ { + if key[shared] != b.lastKey[shared] { + break + } + } + b.restartCount++ + } else { + b.restartCount = 0 + b.restarts = append(b.restarts, b.Len()) + } + + if b.lastKey == nil || cap(b.lastKey) < len(key) { + b.lastKey = make([]byte, len(key)) + } else { + b.lastKey = b.lastKey[:len(key)] + } + copy(b.lastKey[shared:], key[shared:]) + + var hd [maxBlockHeaderSize]byte + var pos int + pos += binary.PutUvarint(hd[pos:], uint64(shared)) + pos += binary.PutUvarint(hd[pos:], uint64(len(key)-shared)) + pos += binary.PutUvarint(hd[pos:], uint64(len(value))) + + b.Write(hd[:pos]) + b.Write(key[shared:]) + b.Write(value) +} + +// Finish completes the block by adding the block trailer. +func (b *blockBuffer) Finish() { + b.Grow(4*(len(b.restarts)+1) + 1 + 4 + 4) + var ( + pback [4]byte + p = pback[:] + ) + if b.Buffer.Len() > 0 { + // Add restart points. Zero is always a restart point (if block is nonempty). + order.PutUint32(p, 0) + b.Write(p) + for _, off := range b.restarts { + order.PutUint32(p, uint32(off)) + b.Write(p) + } + order.PutUint32(p, uint32(len(b.restarts)+1)) + } else { + order.PutUint32(p, 0) + } + b.Write(p) + b.WriteByte(0) // zero type. reserved. + order.PutUint32(p, crc32.ChecksumIEEE(b.Bytes())) + b.Write(p) +} + +// Reset resets the contents of this block. After a call to reset, +// the blockBuffer instance may be used to write a new block. +func (b *blockBuffer) Reset() { + b.lastKey = nil + b.restarts = nil + b.restartCount = 0 + b.Buffer.Reset() +} + +// A block is an in-memory representation of a single block. Blocks +// maintain a current offset from which entries are scanned. +type block struct { + p []byte + nrestart int + restarts []byte + + key, value []byte + off, prevOff int +} + +// Init initializes the block from the block contents stored at b.p. +// Init returns an error if the block is malformed or corrupted. +func (b *block) init() error { + if len(b.p) < blockMinTrailerSize { + return errors.New("invalid block: too small") + } + if got, want := crc32.ChecksumIEEE(b.p[:len(b.p)-4]), order.Uint32(b.p[len(b.p)-4:]); got != want { + return fmt.Errorf("invalid checksum: expected %x, got %v", want, got) + } + off := len(b.p) - blockMinTrailerSize + b.nrestart = int(order.Uint32(b.p[off:])) + if b.nrestart*4 > off { + return errors.New("corrupt block") + } + b.restarts = b.p[off-4*b.nrestart : off] + if btype := b.p[off+4]; btype != 0 { + return fmt.Errorf("invalid block type %d", btype) + } + b.p = b.p[:off-4*b.nrestart] + b.key = nil + b.value = nil + b.off = 0 + b.prevOff = 0 + return nil +} + +// Seek sets the block to the first position for which key <= b.Key(). +func (b *block) Seek(key []byte) { + restart := sort.Search(b.nrestart, func(i int) bool { + b.off = int(order.Uint32(b.restarts[i*4:])) + if !b.Scan() { + panic("corrupt block") + } + return bytes.Compare(key, b.Key()) <= 0 + }) + if restart == 0 { + // No more work needed. key <= the first key in the block. + b.off = 0 + return + } + b.off = int(order.Uint32(b.restarts[(restart-1)*4:])) + for b.Scan() { + if bytes.Compare(key, b.Key()) <= 0 { + b.unscan() + break + } + } +} + +// Scan reads the entry at the current position and then advanced the +// block's position to the next entry. Scan returns false when the +// position is at or beyond the end of the block. +func (b *block) Scan() bool { + if b.off >= len(b.p) { + return false + } + b.prevOff = b.off + nshared, n := binary.Uvarint(b.p[b.off:]) + b.off += n + nunshared, n := binary.Uvarint(b.p[b.off:]) + b.off += n + nvalue, n := binary.Uvarint(b.p[b.off:]) + b.off += n + b.key = append(b.key[:nshared], b.p[b.off:b.off+int(nunshared)]...) + b.off += int(nunshared) + b.value = b.p[b.off : b.off+int(nvalue)] + b.off += int(nvalue) + return true +} + +func (b *block) unscan() { + b.off = b.prevOff +} + +// Key returns the key for the last scanned entry of the block. +func (b *block) Key() []byte { + return b.key +} + +// Value returns the value for the last scanned entry of the block. +func (b *block) Value() []byte { + return b.value +} + +func readBlock(p []byte) (*block, error) { + b := &block{p: p} + return b, b.init() +} diff --git a/mapio/block_test.go b/mapio/block_test.go new file mode 100644 index 00000000..ff5a452e --- /dev/null +++ b/mapio/block_test.go @@ -0,0 +1,136 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package mapio + +import ( + "bytes" + "math/rand" + "sort" + "testing" + + fuzz "github.com/google/gofuzz" +) + +type entry struct{ Key, Value []byte } + +func (e entry) Equal(f entry) bool { + return bytes.Compare(e.Key, f.Key) == 0 && + bytes.Compare(e.Value, f.Value) == 0 +} + +func makeEntries(n int) []entry { + fz := fuzz.New() + // fz.NumElements(1, 100) + entries := make([]entry, n) + // Fuzz manually so we can control the number of entries. + for i := range entries { + fz.Fuzz(&entries[i].Key) + fz.Fuzz(&entries[i].Value) + } + sortEntries(entries) + return entries +} + +func sortEntries(entries []entry) { + sort.Slice(entries, func(i, j int) bool { + x := bytes.Compare(entries[i].Key, entries[j].Key) + return x < 0 || x == 0 && bytes.Compare(entries[i].Value, entries[j].Value) < 0 + }) +} + +type seeker interface { + Seek(key []byte) Scanner +} + +func testSeeker(t *testing.T, entries []entry, seeker seeker) { + t.Helper() + + s := seeker.Seek(nil) + var scanned []entry + for s.Scan() { + var ( + key = make([]byte, len(s.Key())) + value = make([]byte, len(s.Value())) + ) + copy(key, s.Key()) + copy(value, s.Value()) + scanned = append(scanned, entry{key, value}) + } + if got, want := len(scanned), len(entries); got != want { + t.Fatalf("got %v, want %v", got, want) + } + isSorted := sort.SliceIsSorted(scanned, func(i, j int) bool { + return bytes.Compare(scanned[i].Key, scanned[j].Key) < 0 + }) + if !isSorted { + t.Error("scan returned non-sorted entries") + } + sortEntries(scanned) + sortEntries(entries) + for i := range scanned { + if !entries[i].Equal(scanned[i]) { + t.Errorf("scan: entry %d does not match", i) + } + } + + // Look up keys but in a random order. + for n, i := range rand.Perm(len(entries)) { + s := seeker.Seek(entries[i].Key) + for s.Scan() { + if bytes.Compare(entries[i].Key, s.Key()) != 0 { + t.Errorf("%d: did not find key for %d", n, i) + } + + // Since we may have multiple keys with the same value, + // we have to scan until we see our expected value. + if bytes.Compare(entries[i].Key, s.Key()) != 0 || + bytes.Compare(entries[i].Value, s.Value()) == 0 { + break + } + } + if !entries[i].Equal(entry{s.Key(), s.Value()}) { + t.Errorf("%d: seek: entry %d does not match %d", n, i, bytes.Compare(entries[i].Key, s.Key())) + } + } + + lastKey := entries[len(entries)-1].Key + bigKey := make([]byte, len(lastKey)+1) + copy(bigKey, lastKey) + s = seeker.Seek(bigKey) + if s.Scan() { + t.Error("scanned bigger key") + } +} + +type blockSeeker struct{ *block } + +func (b *blockSeeker) Seek(key []byte) Scanner { + b.block.Seek(key) + return b +} + +func (b *blockSeeker) Err() error { return nil } + +func TestBlock(t *testing.T) { + const N = 10000 + entries := makeEntries(N) + + buf := blockBuffer{restartInterval: 100} + for i := range entries { + buf.Append(entries[i].Key, entries[i].Value) + } + buf.Finish() + + block, err := readBlock(buf.Bytes()) + if err != nil { + t.Fatal(err) + } + + if got, want := block.nrestart, 100; got != want { + t.Errorf("got %v, want %v", got, want) + } + + testSeeker(t, entries, &blockSeeker{block}) +} diff --git a/mapio/buf.go b/mapio/buf.go new file mode 100644 index 00000000..01456a61 --- /dev/null +++ b/mapio/buf.go @@ -0,0 +1,56 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package mapio + +import ( + "bytes" + "sort" +) + +// A Buf is an unordered write buffer for maps. It holds entries +// in memory; these are then sorted and written to a map. +type Buf struct { + keys, values [][]byte + keySize, valueSize int +} + +// Append append the given entry to the buffer. +func (b *Buf) Append(key, value []byte) { + keyCopy := make([]byte, len(key)) + copy(keyCopy, key) + valueCopy := make([]byte, len(value)) + copy(valueCopy, value) + b.keys = append(b.keys, keyCopy) + b.values = append(b.values, valueCopy) + b.keySize += len(keyCopy) + b.valueSize += len(valueCopy) +} + +// Size returns the number size of this buffer in bytes. +func (b *Buf) Size() int { return b.keySize + b.valueSize } + +// Len implements sort.Interface +func (b *Buf) Len() int { return len(b.keys) } + +// Less implements sort.Interface +func (b *Buf) Less(i, j int) bool { return bytes.Compare(b.keys[i], b.keys[j]) < 0 } + +// Swap implements sort.Interface +func (b *Buf) Swap(i, j int) { + b.keys[i], b.keys[j] = b.keys[j], b.keys[i] + b.values[i], b.values[j] = b.values[j], b.values[i] +} + +// WriteTo sorts and then writes all of the entries in this buffer to +// the provided writer. +func (b *Buf) WriteTo(w *Writer) error { + sort.Sort(b) + for i := range b.keys { + if err := w.Append(b.keys[i], b.values[i]); err != nil { + return err + } + } + return nil +} diff --git a/mapio/buf_test.go b/mapio/buf_test.go new file mode 100644 index 00000000..98be0436 --- /dev/null +++ b/mapio/buf_test.go @@ -0,0 +1,38 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package mapio + +import ( + "bytes" + "math/rand" + "testing" +) + +func TestBuf(t *testing.T) { + const N = 1000 + entries := makeEntries(N) + // Shuffle to make sure the buffer sorts properly. + rand.Shuffle(len(entries), func(i, j int) { + entries[i], entries[j] = entries[j], entries[i] + }) + var buf Buf + for _, e := range entries { + buf.Append(e.Key, e.Value) + } + var b bytes.Buffer + w := NewWriter(&b, BlockSize(1<<10), RestartInterval(10)) + if err := buf.WriteTo(w); err != nil { + t.Fatal(err) + } + if err := w.Close(); err != nil { + t.Fatal(err) + } + m, err := New(bytes.NewReader(b.Bytes())) + if err != nil { + t.Fatal(err) + } + + testSeeker(t, entries, mapSeeker{m}) +} diff --git a/mapio/doc.go b/mapio/doc.go new file mode 100644 index 00000000..9cb38f95 --- /dev/null +++ b/mapio/doc.go @@ -0,0 +1,58 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +/* + Package mapio implements a sorted, on-disk map, similar to the + SSTable data structure used in Bigtable [1], Cassandra [2], and + others. Maps are read-only, and are produced by a Writer. Each + Writer expects keys to be appended in lexicographic order. Buf + provides a means of buffering writes to be sorted before appended to + a Writer. + + Mapio's on-disk layout loosely follows that of LevelDB [3]. Each Map + is a sequence of blocks; each block comprises a sequence of entries, + followed by a trailer: + + block := blockEntry* blockTrailer + blockEntry := + nshared: uvarint // number of bytes shared with previous key + nunshared: uvarint // number of new bytes in this entry's key + nvalue: uvarint // number of bytes in value + key: uint8[nunshared] // the (prefix compressed) key + value: uint8[nvalue] // the entry's value + blockTrailer := + restarts: uint32[nrestart] // array of key restarts + nrestart: uint32 // size of restart array + type: uint8 // block type (should be 0; reserved for future use) + crc32: uint32 // IEEE crc32 of contents and trailer + + Maps prefix compress each key by storing the number of bytes shared + with the previous key. Maps contain a number of restart points: + points at which the full key is specified (and nshared = 0). The + restart point are stored in an array in the block trailer. This + array can be used to perform binary search for keys. + + A Map is a sequence of data blocks, followed by an index block, + followed by a trailer. + + map := block(data)* block(meta)* block(index) mapTrailer + mapTrailer := + meta: blockAddr[20] // zero-padded address of the meta block index (tbd) + index: blockAddr[20] // zero-padded address of index + magic: uint64 // magic (0xa8b2374e8558bc76) + blockAddr := + offset: uvarint // offset of block in map + len: uvarint // length of block + + The index block contains one entry for each block in the map: each + entry's key is the last key in that block; the entry's value is a + blockAddr containing the position of that block. This arrangement + allows the reader to binary search the index block then search the + found block. + + [1] https://static.googleusercontent.com/media/research.google.com/en//archive/bigtable-osdi06.pdf + [2] https://www.cs.cornell.edu/projects/ladis2009/papers/lakshman-ladis2009.pdf + [3] https://github.com/google/leveldb +*/ +package mapio diff --git a/mapio/map.go b/mapio/map.go new file mode 100644 index 00000000..6a266047 --- /dev/null +++ b/mapio/map.go @@ -0,0 +1,120 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package mapio + +import ( + "errors" + "io" + "sync" +) + +// Map is a read-only, sorted map backed by an io.ReadSeeker. The +// on-disk layout of maps are described by the package documentation. +// Maps support both lookup and (ordered) iteration. A Map instance +// maintains a current position, starting out at the first entry. +type Map struct { + mu sync.Mutex + r io.ReadSeeker + index block +} + +// New opens the map at the provided io.ReadSeeker (usually a file). +func New(r io.ReadSeeker) (*Map, error) { + m := &Map{r: r} + return m, m.init() +} + +func (m *Map) init() error { + if _, err := m.r.Seek(-mapTrailerSize, io.SeekEnd); err != nil { + return err + } + trailer := make([]byte, mapTrailerSize) + if _, err := io.ReadFull(m.r, trailer); err != nil { + return err + } + metaAddr, _ := getBlockAddr(trailer) + if metaAddr != (blockAddr{}) { + return errors.New("non-empty meta block index") + } + indexAddr, _ := getBlockAddr(trailer[maxBlockAddrSize:]) + magic := order.Uint64(trailer[len(trailer)-8:]) + if magic != mapTrailerMagic { + return errors.New("wrong magic") + } + if err := m.readBlock(indexAddr, &m.index); err != nil { + return err + } + if !m.index.Scan() { + return errors.New("empty index") + } + return nil +} + +func (m *Map) readBlock(addr blockAddr, block *block) error { + if block.p != nil && cap(block.p) >= int(addr.len) { + block.p = block.p[:addr.len] + } else { + block.p = make([]byte, addr.len) + } + m.mu.Lock() + defer m.mu.Unlock() + if _, err := m.r.Seek(int64(addr.off), io.SeekStart); err != nil { + return err + } + if _, err := io.ReadFull(m.r, block.p); err != nil { + return err + } + return block.init() +} + +// Seek returns a map scanner beginning at the first key in the map +// >= the provided key. +func (m *Map) Seek(key []byte) *MapScanner { + s := &MapScanner{parent: m, index: m.index} + s.index.Seek(key) + if s.index.Scan() { + addr, _ := getBlockAddr(s.index.Value()) + if s.err = m.readBlock(addr, &s.data); s.err == nil { + s.data.Seek(key) + } + } + return s +} + +// MapScanner implements ordered iteration over a map. +type MapScanner struct { + parent *Map + err error + data, index block +} + +// Scan scans the next entry, returning true on success. When Scan +// returns false, the caller should inspect Err to distinguish +// between scan completion and scan error. +func (m *MapScanner) Scan() bool { + for m.err == nil && !m.data.Scan() { + if !m.index.Scan() { + return false + } + addr, _ := getBlockAddr(m.index.Value()) + m.err = m.parent.readBlock(addr, &m.data) + } + return m.err == nil +} + +// Err returns the last error encountered while scanning. +func (m *MapScanner) Err() error { + return m.err +} + +// Key returns the key that was last scanned. +func (m *MapScanner) Key() []byte { + return m.data.Key() +} + +// Value returns the value that was last scanned. +func (m *MapScanner) Value() []byte { + return m.data.Value() +} diff --git a/mapio/map_test.go b/mapio/map_test.go new file mode 100644 index 00000000..a720b998 --- /dev/null +++ b/mapio/map_test.go @@ -0,0 +1,62 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package mapio + +import ( + "bytes" + "testing" +) + +type mapSeeker struct{ *Map } + +func (m mapSeeker) Seek(key []byte) Scanner { return m.Map.Seek(key) } + +func TestMap(t *testing.T) { + const N = 15000 + entries := makeEntries(N) + + var b bytes.Buffer + w := NewWriter(&b, BlockSize(1024)) + for i := range entries { + if err := w.Append(entries[i].Key, entries[i].Value); err != nil { + t.Fatal(err) + } + } + if err := w.Close(); err != nil { + t.Fatal(err) + } + + m, err := New(bytes.NewReader(b.Bytes())) + if err != nil { + t.Fatal(err) + } + + testSeeker(t, entries, mapSeeker{m}) +} + +func TestEmptyMap(t *testing.T) { + var b bytes.Buffer + w := NewWriter(&b, BlockSize(1024)) + // Flush to get an extra (empty) block. + if err := w.Flush(); err != nil { + t.Fatal(err) + } + if err := w.Close(); err != nil { + t.Fatal(err) + } + + m, err := New(bytes.NewReader(b.Bytes())) + if err != nil { + t.Fatal(err) + } + + scan := m.Seek(nil) + if scan.Scan() { + t.Error("expected EOF") + } + if err := scan.Err(); err != nil { + t.Error(err) + } +} diff --git a/mapio/merged.go b/mapio/merged.go new file mode 100644 index 00000000..4bdbdda9 --- /dev/null +++ b/mapio/merged.go @@ -0,0 +1,105 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package mapio + +import ( + "bytes" + "container/heap" +) + +var scanSentinel = new(MapScanner) + +// Merged represents the merged contents of multiple underlying maps. +// Like Map, Merged presents a sorted, scannable map, but it does not +// guarantee that the order of traversal is stable. +type Merged []*Map + +// Seek returns a scanner for the merged map that starts at the first +// entry where entryKey <= key. +func (m Merged) Seek(key []byte) *MergedScanner { + merged := make(MergedScanner, 0, len(m)+1) + for i := range m { + s := m[i].Seek(key) + if !s.Scan() { + if err := s.Err(); err != nil { + return &MergedScanner{s} + } + // Otherwise it's just empty and we can skip it. + continue + } + merged = append(merged, s) + } + if len(merged) == 0 { + return &MergedScanner{} + } + heap.Init(&merged) + merged = append(merged, scanSentinel) + return &merged +} + +// MergedScanner is a scanner for merged maps. +type MergedScanner []*MapScanner + +// Len implements heap.Interface +func (m MergedScanner) Len() int { return len(m) } + +// Less implements heap.Interface +func (m MergedScanner) Less(i, j int) bool { return bytes.Compare(m[i].Key(), m[j].Key()) < 0 } + +// Swap implements heap.Interface +func (m MergedScanner) Swap(i, j int) { m[i], m[j] = m[j], m[i] } + +// Push implements heap.Interface +func (m *MergedScanner) Push(x interface{}) { + *m = append(*m, x.(*MapScanner)) +} + +// Pop implements heap.Interface +func (m *MergedScanner) Pop() interface{} { + n := len(*m) + elem := (*m)[n-1] + *m = (*m)[:n-1] + return elem +} + +// Scan scans the next entry in the merged map, returning true on +// success. If Scan returns false, the caller should check Err to +// distinguish between scan completion and scan error. +func (m *MergedScanner) Scan() bool { + if len(*m) == 0 || (*m)[0].err != nil { + return false + } + if len(*m) > 0 && (*m)[len(*m)-1] == scanSentinel { + *m = (*m)[:len(*m)-1] + return true + } + + if (*m)[0].Scan() { + heap.Fix(m, 0) + } else if (*m)[0].err == nil { + heap.Remove(m, 0) + } + ok := len(*m) > 0 && (*m)[0].err == nil + return ok + +} + +// Err returns the last error encountered while scanning, if any. +func (m MergedScanner) Err() error { + if len(m) == 0 { + return nil + } + return m[0].err +} + +// Key returns the last key scanned. +func (m MergedScanner) Key() []byte { + return m[0].Key() +} + +// Value returns the last value scanned. +func (m MergedScanner) Value() []byte { + return m[0].Value() +} diff --git a/mapio/merged_test.go b/mapio/merged_test.go new file mode 100644 index 00000000..6e3912ff --- /dev/null +++ b/mapio/merged_test.go @@ -0,0 +1,47 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package mapio + +import ( + "bytes" + "testing" +) + +type mergedSeeker struct{ Merged } + +func (m mergedSeeker) Seek(key []byte) Scanner { return m.Merged.Seek(key) } + +func TestMerged(t *testing.T) { + const ( + N = 10000 + M = 10 + ) + var ( + entries = makeEntries(N) + buffers = make([]bytes.Buffer, M) + writers = make([]*Writer, M) + ) + for i := range writers { + writers[i] = NewWriter(&buffers[i], BlockSize(1024)) + } + for i := range entries { + writers[i%M].Append(entries[i].Key, entries[i].Value) + } + for i := range writers { + if err := writers[i].Close(); err != nil { + t.Fatal(i, err) + } + } + merged := make(Merged, M) + for i := range buffers { + var err error + merged[i], err = New(bytes.NewReader(buffers[i].Bytes())) + if err != nil { + t.Fatal(i, err) + } + } + + testSeeker(t, entries, mergedSeeker{merged}) +} diff --git a/mapio/scanner.go b/mapio/scanner.go new file mode 100644 index 00000000..db3acc86 --- /dev/null +++ b/mapio/scanner.go @@ -0,0 +1,19 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package mapio + +// Scanner is an ordered iterator over map entries. +type Scanner interface { + // Scan scans the next entry, returning true on success, after which + // time the entry is available to inspect using the Key and Value + // methods. + Scan() bool + // Err returns the last error encountered while scanning, if any. + Err() error + // Key returns the key of the last scanned entry. + Key() []byte + // Value returns the value of the last scanned entry. + Value() []byte +} diff --git a/mapio/writer.go b/mapio/writer.go new file mode 100644 index 00000000..f4a998bd --- /dev/null +++ b/mapio/writer.go @@ -0,0 +1,159 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package mapio + +import ( + "encoding/binary" + "io" +) + +const ( + maxBlockAddrSize = binary.MaxVarintLen64 + // offset + binary.MaxVarintLen64 // len + + mapTrailerSize = maxBlockAddrSize + // meta block index (padded) + maxBlockAddrSize + // index address (padded) + 8 // magic + + mapTrailerMagic = 0xa8b2374e8558bc76 +) + +type blockAddr struct { + off uint64 + len uint64 +} + +func putBlockAddr(p []byte, b blockAddr) int { + off := binary.PutUvarint(p, b.off) + return off + binary.PutUvarint(p[off:], b.len) +} + +func getBlockAddr(p []byte) (b blockAddr, n int) { + var m int + b.off, n = binary.Uvarint(p) + b.len, m = binary.Uvarint(p[n:]) + n += m + return +} + +// A Writer appends key-value pairs to a map. Keys must be appended +// in lexicographic order. +type Writer struct { + data, index blockBuffer + w io.Writer + + lastKey []byte + + blockSize int + off int +} + +const ( + defaultBlockSize = 1 << 12 + defaultRestartInterval = 16 +) + +// WriteOption represents a tunable writer parameter. +type WriteOption func(*Writer) + +// BlockSize sets the writer's target block size to sz (in bytes). +// Note that keys and values cannot straddle blocks, so that if large +// data are added to a map, block sizes can grow large. The default +// target block size is 4KB. +func BlockSize(sz int) WriteOption { + return func(w *Writer) { + w.blockSize = sz + } +} + +// RestartInterval sets the writer's restart interval to +// provided value. The default restart interval is 16. +func RestartInterval(iv int) WriteOption { + return func(w *Writer) { + w.data.restartInterval = iv + w.index.restartInterval = iv + } +} + +// NewWriter returns a new Writer that writes a map to the provided +// io.Writer. BlockSize specifies the target block size, while +// restartInterval determines the frequency of key restart points, +// which trades off lookup performance with size. See package docs +// for more details. +func NewWriter(w io.Writer, opts ...WriteOption) *Writer { + wr := &Writer{ + w: w, + blockSize: defaultBlockSize, + } + wr.data.restartInterval = defaultRestartInterval + wr.index.restartInterval = defaultRestartInterval + for _, opt := range opts { + opt(wr) + } + return wr +} + +// Append appends an entry to the maps. Keys must be provided +// in lexicographic order. +func (w *Writer) Append(key, value []byte) error { + w.data.Append(key, value) + if w.lastKey == nil || cap(w.lastKey) < len(key) { + w.lastKey = make([]byte, len(key)) + } else { + w.lastKey = w.lastKey[:len(key)] + } + copy(w.lastKey, key) + if w.data.Len() > w.blockSize { + return w.Flush() + } + return nil +} + +// Flush creates a new block with the current contents. It forces the +// creation of a new block, and overrides the Writer's block size +// parameter. +func (w *Writer) Flush() error { + w.data.Finish() + n, err := w.w.Write(w.data.Bytes()) + if err != nil { + return err + } + w.data.Reset() + off := w.off + w.off += n + + // TODO(marius): we can get more clever about key compression here: + // We need to guarantee that the lastKey <= indexKey < firstKey, + // where firstKey is the first key in the next block. We can thus + // construct a more minimal key to store in the index. + b := make([]byte, maxBlockAddrSize) + n = putBlockAddr(b, blockAddr{uint64(off), uint64(n)}) + w.index.Append(w.lastKey, b[:n]) + + return nil +} + +// Close flushes the last block of the writer and writes the map's +// trailer. After successful close, the map is ready to be opened. +func (w *Writer) Close() error { + if err := w.Flush(); err != nil { + return err + } + w.index.Finish() + n, err := w.w.Write(w.index.Bytes()) + if err != nil { + return err + } + w.index.Reset() + indexAddr := blockAddr{uint64(w.off), uint64(n)} + w.off += n + + trailer := make([]byte, mapTrailerSize) + putBlockAddr(trailer, blockAddr{}) // address of meta block index. tbd. + putBlockAddr(trailer[maxBlockAddrSize:], indexAddr) + order.PutUint64(trailer[len(trailer)-8:], mapTrailerMagic) + _, err = w.w.Write(trailer) + return err +} diff --git a/morebufio/peekback.go b/morebufio/peekback.go new file mode 100644 index 00000000..5480887a --- /dev/null +++ b/morebufio/peekback.go @@ -0,0 +1,56 @@ +package morebufio + +import ( + "context" + + "github.com/grailbio/base/ioctx" +) + +// PeekBackReader is a Reader augmented with a function to "peek" backwards at the data that +// was already passed. Peeking does not change the current stream position (that is, PeekBack has +// no effect on the next Read). +type PeekBackReader interface { + ioctx.Reader + // PeekBack returns a fixed "window" of the data that Read has already returned, ending at + // the current Read position. It may be smaller at the start until enough data has been read, + // but after that it's constant size. PeekBack is allowed after Read returns EOF. + // The returned slice aliases the internal buffer and is invalidated by the next Read. + PeekBack() []byte +} + +type peekBackReader struct { + r ioctx.Reader + buf []byte +} + +// NewPeekBackReader returns a PeekBackReader. It doesn't have a "forward" buffer so small +// PeekBackReader.Read operations cause small r.Read operations. +func NewPeekBackReader(r ioctx.Reader, peekBackSize int) PeekBackReader { + return &peekBackReader{r, make([]byte, 0, peekBackSize)} +} + +func (p *peekBackReader) Read(ctx context.Context, dst []byte) (int, error) { + nRead, err := p.r.Read(ctx, dst) + dst = dst[:nRead] + // First, grow the peek buf until cap, since it starts empty. + if grow := cap(p.buf) - len(p.buf); grow > 0 { + if len(dst) < grow { + grow = len(dst) + } + p.buf = append(p.buf, dst[:grow]...) + dst = dst[grow:] + } + if len(dst) == 0 { + return nRead, err + } + // Shift data if any part of the peek buf is still valid. + updateTail := p.buf + if len(dst) < len(p.buf) { + n := copy(p.buf, p.buf[len(dst):]) + updateTail = p.buf[n:] + } + _ = copy(updateTail, dst[len(dst)-len(updateTail):]) + return nRead, err +} + +func (p *peekBackReader) PeekBack() []byte { return p.buf } diff --git a/morebufio/peekback_test.go b/morebufio/peekback_test.go new file mode 100644 index 00000000..d68af7d9 --- /dev/null +++ b/morebufio/peekback_test.go @@ -0,0 +1,131 @@ +package morebufio + +import ( + "context" + "io" + "strings" + "testing" + + "github.com/grailbio/base/ioctx" + "github.com/stretchr/testify/require" +) + +const digits = "0123456789" + +func TestPeekBack(t *testing.T) { + ctx := context.Background() + r := NewPeekBackReader(ioctx.FromStdReader(strings.NewReader(digits)), 4) + + // Initial read, smaller than peek buf. + b := make([]byte, 2) + n, err := r.Read(ctx, b) + require.NoError(t, err) + require.Equal(t, len(b), n) + require.Equal(t, "01", string(b)) + require.Equal(t, "01", string(r.PeekBack())) + + // Read enough to shift buf. + b = make([]byte, 3) + n, err = r.Read(ctx, b) + require.NoError(t, err) + require.Equal(t, len(b), n) + require.Equal(t, "234", string(b)) + require.Equal(t, "1234", string(r.PeekBack())) + + // Read nothing. + b = nil + n, err = r.Read(ctx, b) + require.NoError(t, err) + require.Equal(t, len(b), n) + require.Equal(t, "1234", string(r.PeekBack())) + + // Read past EOF. + b = make([]byte, 8) + n, err = r.Read(ctx, b) + if err != io.EOF { + require.NoError(t, err) + } + require.Equal(t, 5, n) + require.Equal(t, "56789", string(b[:n])) + require.Equal(t, "6789", string(r.PeekBack())) + + n, err = r.Read(ctx, b) + require.ErrorIs(t, err, io.EOF) + require.Equal(t, 0, n) + require.Equal(t, "6789", string(r.PeekBack())) +} + +func TestPeekBackLargeInitial(t *testing.T) { + ctx := context.Background() + r := NewPeekBackReader(ioctx.FromStdReader(strings.NewReader(digits)), 3) + + // Initial read, larger than peek buf. + b := make([]byte, 6) + n, err := r.Read(ctx, b) + require.NoError(t, err) + require.Equal(t, len(b), n) + require.Equal(t, "012345", string(b)) + require.Equal(t, "345", string(r.PeekBack())) + + // Shift. + b = make([]byte, 1) + n, err = r.Read(ctx, b) + require.NoError(t, err) + require.Equal(t, len(b), n) + require.Equal(t, "6", string(b)) + require.Equal(t, "456", string(r.PeekBack())) +} + +func TestPeekBackNeverFull(t *testing.T) { + ctx := context.Background() + r := NewPeekBackReader(ioctx.FromStdReader(strings.NewReader(digits)), 20) + + b := make([]byte, 6) + n, err := r.Read(ctx, b) + require.NoError(t, err) + require.Equal(t, len(b), n) + require.Equal(t, "012345", string(b)) + require.Equal(t, "012345", string(r.PeekBack())) + + b = make([]byte, 20) + n, err = r.Read(ctx, b) + if err != io.EOF { + require.NoError(t, err) + } + require.Equal(t, 4, n) + require.Equal(t, "6789", string(b[:n])) + require.Equal(t, "0123456789", string(r.PeekBack())) + + n, err = r.Read(ctx, b) + require.ErrorIs(t, err, io.EOF) + require.Equal(t, 0, n) + require.Equal(t, "0123456789", string(r.PeekBack())) +} + +func TestPeekBackZero(t *testing.T) { + ctx := context.Background() + r := NewPeekBackReader(ioctx.FromStdReader(strings.NewReader(digits)), 0) + + b := make([]byte, 6) + n, err := r.Read(ctx, b) + require.NoError(t, err) + require.Equal(t, len(b), n) + require.Equal(t, "012345", string(b)) + require.Equal(t, "", string(r.PeekBack())) + + b = make([]byte, 20) + n, err = r.Read(ctx, b) + if err != io.EOF { + require.NoError(t, err) + } + require.Equal(t, 4, n) + require.Equal(t, "6789", string(b[:n])) + require.Equal(t, "", string(r.PeekBack())) + + n, err = r.Read(ctx, b) + require.ErrorIs(t, err, io.EOF) + require.Equal(t, 0, n) + require.Equal(t, "", string(r.PeekBack())) +} + +// TODO: Randomized/fuzz tests. diff --git a/morebufio/readerat.go b/morebufio/readerat.go new file mode 100644 index 00000000..40ef9215 --- /dev/null +++ b/morebufio/readerat.go @@ -0,0 +1,86 @@ +package morebufio + +import ( + "context" + "io" + "sync" + + "github.com/grailbio/base/errors" + "github.com/grailbio/base/ioctx" +) + +type readerAt struct { + r ioctx.ReaderAt + // mu guards updates to all subsequent fields. It's not held during reads. + mu sync.Mutex + // seekerInUse is true while waiting for operations (like reads) on seeker. + seekerInUse bool + seeker *readSeeker +} + +// NewReaderAt constructs a buffered ReaderAt. While ReaderAt allows arbitrary concurrent reads, +// this implementation has only one buffer (for simplicity), so reads will only be usefully buffered +// if reads are serial and generally contiguous (just like a plain ioctx.Reader). Concurrent reads +// will be "passed through" to the underlying ReaderAt, just without buffering. +func NewReaderAtSize(r ioctx.ReaderAt, size int) ioctx.ReaderAt { + return &readerAt{ + r: r, + seeker: NewReadSeekerSize(&readerAtSeeker{r: r}, size), + } +} + +// ReadAt implements ioctx.ReaderAt. +func (r *readerAt) ReadAt(ctx context.Context, dst []byte, off int64) (int, error) { + acquired := r.tryAcquireSeeker() + if !acquired { + return r.r.ReadAt(ctx, dst, off) + } + defer r.releaseSeeker() + if _, err := r.seeker.Seek(ctx, off, io.SeekStart); err != nil { + return 0, errors.E(err, "seeking for ReadAt") + } + return r.seeker.Read(ctx, dst) +} + +func (r *readerAt) tryAcquireSeeker() bool { + r.mu.Lock() + defer r.mu.Unlock() + if r.seekerInUse { + return false + } + r.seekerInUse = true + return true +} + +func (r *readerAt) releaseSeeker() { + r.mu.Lock() + defer r.mu.Unlock() + if !r.seekerInUse { + panic("release of unacquired seeker") + } + r.seekerInUse = false +} + +// readerAtSeeker is a simple ioctx.ReadSeeker that only supports seeking by io.SeekCurrent, +// which is all readSeeker requires. +type readerAtSeeker struct { + r ioctx.ReaderAt + pos int64 +} + +func (r *readerAtSeeker) Read(ctx context.Context, p []byte) (int, error) { + n, err := r.r.ReadAt(ctx, p, r.pos) + r.pos += int64(n) + return n, err +} + +func (r *readerAtSeeker) Seek(ctx context.Context, request int64, whence int) (int64, error) { + if whence == io.SeekCurrent { + r.pos += request + return r.pos, nil + } + // Pretend the end position is zero. readSeeker requests this at initialization but we + // won't use it after that. + r.pos = 0 + return r.pos, nil +} diff --git a/morebufio/readerat_test.go b/morebufio/readerat_test.go new file mode 100644 index 00000000..fe80f1b3 --- /dev/null +++ b/morebufio/readerat_test.go @@ -0,0 +1,61 @@ +package morebufio + +import ( + "context" + "io" + "math/rand" + "strconv" + "strings" + "testing" + + "github.com/grailbio/base/ioctx" + "github.com/grailbio/base/traverse" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestReaderAt(t *testing.T) { + const want = "abcdefghijklmnopqrstuvwxyz" + ctx := context.Background() + rawReaderAt := ioctx.FromStdReaderAt(strings.NewReader(want)) + + t.Run("sequential", func(t *testing.T) { + bufAt := NewReaderAtSize(rawReaderAt, 5) + got := make([]byte, 0, len(want)) + for len(got) < len(want) { + n, err := bufAt.ReadAt(ctx, got[len(got):cap(got)], int64(len(got))) + got = got[:len(got)+n] + if err == io.EOF { + break + } + require.NoError(t, err) + } + assert.Equal(t, want, string(got)) + }) + + t.Run("random", func(t *testing.T) { + rnd := rand.New(rand.NewSource(1)) + for _, parallelism := range []int{1, len(want) / 2} { + t.Run(strconv.Itoa(parallelism), func(t *testing.T) { + bufAt := NewReaderAtSize(rawReaderAt, 5) + got := make([]byte, len(want)) + perm := rnd.Perm(len(want) / 2) + _ = traverse.T{Limit: parallelism}.Each(len(perm), func(permIdx int) error { + i := perm[permIdx] + start := i * 2 + limit := start + 2 + if limit > len(got) { + limit -= 1 + } + n, err := bufAt.ReadAt(ctx, got[start:limit], int64(start)) + assert.Equal(t, limit-start, n) + if limit < len(got) || err != io.EOF { + require.NoError(t, err) + } + return nil + }) + assert.Equal(t, want, string(got)) + }) + } + }) +} diff --git a/morebufio/readseeker.go b/morebufio/readseeker.go new file mode 100644 index 00000000..84113b68 --- /dev/null +++ b/morebufio/readseeker.go @@ -0,0 +1,114 @@ +package morebufio + +import ( + "context" + "io" + + "github.com/grailbio/base/ioctx" +) + +type readSeeker struct { + r ioctx.ReadSeeker + // buf is the buffer, resized as necessary after reading from r. + buf []byte + // off is the caller's current offset into the buffer. buf[off:] is unread. + off int + + // filePos is the caller's current position r's stream. This can be different from r's position, + // for example when there's unread data in buf. Equals -1 when uninitialized. + filePos int64 + // fileEnd is the offset of the end of r, uused for efficiently seeking within r. Equals -1 when + // uninitialized. + fileEnd int64 +} + +var _ ioctx.ReadSeeker = (*readSeeker)(nil) + +// minBufferSize equals bufio.minBufferSize. +const minBufferSize = 16 + +// NewReadSeekerSize returns a buffered io.ReadSeeker whose buffer has at least the specified size. +// If r is already a readSeeker with sufficient size, returns r. +func NewReadSeekerSize(r ioctx.ReadSeeker, size int) *readSeeker { + if b, ok := r.(*readSeeker); ok && len(b.buf) >= size { + return b + } + if size < minBufferSize { + size = minBufferSize + } + return &readSeeker{r, make([]byte, 0, size), 0, -1, -1} +} + +// Read implements ioctx.Reader. +func (b *readSeeker) Read(ctx context.Context, p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + if err := b.initFilePos(ctx); err != nil { + return 0, err + } + var err error + if b.off == len(b.buf) { + b.buf = b.buf[:cap(b.buf)] + var n int + n, err = b.r.Read(ctx, b.buf) + b.buf, b.off = b.buf[:n], 0 + } + n := copy(p, b.buf[b.off:]) + b.off += n + if b.off < len(b.buf) && err == io.EOF { + // We've reached EOF from filling the buffer but the caller hasn't reached the end yet. + // Clear EOF for now; we'll find it again after the caller reaches the end of the buffer. + err = nil + } + b.filePos += int64(n) + return n, err +} + +// Seek implements ioctx.Seeker. +func (b *readSeeker) Seek(ctx context.Context, request int64, whence int) (int64, error) { + if err := b.initFilePos(ctx); err != nil { + return 0, err + } + var diff int64 + switch whence { + case io.SeekStart: + diff = request - b.filePos + case io.SeekCurrent: + diff = request + case io.SeekEnd: + diff = b.fileEnd + request - b.filePos + default: + panic(whence) + } + if -int64(b.off) <= diff && diff <= int64(len(b.buf)-b.off) { + // Seek within buffer without changing file position. + b.off += int(diff) + b.filePos += diff + return b.filePos, nil + } + // Discard the buffer and seek the underlying reader. + diff -= int64(len(b.buf) - b.off) + b.buf, b.off = b.buf[:0], 0 + var err error + b.filePos, err = b.r.Seek(ctx, diff, io.SeekCurrent) + return b.filePos, err +} + +// initFilePos idempotently initializes filePos and fileEnd. +func (b *readSeeker) initFilePos(ctx context.Context) error { + if b.filePos >= 0 && b.fileEnd >= 0 { + return nil + } + var err error + b.filePos, err = b.r.Seek(ctx, 0, io.SeekCurrent) + if err != nil { + return err + } + b.fileEnd, err = b.r.Seek(ctx, 0, io.SeekEnd) + if err != nil { + return err + } + _, err = b.r.Seek(ctx, b.filePos, io.SeekStart) + return err +} diff --git a/morebufio/readseeker_test.go b/morebufio/readseeker_test.go new file mode 100644 index 00000000..94564751 --- /dev/null +++ b/morebufio/readseeker_test.go @@ -0,0 +1,170 @@ +package morebufio + +import ( + "bytes" + "context" + "fmt" + "io" + "math/rand" + "strings" + "testing" + + "github.com/grailbio/base/ioctx" + "github.com/grailbio/testutil/assert" +) + +func TestReadSeeker(t *testing.T) { + ctx := context.Background() + const file = "0123456789" + t.Run("read_zero", func(t *testing.T) { + r := NewReadSeekerSize(ioctx.FromStdReadSeeker(bytes.NewReader([]byte(file))), 4) + var b []byte + n, err := r.Read(ctx, b) + assert.NoError(t, err) + assert.EQ(t, n, 0) + }) + t.Run("read", func(t *testing.T) { + r := NewReadSeekerSize(ioctx.FromStdReadSeeker(bytes.NewReader([]byte(file))), 4) + b := make([]byte, 4) + + n, err := r.Read(ctx, b) + assert.NoError(t, err) + assert.GE(t, n, 0) + assert.LE(t, n, len(b)) + assert.EQ(t, b[:n], []byte(file[:n])) + remaining := file[n:] + + n, err = r.Read(ctx, b) + assert.NoError(t, err) + assert.GE(t, n, 0) + assert.LE(t, n, len(b)) + assert.EQ(t, b[:n], []byte(remaining[:n])) + }) + t.Run("seek", func(t *testing.T) { + r := NewReadSeekerSize(ioctx.FromStdReadSeeker(bytes.NewReader([]byte(file))), 4) + b := make([]byte, 4) + + n, err := io.ReadFull(ioctx.ToStdReadSeeker(ctx, r), b) + assert.NoError(t, err) + assert.EQ(t, n, len(b)) + assert.EQ(t, b, []byte(file[:4])) + + n64, err := r.Seek(ctx, -2, io.SeekCurrent) + assert.NoError(t, err) + assert.EQ(t, int(n64), 2) + + n, err = io.ReadFull(ioctx.ToStdReadSeeker(ctx, r), b) + assert.NoError(t, err) + assert.EQ(t, n, len(b)) + assert.EQ(t, b, []byte(file[2:6])) + }) + t.Run("regression_early_eof", func(t *testing.T) { + // Regression test for an issue discovered during NewReaderAt development. + // We construct a read seeker that returns EOF after filling the internal buffer. + // In this case we get that behavior from the string reader's ReadAt method, adapted + // into a seeker. This exposed a bug where readSeeker returned the EOF from filling its + // internal buffer even if the client hadn't read that far yet. + rawRS := &readerAtSeeker{r: ioctx.FromStdReaderAt(strings.NewReader(file))} + r := NewReadSeekerSize(rawRS, len(file)+1) + b := make([]byte, 4) + + n, err := r.Read(ctx, b) + assert.NoError(t, err) + assert.EQ(t, n, len(b)) + assert.EQ(t, b, []byte(file[:4])) + }) +} + +func TestReadSeekerRandom(t *testing.T) { + const ( + fileSize = 10000 + testOps = 100000 + ) + ctx := context.Background() + rnd := rand.New(rand.NewSource(1)) + file := func() string { + b := make([]byte, fileSize) + _, _ = rnd.Read(b) + return string(b) + }() + for _, bufSize := range []int{1, 16, 1024, fileSize * 2} { + t.Run(fmt.Sprint(bufSize), func(t *testing.T) { + var ( + gold, test ioctx.ReadSeeker + pos int + ) + reinit := func() { + pos = 0 + gold = ioctx.FromStdReadSeeker(bytes.NewReader([]byte(file))) + testBase := bytes.NewReader([]byte(file)) + if rnd.Intn(2) == 1 { + // Exercise initializing in the middle of a file. + pos = rnd.Intn(fileSize) + nGold, err := gold.Seek(ctx, int64(pos), io.SeekStart) + assert.NoError(t, err) + assert.EQ(t, nGold, int64(pos)) + nTest, err := testBase.Seek(int64(pos), io.SeekStart) + assert.NoError(t, err) + assert.EQ(t, nTest, int64(pos)) + } + test = NewReadSeekerSize(ioctx.FromStdReadSeeker(testBase), bufSize) + } + reinit() + ops := []func(){ + reinit, + func() { // read + n := len(file) - pos + if n > 0 { + n = rnd.Intn(n) + } + bGold := make([]byte, n) + bTest := make([]byte, len(bGold)) + + nGold, errGold := io.ReadFull(ioctx.ToStdReadSeeker(ctx, gold), bGold) + nTest, errTest := io.ReadFull(ioctx.ToStdReadSeeker(ctx, test), bTest) + pos += nGold + + assert.EQ(t, nTest, nGold) + assert.NoError(t, errGold) + assert.NoError(t, errTest) + }, + func() { // seek current + off := rnd.Intn(len(file)) - pos + + nGold, errGold := gold.Seek(ctx, int64(off), io.SeekCurrent) + nTest, errTest := test.Seek(ctx, int64(off), io.SeekCurrent) + pos = int(nGold) + + assert.EQ(t, nTest, nGold) + assert.NoError(t, errGold) + assert.NoError(t, errTest) + }, + func() { // seek start + off := rnd.Intn(len(file)) + + nGold, errGold := gold.Seek(ctx, int64(off), io.SeekStart) + nTest, errTest := test.Seek(ctx, int64(off), io.SeekStart) + pos = int(nGold) + + assert.EQ(t, nTest, nGold) + assert.NoError(t, errGold) + assert.NoError(t, errTest) + }, + func() { // seek end + off := -rnd.Intn(len(file)) + + nGold, errGold := gold.Seek(ctx, int64(off), io.SeekEnd) + nTest, errTest := test.Seek(ctx, int64(off), io.SeekEnd) + pos = int(nGold) + + assert.EQ(t, nTest, nGold) + assert.NoError(t, errGold) + assert.NoError(t, errTest) + }, + } + for i := 0; i < testOps; i++ { + ops[rnd.Intn(len(ops))]() + } + }) + } +} diff --git a/must/must.go b/must/must.go new file mode 100644 index 00000000..9824ddb2 --- /dev/null +++ b/must/must.go @@ -0,0 +1,92 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +// Package must provides a handful of functions to express fatal +// assertions in Go programs. It is meant to alleviate cumbersome +// error handling and reporting when the only course of action is to +// fail the program. Package must is intended to be used by top-level +// binaries (i.e., in main packages); it should rarely be used +// elsewhere. +package must + +import ( + "fmt" + + "github.com/grailbio/base/log" +) + +// Func is the function called to report an error and interrupt execution. Func +// is typically set to a function that logs the message and halts execution, +// e.g. by panicking. It should be set before any potential calls to functions +// in the must package. Func is passed the call depth of the caller of the must +// function, e.g. the caller of Nil. This can be used to annotate messages. +// +// The default implementation logs the message with +// github.com/grailbio/base/log at the Error level and then panics. +var Func func(int, ...interface{}) = func(depth int, v ...interface{}) { + s := fmt.Sprint(v...) + // Nothing to do if output fails. + _ = log.Output(depth+1, log.Error, s) + panic(s) +} + +// Nil asserts that v is nil; v is typically a value of type error. +// If v is not nil, Nil formats a message in hte manner of fmt.Sprint +// and calls must.Func. Nil also suffixes the message with the +// fmt.Sprint-formatted value of v. +func Nil(v interface{}, args ...interface{}) { + if v == nil { + return + } + if len(args) == 0 { + Func(2, v) + return + } + Func(2, fmt.Sprint(args...), ": ", v) +} + +// Nilf asserts that v is nil; v is typically a value of type error. +// If v is not nil, Nilf formats a message in hte manner of +// fmt.Sprintf and calls must.Func. Nilf also suffixes the message +// with the fmt.Sprint-formatted value of v. +func Nilf(v interface{}, format string, args ...interface{}) { + if v == nil { + return + } + Func(2, fmt.Sprintf(format, args...), ": ", v) +} + +// True is a no-op if the value b is true. If it is false, True +// formats a message in the manner of fmt.Sprint and calls Func. +func True(b bool, v ...interface{}) { + if b { + return + } + if len(v) == 0 { + Func(2, "must: assertion failed") + return + } + Func(2, v...) +} + +// Truef is a no-op if the value b is true. If it is false, True +// formats a message in the manner of fmt.Sprintf and calls Func. +func Truef(x bool, format string, v ...interface{}) { + if x { + return + } + Func(2, fmt.Sprintf(format, v...)) +} + +// Never asserts that it is never called. If it is, it formats a message +// in the manner of fmt.Sprint and calls Func. +func Never(v ...interface{}) { + Func(2, v...) +} + +// Neverf asserts that it is never called. If it is, it formats a message +// in the manner of fmt.Sprintf and calls Func. +func Neverf(format string, v ...interface{}) { + Func(2, fmt.Sprintf(format, v...)) +} diff --git a/must/must_test.go b/must/must_test.go new file mode 100644 index 00000000..4212f13e --- /dev/null +++ b/must/must_test.go @@ -0,0 +1,61 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package must_test + +import ( + "errors" + "fmt" + "runtime" + "testing" + + "github.com/grailbio/base/must" +) + +// TestDepth verifies that the depth passed to Func correctly locates the +// caller of the must function. +func TestDepth(t *testing.T) { + _, thisFile, _, ok := runtime.Caller(0) + if !ok { + t.Fatal("could not determine current file") + } + must.Func = func(depth int, v ...interface{}) { + _, file, _, ok := runtime.Caller(depth) + if !ok { + t.Fatal("could not determine caller of Func") + } + if file != thisFile { + t.Errorf("caller at depth %d is '%s'; should be '%s'", depth, file, thisFile) + } + } + must.True(false) + must.Truef(false, "") + must.Nil(struct{}{}) + must.Nilf(struct{}{}, "") + must.Never() + must.Neverf("") +} + +func Example() { + must.Func = func(depth int, v ...interface{}) { + fmt.Print(v...) + fmt.Print("\n") + } + + must.Nil(errors.New("unexpected condition")) + must.Nil(nil) + must.Nil(errors.New("some error")) + must.Nil(errors.New("i/o error"), "reading file") + + must.True(false) + must.True(true, "something happened") + must.True(false, "a condition failed") + + // Output: + // unexpected condition + // some error + // reading file: i/o error + // must: assertion failed + // a condition failed +} diff --git a/pprof/pprof.go b/pprof/pprof.go index f027f45e..a653fc20 100644 --- a/pprof/pprof.go +++ b/pprof/pprof.go @@ -62,8 +62,8 @@ func newProfiling() *profiling { {&pr.threadName, "thread-create-profile", "", "filename prefix for thread create profiles"}, {&pr.blockName, "block-profile", "", "filename prefix for block profiles"}, {&pr.mutexName, "mutex-profile", "", "filename prefix for mutex profiles"}, - {&pr.mutexRate, "mutex-profile-rate", 1, "rate for runtime.SetMutexProfileFraction"}, - {&pr.blockRate, "block-profile-rate", 1, "rate for runtime. SetBlockProfileRate"}, + {&pr.mutexRate, "mutex-profile-rate", 200, "rate for runtime.SetMutexProfileFraction"}, + {&pr.blockRate, "block-profile-rate", 200, "rate for runtime.SetBlockProfileRate"}, {&pr.profileInterval, "profile-interval-s", 0.0, "If >0, output new profiles at this interval (seconds). If <=0, profiles are written only when Write() is called"}, } { fn := p.n @@ -135,10 +135,10 @@ func (p *profiling) Start() { if atomic.AddInt32(&p.started, 1) > 1 { return } - if len(p.blockName) > 0 { + if len(p.blockName) > 0 || len(p.httpAddr) > 0 { runtime.SetBlockProfileRate(p.blockRate) } - if len(p.mutexName) > 0 { + if len(p.mutexName) > 0 || len(p.httpAddr) > 0 { runtime.SetMutexProfileFraction(p.mutexRate) } if len(p.cpuName) > 0 { @@ -152,7 +152,7 @@ func (p *profiling) Start() { } }() } - if p.httpAddr != "" { + if len(p.httpAddr) > 0 { mux := http.NewServeMux() mux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html; charset=utf-8") diff --git a/psort/doc.go b/psort/doc.go new file mode 100644 index 00000000..2a084aa7 --- /dev/null +++ b/psort/doc.go @@ -0,0 +1,2 @@ +// Package psort includes functions for parallel sorting. +package psort diff --git a/psort/mergesort.go b/psort/mergesort.go new file mode 100644 index 00000000..11c100aa --- /dev/null +++ b/psort/mergesort.go @@ -0,0 +1,136 @@ +package psort + +import ( + "reflect" + "sort" + "sync" + + "github.com/grailbio/base/traverse" +) + +const ( + serialThreshold = 128 +) + +// Slice sorts the given slice according to the ordering induced by the provided +// less function. Parallel computation will be attempted, up to the limit imposed by +// parallelism. This function can be much faster than the standard library's sort.Slice() +// when sorting large slices on multicore machines. +func Slice(slice interface{}, less func(i, j int) bool, parallelism int) { + if parallelism < 1 { + panic("parallelism must be at least 1") + } + if reflect.TypeOf(slice).Kind() != reflect.Slice { + panic("input interface was not of slice type") + } + rv := reflect.ValueOf(slice) + if rv.Len() < 2 { + return + } + // For clarity, we will sort a slice containing indices from the input slice. Then, + // we will set the elements of the input slice according to this permutation. This + // avoids difficult-to-understand reflection types and calls in most of the code. + perm := make([]int, rv.Len()) + for i := range perm { + perm[i] = i + } + scratch := make([]int, len(perm)) + mergeSort(perm, less, parallelism, scratch) + result := reflect.MakeSlice(rv.Type(), rv.Len(), rv.Len()) + _ = traverse.Limit(parallelism).Range(rv.Len(), func(start, end int) error { + for i := start; i < end; i++ { + result.Index(i).Set(rv.Index(perm[i])) + } + return nil + }) + _ = traverse.Limit(parallelism).Range(rv.Len(), func(start, end int) error { + reflect.Copy(rv.Slice(start, end), result.Slice(start, end)) + return nil + }) +} + +func mergeSort(perm []int, less func(i, j int) bool, parallelism int, scratch []int) { + if parallelism == 1 || len(perm) < serialThreshold { + sortSerial(perm, less) + return + } + + // Sort two halves of the slice in parallel, allocating half of our parallelism to + // each subroutine. + left := perm[:len(perm)/2] + right := perm[len(perm)/2:] + var waitGroup sync.WaitGroup + waitGroup.Add(1) + go func() { + mergeSort(left, less, (parallelism+1)/2, scratch[:len(perm)/2]) + waitGroup.Done() + }() + mergeSort(right, less, parallelism/2, scratch[len(perm)/2:]) + waitGroup.Wait() + + merge(left, right, less, parallelism, scratch) + parallelCopy(perm, scratch, parallelism) +} + +func parallelCopy(dst, src []int, parallelism int) { + _ = traverse.Limit(parallelism).Range(len(dst), func(start, end int) error { + copy(dst[start:end], src[start:end]) + return nil + }) +} + +func sortSerial(perm []int, less func(i, j int) bool) { + sort.Slice(perm, func(i, j int) bool { + return less(perm[i], perm[j]) + }) +} + +func merge(perm1, perm2 []int, less func(i, j int) bool, parallelism int, out []int) { + if parallelism == 1 || len(perm1)+len(perm2) < serialThreshold { + mergeSerial(perm1, perm2, less, out) + return + } + + if len(perm1) < len(perm2) { + perm1, perm2 = perm2, perm1 + } + // Find the index in perm2 such that all elements to the left are smaller than + // the midpoint element of perm1. + r := len(perm1) / 2 + s := sort.Search(len(perm2), func(i int) bool { + return !less(perm2[i], perm1[r]) + }) + // Merge in parallel, allocating half of our parallelism to each subroutine. + var waitGroup sync.WaitGroup + waitGroup.Add(1) + go func() { + merge(perm1[:r], perm2[:s], less, (parallelism+1)/2, out[:r+s]) + waitGroup.Done() + }() + merge(perm1[r:], perm2[s:], less, parallelism/2, out[r+s:]) + waitGroup.Wait() +} + +func mergeSerial(perm1, perm2 []int, less func(i, j int) bool, out []int) { + var idx1, idx2, idxOut int + for idx1 < len(perm1) && idx2 < len(perm2) { + if less(perm1[idx1], perm2[idx2]) { + out[idxOut] = perm1[idx1] + idx1++ + } else { + out[idxOut] = perm2[idx2] + idx2++ + } + idxOut++ + } + for idx1 < len(perm1) { + out[idxOut] = perm1[idx1] + idx1++ + idxOut++ + } + for idx2 < len(perm2) { + out[idxOut] = perm2[idx2] + idx2++ + idxOut++ + } +} diff --git a/psort/mergesort_test.go b/psort/mergesort_test.go new file mode 100644 index 00000000..2132697d --- /dev/null +++ b/psort/mergesort_test.go @@ -0,0 +1,173 @@ +package psort + +import ( + "fmt" + "math/rand" + "reflect" + "sort" + "testing" +) + +type TestInput int + +const ( + Random TestInput = iota + Ascending + Descending +) + +func TestSlice(t *testing.T) { + tests := []struct { + input TestInput + size int + parallelism int + reps int + }{ + { + input: Random, + size: 10000, + parallelism: 7, + reps: 100, + }, + { + input: Random, + size: 1000000, + parallelism: 6, + reps: 4, + }, + { + input: Ascending, + size: 10000, + parallelism: 9, + reps: 1, + }, + { + input: Descending, + size: 10000, + parallelism: 8, + reps: 1, + }, + } + + for _, test := range tests { + random := rand.New(rand.NewSource(0)) + for rep := 0; rep < test.reps; rep++ { + in := make([]int, test.size) + switch test.input { + case Random: + for i := range in { + in[i] = random.Intn(test.size) + } + case Ascending: + for i := range in { + in[i] = i + } + case Descending: + for i := range in { + in[i] = len(in) - i + } + } + expected := make([]int, len(in)) + copy(expected, in) + sort.Slice(expected, func(i, j int) bool { + return expected[i] < expected[j] + }) + Slice(in, func(i, j int) bool { + return in[i] < in[j] + }, test.parallelism) + if !reflect.DeepEqual(expected, in) { + t.Errorf("Wrong sort result: want %v\n, got %v\n", expected, in) + } + } + } +} + +func BenchmarkSlice(b *testing.B) { + tests := []struct { + size int + parallelism int //parallelism = 0 means use sort.Slice() sort + }{ + { + size: 100000000, + parallelism: 4096, + }, + { + size: 100000000, + parallelism: 2048, + }, + { + size: 100000000, + parallelism: 1024, + }, + { + size: 100000000, + parallelism: 512, + }, + { + size: 100000000, + parallelism: 256, + }, + { + size: 100000000, + parallelism: 128, + }, + { + size: 100000000, + parallelism: 64, + }, + { + size: 100000000, + parallelism: 32, + }, + { + size: 100000000, + parallelism: 16, + }, + { + size: 100000000, + parallelism: 8, + }, + { + size: 100000000, + parallelism: 4, + }, + { + size: 100000000, + parallelism: 2, + }, + { + size: 100000000, + parallelism: 1, + }, + { + size: 100000000, + parallelism: 0, + }, + } + + for _, test := range tests { + b.Run(fmt.Sprintf("size:%d-%d", test.size, test.parallelism), func(b *testing.B) { + data := make([]float64, test.size) + r := rand.New(rand.NewSource(0)) + dataCopy := make([]float64, len(data)) + for i := range data { + data[i] = r.Float64() + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() + copy(dataCopy, data) + b.StartTimer() + if test.parallelism == 0 { + sort.Slice(dataCopy, func(i, j int) bool { + return dataCopy[i] < dataCopy[j] + }) + } else { + Slice(dataCopy, func(i, j int) bool { + return dataCopy[i] < dataCopy[j] + }, test.parallelism) + } + } + }) + } +} diff --git a/recordio/README.md b/recordio/README.md index cd458772..c06bb906 100644 --- a/recordio/README.md +++ b/recordio/README.md @@ -86,11 +86,10 @@ Each chunk contains a 28 byte header. chunk := magic (8 bytes) CRC32 (4 bytes LE) - flag (4bytes LE) + flag (4 bytes LE) chunk payload size (4 bytes LE) totalChunks (4 bytes LE) chunk index (4 bytes LE) - flag (4 bytes LE) payload (bytes) - The 8-byte magic header tells whether the chunk is part of header, body, or a trailer. @@ -99,7 +98,7 @@ Each chunk contains a 28 byte header. MagicPacked, and MagicTrailer. -- The chunk payload size is (32768 - 32), unless it is for the final chunk of a +- The chunk payload size is (32768 - 28), unless it is for the final chunk of a block. For the final chunk, the "chunk payload size" stores the size of the block contents, and the chunk is filled with garbage to make it 32KiB at rest. diff --git a/recordio/deprecated/index_test.go b/recordio/deprecated/index_test.go index d95dfcca..aa337087 100644 --- a/recordio/deprecated/index_test.go +++ b/recordio/deprecated/index_test.go @@ -269,6 +269,7 @@ func TestIndexErrors(t *testing.T) { } pwr := deprecated.NewLegacyPackedWriter(buf, pwropts) _, err = pwr.Marshal(&TestPB{"x"}) + expect.NoError(t, err, "Marshall packed writer") err = pwr.Flush() expect.HasSubstr(t, err, "packed record index oops") @@ -279,6 +280,7 @@ func TestIndexErrors(t *testing.T) { } pwr = deprecated.NewLegacyPackedWriter(buf, pwropts) _, err = pwr.Marshal(&TestPB{"x"}) + expect.NoError(t, err, "Marshall packed writer first try") _, err = pwr.Marshal(&TestPB{"y"}) expect.HasSubstr(t, err, "packed item index oops") @@ -288,6 +290,7 @@ func TestIndexErrors(t *testing.T) { } pwr = deprecated.NewLegacyPackedWriter(buf, pwropts) _, err = pwr.Marshal(&TestPB{"z"}) + expect.NoError(t, err, "Marshall packed writer") err = pwr.Flush() expect.HasSubstr(t, err, "packed flush oops") } diff --git a/recordio/deprecated/packed_test.go b/recordio/deprecated/packed_test.go index 9de0b6fa..e479a02a 100644 --- a/recordio/deprecated/packed_test.go +++ b/recordio/deprecated/packed_test.go @@ -25,7 +25,7 @@ import ( func TestPackedWriteRead(t *testing.T) { sl := func(s ...string) [][]byte { - bs := [][]byte{} + var bs [][]byte for _, t := range s { bs = append(bs, []byte(t)) } @@ -265,6 +265,7 @@ func TestPackedErrors(t *testing.T) { wropts := deprecated.LegacyPackedWriterOpts{MaxItems: 1, MaxBytes: 10} wr := deprecated.NewLegacyPackedWriter(ew, wropts) _, err := wr.Write([]byte("hello")) + expect.NoError(t, err, "first write succeeds") _, err = wr.Write([]byte("hello")) expect.HasSubstr(t, err, msg) } @@ -318,7 +319,7 @@ func TestPackedErrors(t *testing.T) { func readAll(t *testing.T, buf *bytes.Buffer, opts deprecated.LegacyPackedScannerOpts) []string { sc := deprecated.NewLegacyPackedScanner(buf, opts) - read := []string{} + var read []string for sc.Scan() { read = append(read, string(sc.Bytes())) } diff --git a/recordio/deprecated/packer_test.go b/recordio/deprecated/packer_test.go index 78c41525..09ebb163 100644 --- a/recordio/deprecated/packer_test.go +++ b/recordio/deprecated/packer_test.go @@ -91,6 +91,9 @@ func TestPacker(t *testing.T) { expectStats(t, 1, wr, 0, 0) // Pack is not idempotent. hdr, _, _, err := wr.Pack() + if err != nil { + t.Fatal(err) + } if got, want := len(hdr), 0; got != want { t.Errorf("got %v, want %v", got, want) } diff --git a/recordio/deprecated/range_test.go b/recordio/deprecated/range_test.go index 27b75ba8..2a10c315 100644 --- a/recordio/deprecated/range_test.go +++ b/recordio/deprecated/range_test.go @@ -10,12 +10,12 @@ import ( "strings" "testing" - "github.com/grailbio/testutil" "github.com/grailbio/base/recordio/deprecated" + "github.com/grailbio/testutil" ) func TestBounded(t *testing.T) { - fail := func(err error) { + failIf := func(err error) { if err != nil { t.Fatalf("%v: %v", testutil.Caller(1), err) } @@ -39,7 +39,7 @@ func TestBounded(t *testing.T) { } } - raw := make([]byte, 255, 255) + raw := make([]byte, 255) for i := 0; i < 255; i++ { raw[i] = byte(i) } @@ -47,12 +47,12 @@ func TestBounded(t *testing.T) { rs := bytes.NewReader(raw) // Negative offset will fail. - br, err := deprecated.NewRangeReader(rs, -1, 5) + _, err := deprecated.NewRangeReader(rs, -1, 5) expectError(err, "negative position") // Seeking past the end of the file is the same as seeking to the end. - br, err = deprecated.NewRangeReader(rs, 512, 5) - fail(err) + br, err := deprecated.NewRangeReader(rs, 512, 5) + failIf(err) buf := make([]byte, 3) n, err := br.Read(buf) @@ -60,10 +60,10 @@ func TestBounded(t *testing.T) { expectLen(n, 0) br, err = deprecated.NewRangeReader(rs, 48, 5) - fail(err) + failIf(err) n, err = br.Read(buf) - fail(err) + failIf(err) expectLen(n, 3) expectBuf(buf, '0', '1', '2') @@ -74,7 +74,7 @@ func TestBounded(t *testing.T) { expectBuf(buf[:n], '3', '4') p, err := br.Seek(0, io.SeekStart) - fail(err) + failIf(err) if got, want := p, int64(0); got != want { t.Errorf("got %v, want %v", got, want) } @@ -84,6 +84,7 @@ func TestBounded(t *testing.T) { expectBuf(buf[:n], '0', '1', '2', '3', '4') _, err = br.Seek(2, io.SeekStart) + failIf(err) n, err = br.Read(buf) expectError(err, "EOF") if got, want := buf[:n], []byte{'2', '3', '4'}; !bytes.Equal(got, want) { @@ -91,33 +92,34 @@ func TestBounded(t *testing.T) { } _, err = br.Seek(-2, io.SeekEnd) + failIf(err) n, err = br.Read(buf) expectError(err, "EOF") expectBuf(buf[:n], '3', '4') _, err = br.Seek(1, io.SeekStart) - fail(err) + failIf(err) _, err = br.Seek(1, io.SeekCurrent) - fail(err) + failIf(err) n, err = br.Read(buf[:2]) - fail(err) + failIf(err) expectBuf(buf[:n], '2', '3') _, err = br.Seek(100, io.SeekCurrent) - fail(err) + failIf(err) n, err = br.Read(buf[:2]) expectError(err, "EOF") if got, want := n, 0; got != want { t.Errorf("got %v, want %v", got, want) } _, err = br.Seek(-1, io.SeekEnd) - fail(err) + failIf(err) n, err = br.Read(buf[:1]) expectError(err, "EOF") expectBuf(buf[:n], '4') // Seeking past the end of the stream is the same as seeking to the end. _, err = br.Seek(100, io.SeekEnd) - fail(err) + failIf(err) n, err = br.Read(buf[:1]) expectError(err, "EOF") expectLen(n, 0) diff --git a/recordio/deprecated/recordio.go b/recordio/deprecated/recordio.go index 254bbf64..8e451856 100644 --- a/recordio/deprecated/recordio.go +++ b/recordio/deprecated/recordio.go @@ -9,9 +9,8 @@ import ( "fmt" "hash/crc32" "io" - "sync" - "github.com/grailbio/base/errorreporter" + "github.com/grailbio/base/errors" "github.com/grailbio/base/recordio/internal" ) @@ -79,15 +78,14 @@ type LegacyWriter interface { } const ( - sizeOffset = internal.NumMagicBytes - crcOffset = internal.NumMagicBytes + 8 - dataOffset = internal.NumMagicBytes + 8 + crc32.Size + sizeOffset = internal.NumMagicBytes + crcOffset = internal.NumMagicBytes + 8 + dataOffset = internal.NumMagicBytes + 8 + crc32.Size // teaderSize is the size in bytes of the recordio header. headerSize = dataOffset ) type byteWriter struct { - sync.Mutex wr io.Writer magic internal.MagicBytes hdr [headerSize]byte @@ -214,7 +212,7 @@ func isErr(err error) bool { type LegacyScannerImpl struct { rd io.Reader record []byte - err errorreporter.T + err errors.Once opts LegacyScannerOpts hdr [headerSize]byte } @@ -230,7 +228,7 @@ func NewLegacyScanner(rd io.Reader, opts LegacyScannerOpts) LegacyScanner { // Reset implements Scanner.Reset. func (s *LegacyScannerImpl) Reset(rd io.Reader) { s.rd = rd - s.err = errorreporter.T{} + s.err = errors.Once{} } // Unmarshal implements Scanner.Unmarshal. diff --git a/recordio/estimate.go b/recordio/estimate.go new file mode 100644 index 00000000..180452fc --- /dev/null +++ b/recordio/estimate.go @@ -0,0 +1,41 @@ +package recordio + +import ( + "encoding/binary" + + "github.com/grailbio/base/recordio/internal" +) + +// RequiredSpaceUpperBound returns an upper bound on the space required to +// store n items of itemSizes for a specified record size. +func RequiredSpaceUpperBound(itemSizes []int64, recordSize int64) int64 { + + // Max number of chunks required per record. + // reqChunksPerRecord is Ceil(recordSize / internal.MaxChunkPayloadSize) + reqChunksPerRecord := recordSize / internal.MaxChunkPayloadSize + if (recordSize % internal.MaxChunkPayloadSize) != 0 { + reqChunksPerRecord++ + } + + // Max payload = UpperBound(header) + payload. + // 1 varint for # items, n for the size of each of n items. + // Using binary.MaxVarintLen64 since we want an upper bound. + hdrSizeUBound := (len(itemSizes) + 1) * binary.MaxVarintLen64 + maxPayload := int64(hdrSizeUBound) + for _, s := range itemSizes { + maxPayload += s + } + + // Max number of records required for payload. + // reqRecordsForPayload is Ceil(maxPayload / recordSize) + reqRecordsForPayload := maxPayload / recordSize + if (maxPayload % recordSize) != 0 { + reqRecordsForPayload++ + } + + // Max number of chunks required = chunks for payload + 2 chunks for header and trailer. + reqChunksForPayload := (reqChunksPerRecord * reqRecordsForPayload) + int64(2) + + // Upper bound on the space required. + return reqChunksForPayload * internal.ChunkSize +} diff --git a/recordio/estimate_test.go b/recordio/estimate_test.go new file mode 100644 index 00000000..e55e9516 --- /dev/null +++ b/recordio/estimate_test.go @@ -0,0 +1,57 @@ +package recordio_test + +import ( + "testing" + + "github.com/grailbio/base/recordio" + "github.com/grailbio/testutil/assert" +) + +const ( + KiB = int64(1024) + MiB = int64(1024 * 1024) +) + +func TestRequiredSpaceUpperBound(t *testing.T) { + for _, test := range []struct { + itemSizes []int64 + recordSize int64 + expectedReqSpace int64 + }{ + // internal.ChunkSize == 32KiB + + { // recordSize < chunkSize + []int64{1 * KiB, 1 * KiB, 1 * KiB}, + 1 * KiB, + 6 * 32 * KiB, + }, + { // recordSize < chunkSize + []int64{1 * KiB, 1 * KiB, 1 * KiB}, + 2 * KiB, + 4 * 32 * KiB, + }, + { // chunkSize < recordSize + []int64{5 * MiB, 2 * KiB, 12 * MiB, 3 * MiB}, + 4 * MiB, + 776 * 32 * KiB, + }, + { // recordSize == chunkSize + []int64{35 * KiB, 9 * KiB, 1 * MiB, 20 * KiB}, + 32 * KiB, + 72 * 32 * KiB, + }, + { // sizes where no-padding of chunks is required + []int64{32736, 32736, 32732 + 32740 + 32736}, + 32*KiB - 32, + 8 * 32 * KiB, + }, + } { + req := recordio.RequiredSpaceUpperBound(test.itemSizes, test.recordSize) + sum := int64(0) + for _, v := range test.itemSizes { + sum += v + } + assert.GT(t, req, sum) + assert.EQ(t, req, test.expectedReqSpace) + } +} diff --git a/recordio/header.go b/recordio/header.go index f3bf96ea..8d9c82e5 100644 --- a/recordio/header.go +++ b/recordio/header.go @@ -10,7 +10,7 @@ import ( "encoding/binary" "fmt" - "github.com/grailbio/base/errorreporter" + "github.com/grailbio/base/errors" ) const ( @@ -138,7 +138,7 @@ func (e *headerEncoder) putKeyValue(key string, v interface{}) error { // Helper for decoding header data produced by headerEncoder. Thread // compatible. type headerDecoder struct { - err errorreporter.T + err errors.Once data []byte } diff --git a/recordio/internal/chunk.go b/recordio/internal/chunk.go index b41aba62..547de3ae 100644 --- a/recordio/internal/chunk.go +++ b/recordio/internal/chunk.go @@ -12,16 +12,20 @@ import ( "io" "math" - "github.com/grailbio/base/errorreporter" - "github.com/pkg/errors" + "github.com/grailbio/base/errors" ) type chunkFlag uint32 const ( - chunkHeaderSize = 28 - chunkSize = 32 << 10 - maxChunkPayloadSize = chunkSize - chunkHeaderSize + // ChunkHeaderSize is the fixed header size for a chunk. + ChunkHeaderSize = 28 + + // ChunkSize is the fixed size of a chunk, including its header. + ChunkSize = 32 << 10 + + // MaxChunkPayloadSize is the maximum size of payload a chunk can carry. + MaxChunkPayloadSize = ChunkSize - ChunkHeaderSize ) // Chunk layout: @@ -36,7 +40,7 @@ const ( // padding [32768 - 28 - size] // // magic: one of MagicHeader, MagicPacked, MagicTrailer. -// size: size of the chunk payload (data). size <= (32<<10) - 24 +// size: size of the chunk payload (data). size <= (32<<10) - 28 // padding: garbage data added to make the chunk size exactly 32768B. // // total: the total # of chunks in the blocks. @@ -46,24 +50,7 @@ const ( // // crc: IEEE CRC32 of of the succeeding fields: size, index, flag, and data. // Note: padding is not included in the CRC. -type chunkHeader [chunkHeaderSize]byte - -func (h *chunkHeader) Magic() (magic MagicBytes) { - copy(magic[:], h[:]) - return -} - -func (h *chunkHeader) Checksum() uint32 { - return binary.LittleEndian.Uint32(h[8:]) -} - -func (h *chunkHeader) Flag() chunkFlag { - return chunkFlag(binary.LittleEndian.Uint32(h[12:])) -} - -func (h *chunkHeader) Size() uint32 { - return binary.LittleEndian.Uint32(h[16:]) -} +type chunkHeader [ChunkHeaderSize]byte func (h *chunkHeader) TotalChunks() int { return int(binary.LittleEndian.Uint32(h[20:])) @@ -73,7 +60,7 @@ func (h *chunkHeader) Index() int { return int(binary.LittleEndian.Uint32(h[24:])) } -var chunkPadding [maxChunkPayloadSize]byte +var chunkPadding [MaxChunkPayloadSize]byte // Seek to "off". Returns nil iff the seek ptr moves to "off". func Seek(r io.ReadSeeker, off int64) error { @@ -82,7 +69,7 @@ func Seek(r io.ReadSeeker, off int64) error { return err } if n != off { - return errors.Errorf("seek: got %v, expect %v", n, off) + return fmt.Errorf("seek: got %v, expect %v", n, off) } return nil } @@ -99,7 +86,7 @@ func init() { type ChunkWriter struct { nWritten int64 w io.Writer - err *errorreporter.T + err *errors.Once crc hash.Hash32 } @@ -115,17 +102,17 @@ func (w *ChunkWriter) Write(magic MagicBytes, payload []byte) { copy(header[:], magic[:]) chunkIndex := 0 - totalChunks := (len(payload)-1)/maxChunkPayloadSize + 1 + totalChunks := (len(payload)-1)/MaxChunkPayloadSize + 1 for { var chunkPayload []byte lastChunk := false - if len(payload) <= maxChunkPayloadSize { + if len(payload) <= MaxChunkPayloadSize { lastChunk = true chunkPayload = payload payload = nil } else { - chunkPayload = payload[:maxChunkPayloadSize] - payload = payload[maxChunkPayloadSize:] + chunkPayload = payload[:MaxChunkPayloadSize] + payload = payload[MaxChunkPayloadSize:] } binary.LittleEndian.PutUint32(header[12:], uint32(0)) binary.LittleEndian.PutUint32(header[16:], uint32(len(chunkPayload))) @@ -141,7 +128,7 @@ func (w *ChunkWriter) Write(magic MagicBytes, payload []byte) { w.doWrite(chunkPayload) chunkIndex++ if lastChunk { - paddingSize := maxChunkPayloadSize - len(chunkPayload) + paddingSize := MaxChunkPayloadSize - len(chunkPayload) if paddingSize > 0 { w.doWrite(chunkPadding[:paddingSize]) } @@ -161,13 +148,13 @@ func (w *ChunkWriter) doWrite(data []byte) { } w.nWritten += int64(len(data)) if n != len(data) { - w.err.Set(errors.Errorf("Failed to write %d bytes (got %d)", len(data), n)) + w.err.Set(fmt.Errorf("Failed to write %d bytes (got %d)", len(data), n)) } } // NewChunkWriter creates a new chunk writer. Any error is reported through // "err". -func NewChunkWriter(w io.Writer, err *errorreporter.T) *ChunkWriter { +func NewChunkWriter(w io.Writer, err *errors.Once) *ChunkWriter { return &ChunkWriter{w: w, err: err, crc: crc32.New(IEEECRC)} } @@ -175,7 +162,7 @@ func NewChunkWriter(w io.Writer, err *errorreporter.T) *ChunkWriter { // block. Thread compatible. type ChunkScanner struct { r io.ReadSeeker - err *errorreporter.T + err *errors.Once fileSize int64 off int64 @@ -190,7 +177,7 @@ type ChunkScanner struct { } // NewChunkScanner creates a new chunk scanner. Any error is reported through "err". -func NewChunkScanner(r io.ReadSeeker, err *errorreporter.T) *ChunkScanner { +func NewChunkScanner(r io.ReadSeeker, err *errors.Once) *ChunkScanner { rx := &ChunkScanner{r: r, err: err} // Compute the file size. var e error @@ -212,11 +199,11 @@ func (r *ChunkScanner) LimitShard(start, limit, nshard int) { // Compute the offset and limit for shard-of-nshard. // Invariant: limit is the offset at or after which a new block // should not be scanned. - numChunks := (r.fileSize - r.off) / chunkSize + numChunks := (r.fileSize - r.off) / ChunkSize chunksPerShard := float64(numChunks) / float64(nshard) startOff := r.off - r.off = startOff + int64(float64(start)*chunksPerShard)*chunkSize - r.limit = startOff + int64(float64(limit)*chunksPerShard)*chunkSize + r.off = startOff + int64(float64(start)*chunksPerShard)*ChunkSize + r.limit = startOff + int64(float64(limit)*chunksPerShard)*ChunkSize if start == 0 { // No more work to do. We assume LimitShard is called on a block boundary. return @@ -244,7 +231,7 @@ func (r *ChunkScanner) LimitShard(start, limit, nshard int) { r.err.Set(errors.New("invalid chunk header")) return } - r.off += chunkSize * int64(total-index) + r.off += ChunkSize * int64(total-index) r.err.Set(Seek(r.r, r.off)) } @@ -256,7 +243,7 @@ func (r *ChunkScanner) Tell() int64 { // Seek moves the read pointer so that next Scan() will move to the block at the // given file offset. Any error is reported in r.Err() -func (r *ChunkScanner) Seek(off int64) { +func (r *ChunkScanner) Seek(off int64) { // "go vet" complaint expected r.off = off r.err.Set(Seek(r.r, off)) } @@ -284,17 +271,17 @@ func (r *ChunkScanner) Scan() bool { totalChunks = nchunks } if chunkMagic != r.magic { - r.err.Set(errors.Errorf("Magic number changed in the middle of a chunk sequence, got %v, expect %v", + r.err.Set(fmt.Errorf("Magic number changed in the middle of a chunk sequence, got %v, expect %v", r.magic, chunkMagic)) return false } if len(r.chunks) != index { - r.err.Set(errors.Errorf("Chunk index mismatch, got %v, expect %v for magic %x", + r.err.Set(fmt.Errorf("Chunk index mismatch, got %v, expect %v for magic %x", index, len(r.chunks), r.magic)) return false } if nchunks != totalChunks { - r.err.Set(errors.Errorf("Chunk nchunk mismatch, got %v, expect %v for magic %x", + r.err.Set(fmt.Errorf("Chunk nchunk mismatch, got %v, expect %v for magic %x", nchunks, totalChunks, r.magic)) return false } @@ -319,7 +306,7 @@ func (r *ChunkScanner) readChunkHeader(header *chunkHeader) bool { r.err.Set(err) return false } - r.off, err = r.r.Seek(-chunkHeaderSize, io.SeekCurrent) + r.off, err = r.r.Seek(-ChunkHeaderSize, io.SeekCurrent) r.err.Set(err) return true } @@ -334,7 +321,7 @@ func (r *ChunkScanner) readChunk() (MagicBytes, chunkFlag, int, int, []byte) { r.err.Set(err) return MagicInvalid, chunkFlag(0), 0, 0, nil } - header := chunkBuf[:chunkHeaderSize] + header := chunkBuf[:ChunkHeaderSize] var magic MagicBytes copy(magic[:], header[:]) @@ -343,15 +330,15 @@ func (r *ChunkScanner) readChunk() (MagicBytes, chunkFlag, int, int, []byte) { size := binary.LittleEndian.Uint32(header[16:]) totalChunks := int(binary.LittleEndian.Uint32(header[20:])) index := int(binary.LittleEndian.Uint32(header[24:])) - if size > maxChunkPayloadSize { - r.err.Set(errors.Errorf("Invalid chunk size %d", size)) + if size > MaxChunkPayloadSize { + r.err.Set(fmt.Errorf("Invalid chunk size %d", size)) return MagicInvalid, chunkFlag(0), 0, 0, nil } - chunkPayload := chunkBuf[chunkHeaderSize : chunkHeaderSize+size] - actualCsum := crc32.Checksum(chunkBuf[12:chunkHeaderSize+size], IEEECRC) + chunkPayload := chunkBuf[ChunkHeaderSize : ChunkHeaderSize+size] + actualCsum := crc32.Checksum(chunkBuf[12:ChunkHeaderSize+size], IEEECRC) if expectedCsum != actualCsum { - r.err.Set(errors.Errorf("Chunk checksum mismatch, expect %d, got %d", + r.err.Set(fmt.Errorf("Chunk checksum mismatch, expect %d, got %d", actualCsum, expectedCsum)) } return magic, flag, totalChunks, index, chunkPayload @@ -377,11 +364,11 @@ func (r *ChunkScanner) resetChunks() { func (r *ChunkScanner) allocChunk() []byte { for len(r.pool) <= r.unused { - r.pool = append(r.pool, make([]byte, chunkSize)) + r.pool = append(r.pool, make([]byte, ChunkSize)) } b := r.pool[r.unused] r.unused++ - if len(b) != chunkSize { + if len(b) != ChunkSize { panic(r) } return b @@ -392,14 +379,14 @@ func (r *ChunkScanner) allocChunk() []byte { // the user must call Seek() explicitly. func (r *ChunkScanner) ReadLastBlock() (MagicBytes, [][]byte) { var err error - r.off, err = r.r.Seek(-chunkSize, io.SeekEnd) + r.off, err = r.r.Seek(-ChunkSize, io.SeekEnd) if err != nil { r.err.Set(err) return MagicInvalid, nil } magic, _, totalChunks, index, payload := r.readChunk() if magic != MagicTrailer { - r.err.Set(errors.Errorf("Missing magic trailer; found %v", magic)) + r.err.Set(fmt.Errorf("Missing magic trailer; found %v", magic)) return MagicInvalid, nil } if index == 0 && totalChunks == 1 { @@ -407,13 +394,13 @@ func (r *ChunkScanner) ReadLastBlock() (MagicBytes, [][]byte) { return magic, [][]byte{payload} } // Seek to the beginning of the block. - r.off, err = r.r.Seek(-int64(index+1)*chunkSize, io.SeekEnd) + r.off, err = r.r.Seek(-int64(index+1)*ChunkSize, io.SeekEnd) if err != nil { r.err.Set(err) return MagicInvalid, nil } if !r.Scan() { - r.err.Set(errors.Errorf("Failed to read trailer")) + r.err.Set(fmt.Errorf("Failed to read trailer")) return MagicInvalid, nil } return r.magic, r.chunks diff --git a/recordio/legacyscanner.go b/recordio/legacyscanner.go index a521fa73..735e45d9 100644 --- a/recordio/legacyscanner.go +++ b/recordio/legacyscanner.go @@ -5,18 +5,18 @@ package recordio import ( + "fmt" "io" - "github.com/grailbio/base/errorreporter" + "github.com/grailbio/base/errors" "github.com/grailbio/base/recordio/deprecated" "github.com/grailbio/base/recordio/internal" - "github.com/pkg/errors" ) // legacyScanner is a ScannerV2 implementation that reads legacy recordio files, // either packed or unpacked. type legacyScannerAdapter struct { - err errorreporter.T + err errors.Once in io.ReadSeeker sc *deprecated.LegacyScannerImpl opts ScannerOpts @@ -67,7 +67,7 @@ func (s *legacyScannerAdapter) seekRaw(off int64) bool { func (s *legacyScannerAdapter) Seek(loc ItemLocation) { // TODO(saito) Avoid seeking the file if loc.Block points to the current block. if s.err.Err() == io.EOF { - s.err = errorreporter.T{} + s.err = errors.Once{} } if !s.seekRaw(int64(loc.Block)) { return @@ -76,7 +76,7 @@ func (s *legacyScannerAdapter) Seek(loc ItemLocation) { return } if loc.Item >= len(s.buffered) { - s.err.Set(errors.Errorf("Invalid location %+v, block has only %d items", loc, len(s.buffered))) + s.err.Set(fmt.Errorf("Invalid location %+v, block has only %d items", loc, len(s.buffered))) } s.nextItem = loc.Item } @@ -112,7 +112,7 @@ func (s *legacyScannerAdapter) scanNextBlock() bool { s.nextItem = 0 return true } - s.err.Set(errors.Errorf("recordio: invalid magic number: %v", magic)) + s.err.Set(fmt.Errorf("recordio: invalid magic number: %v", magic)) return false } diff --git a/recordio/recordio.go b/recordio/recordio.go index 14c9ac60..67400c7b 100644 --- a/recordio/recordio.go +++ b/recordio/recordio.go @@ -16,8 +16,6 @@ type TransformFunc func(scratch []byte, in [][]byte) (out []byte, err error) type FormatVersion int const ( - // InvalidFormat is never used. - InvalidFormat FormatVersion = 0 // V1 is pre 2018-02 format V1 FormatVersion = 1 // V2 is post 2018-02 format @@ -34,9 +32,6 @@ var MaxReadRecordSize = internal.MaxReadRecordSize // value. Else, it should allocate a new []byte and return it. type MarshalFunc func(scratch []byte, v interface{}) ([]byte, error) -// UnmarshalFunc is called to deserialize data. -type UnmarshalFunc func(data []byte, v interface{}) error - // MagicPacked is the chunk header for legacy and v2 data chunks. Not for // general use. var MagicPacked = internal.MagicPacked diff --git a/recordio/recordioutil/compress.go b/recordio/recordioutil/compress.go index 93dba791..bef0b051 100644 --- a/recordio/recordioutil/compress.go +++ b/recordio/recordioutil/compress.go @@ -35,14 +35,6 @@ func NewFlateTransform(level int) *FlateTransform { return &FlateTransform{level: level} } -// SetPassthrough sets the compressor/decommpressor into passthrough -// mode whereby they are essentially disabled and do not transform the -// data passed to them. SetPassThrough is generally used when -// reading/writing metadata. -func (f *FlateTransform) SetPassthrough(v bool) { - f.passthrough = v -} - // CompressTransform is intended for use Recordio.PackedWriterOpts.Transform. func (f *FlateTransform) CompressTransform(bufs [][]byte) ([]byte, error) { if f.passthrough { diff --git a/recordio/recordioutil/flags.go b/recordio/recordioutil/flags.go index 215992d8..4e863d92 100644 --- a/recordio/recordioutil/flags.go +++ b/recordio/recordioutil/flags.go @@ -55,7 +55,6 @@ func (f *CompressionLevelFlag) String() string { return "huffman-only" } panic(fmt.Sprintf("unrecognised compression constant: %v", f.Level)) - return "unknown" } // WriterFlags represents the flags required to configure a recordioutil.Writer. diff --git a/recordio/recordioutil/rioutil.go b/recordio/recordioutil/rioutil.go index 0543b92a..d127ac38 100644 --- a/recordio/recordioutil/rioutil.go +++ b/recordio/recordioutil/rioutil.go @@ -7,9 +7,9 @@ package recordioutil import ( "io" - "github.com/klauspost/compress/flate" "github.com/grailbio/base/recordio" "github.com/grailbio/base/recordio/deprecated" + "github.com/klauspost/compress/flate" ) // WriterOpts represents the options accepted by NewWriter. @@ -40,7 +40,6 @@ type writer struct { deprecated.LegacyPackedWriter compressor *FlateTransform opts WriterOpts - flushed func() error } // NewWriter returns a recordio.LegacyPackedWriter that can optionally compress diff --git a/recordio/recordioutil/v2_test.go b/recordio/recordioutil/v2_test.go index e47e097d..f0512258 100644 --- a/recordio/recordioutil/v2_test.go +++ b/recordio/recordioutil/v2_test.go @@ -10,8 +10,8 @@ import ( "github.com/grailbio/base/fileio" "github.com/grailbio/base/recordio" - "github.com/grailbio/base/recordio/recordioutil" "github.com/grailbio/base/recordio/deprecated" + "github.com/grailbio/base/recordio/recordioutil" "github.com/stretchr/testify/require" ) @@ -33,8 +33,10 @@ func readV1(t *testing.T, format fileio.FileType, buf *bytes.Buffer) (s []string func TestPacked(t *testing.T) { buf := &bytes.Buffer{} w := deprecated.NewLegacyPackedWriter(buf, deprecated.LegacyPackedWriterOpts{}) - w.Write([]byte("Foo")) - w.Write([]byte("Baz")) + _, err := w.Write([]byte("Foo")) + require.NoError(t, err) + _, err = w.Write([]byte("Baz")) + require.NoError(t, err) w.Flush() require.Equal(t, []string{"Foo", "Baz"}, readV1(t, fileio.GrailRIOPacked, buf)) } @@ -42,8 +44,10 @@ func TestPacked(t *testing.T) { func TestUnpacked(t *testing.T) { buf := &bytes.Buffer{} w := deprecated.NewLegacyWriter(buf, deprecated.LegacyWriterOpts{}) - w.Write([]byte("Foo")) - w.Write([]byte("Baz")) + _, err := w.Write([]byte("Foo")) + require.NoError(t, err) + _, err = w.Write([]byte("Baz")) + require.NoError(t, err) require.Equal(t, []string{"Foo", "Baz"}, readV1(t, fileio.GrailRIO, buf)) } @@ -51,8 +55,10 @@ func TestCompressed(t *testing.T) { buf := &bytes.Buffer{} w := deprecated.NewLegacyPackedWriter(buf, deprecated.LegacyPackedWriterOpts{ Transform: recordioutil.NewFlateTransform(-1).CompressTransform}) - w.Write([]byte("Foo")) - w.Write([]byte("Baz")) + _, err := w.Write([]byte("Foo")) + require.NoError(t, err) + _, err = w.Write([]byte("Baz")) + require.NoError(t, err) w.Flush() require.Equal(t, []string{"Foo", "Baz"}, readV1(t, fileio.GrailRIOPackedCompressed, buf)) } diff --git a/recordio/recordiozstd/recordiozstd.go b/recordio/recordiozstd/recordiozstd.go index 2f5da6fa..695d7c8c 100644 --- a/recordio/recordiozstd/recordiozstd.go +++ b/recordio/recordiozstd/recordiozstd.go @@ -4,5 +4,98 @@ package recordiozstd +import ( + "strconv" + "sync" + + "github.com/grailbio/base/compress/zstd" + "github.com/grailbio/base/recordio" + "github.com/grailbio/base/recordio/recordioiov" +) + // Name is the registered name of the zstd transformer. const Name = "zstd" + +func parseConfig(config string) (level int, err error) { + level = -1 + if config != "" { + level, err = strconv.Atoi(config) + } + return +} + +var tmpBufPool = sync.Pool{New: func() interface{} { return &[]byte{} }} + +// As of 2018-03, zstd.{Compress,Decompress} is much faster than +// io.{Reader,Writer}-based implementations, even though the former incurs extra +// copying. +// +// Reader/Writer impl: +// BenchmarkWrite-56 20 116151712 ns/op +// BenchmarkRead-56 30 45302918 ns/op +// +// Compress/Decompress impl: +// BenchmarkWrite-56 50 30034396 ns/op +// BenchmarkRead-56 50 23871334 ns/op +func flattenIov(in [][]byte) []byte { + totalBytes := recordioiov.TotalBytes(in) + + // storing only pointers in sync.Pool per https://github.com/golang/go/issues/16323 + slicePtr := tmpBufPool.Get().(*[]byte) + tmp := recordioiov.Slice(*slicePtr, totalBytes) + n := 0 + for _, inbuf := range in { + copy(tmp[n:], inbuf) + n += len(inbuf) + } + return tmp +} + +func zstdCompress(level int, scratch []byte, in [][]byte) ([]byte, error) { + if len(in) == 0 { + return zstd.CompressLevel(scratch, nil, level) + } + if len(in) == 1 { + return zstd.CompressLevel(scratch, in[0], level) + } + tmp := flattenIov(in) + d, err := zstd.CompressLevel(scratch, tmp, level) + tmpBufPool.Put(&tmp) + return d, err +} + +func zstdUncompress(scratch []byte, in [][]byte) ([]byte, error) { + if len(in) == 0 { + return zstd.Decompress(scratch, nil) + } + if len(in) == 1 { + return zstd.Decompress(scratch, in[0]) + } + tmp := flattenIov(in) + d, err := zstd.Decompress(scratch, tmp) + tmpBufPool.Put(&tmp) + return d, err +} + +var once = sync.Once{} + +// Init installs the zstd transformer in recordio. It can be called multiple +// times, but 2nd and later calls have no effect. +func Init() { + once.Do(func() { + recordio.RegisterTransformer( + Name, + func(config string) (recordio.TransformFunc, error) { + level, err := parseConfig(config) + if err != nil { + return nil, err + } + return func(scratch []byte, in [][]byte) ([]byte, error) { + return zstdCompress(level, scratch, in) + }, nil + }, + func(string) (recordio.TransformFunc, error) { + return zstdUncompress, nil + }) + }) +} diff --git a/recordio/recordiozstd/recordiozstd_cgo.go b/recordio/recordiozstd/recordiozstd_cgo.go deleted file mode 100644 index 20842560..00000000 --- a/recordio/recordiozstd/recordiozstd_cgo.go +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright 2018 GRAIL, Inc. All rights reserved. -// Use of this source code is governed by the Apache-2.0 -// license that can be found in the LICENSE file. - -// +build cgo - -// Package recordiozstd implements zstd compression and decompression. -// -// Adding "zstd" to WriterV2Opts.Transformer enables zstd with default -// compression level. "zstd 6" will enable zstd with compression level 6. -package recordiozstd - -import ( - "strconv" - "sync" - - "github.com/DataDog/zstd" - "github.com/grailbio/base/recordio" - "github.com/grailbio/base/recordio/recordioiov" -) - -var tmpBufPool = sync.Pool{New: func() interface{} { return []byte{} }} - -// As of 2018-03, zstd.{Compress,Decompress} is much faster than -// io.{Reader,Writer}-based implementations, even though the former incurs extra -// copying. -// -// Reader/Writer impl: -// BenchmarkWrite-56 20 116151712 ns/op -// BenchmarkRead-56 30 45302918 ns/op -// -// Compress/Decompress impl: -// BenchmarkWrite-56 50 30034396 ns/op -// BenchmarkRead-56 50 23871334 ns/op -func flattenIov(in [][]byte) []byte { - totalBytes := recordioiov.TotalBytes(in) - tmp := recordioiov.Slice(tmpBufPool.Get().([]byte), totalBytes) - n := 0 - for _, inbuf := range in { - copy(tmp[n:], inbuf) - n += len(inbuf) - } - return tmp -} - -func zstdCompress(level int, scratch []byte, in [][]byte) ([]byte, error) { - if len(in) == 0 { - return zstd.Compress(scratch, nil) - } - if len(in) == 1 { - return zstd.Compress(scratch, in[0]) - } - tmp := flattenIov(in) - d, err := zstd.CompressLevel(scratch, tmp, level) - tmpBufPool.Put(tmp) - return d, err -} - -func zstdUncompress(scratch []byte, in [][]byte) ([]byte, error) { - if len(in) == 0 { - return zstd.Decompress(scratch, nil) - } - if len(in) == 1 { - return zstd.Decompress(scratch, in[0]) - } - tmp := flattenIov(in) - d, err := zstd.Decompress(scratch, tmp) - tmpBufPool.Put(tmp) - return d, err -} - -var once = sync.Once{} - -// Init installs the zstd transformer in recordio. It can be called multiple -// times, but 2nd and later calls have no effect. -func Init() { - once.Do(func() { - recordio.RegisterTransformer( - Name, - func(config string) (recordio.TransformFunc, error) { - level := zstd.DefaultCompression - if config != "" { - var err error - level, err = strconv.Atoi(config) - if err != nil { - return nil, err - } - } - return func(scratch []byte, in [][]byte) ([]byte, error) { - return zstdCompress(level, scratch, in) - }, nil - }, - func(string) (recordio.TransformFunc, error) { - return zstdUncompress, nil - }) - }) -} diff --git a/recordio/recordiozstd/recordiozstd_nocgo.go b/recordio/recordiozstd/recordiozstd_nocgo.go deleted file mode 100644 index d63f70b9..00000000 --- a/recordio/recordiozstd/recordiozstd_nocgo.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2018 GRAIL, Inc. All rights reserved. -// Use of this source code is governed by the Apache-2.0 -// license that can be found in the LICENSE file. - -// +build !cgo - -// Copyright 2018 GRAIL, Inc. All rights reserved. -// Use of this source code is governed by the Apache 2.0 -// license that can be found in the LICENSE file. - -package recordiozstd - -import ( - "errors" - "sync" - - "github.com/grailbio/base/recordio" -) - -var once sync.Once - -// Init registers a dummy implementation for recordio zstd compression. -// The registered transformers always return an error. -func Init() { - once.Do(func() { - recordio.RegisterTransformer( - Name, - func(config string) (recordio.TransformFunc, error) { - return nil, errors.New("zstd not supported on non-cgo platforms") - }, - func(string) (recordio.TransformFunc, error) { - return nil, errors.New("zstd not supported on non-cgo platforms") - }) - }) -} diff --git a/recordio/scannerv2.go b/recordio/scannerv2.go index 8bb0de45..f2616228 100644 --- a/recordio/scannerv2.go +++ b/recordio/scannerv2.go @@ -10,9 +10,8 @@ import ( "io" "sync" - "github.com/grailbio/base/errorreporter" + "github.com/grailbio/base/errors" "github.com/grailbio/base/recordio/internal" - "github.com/pkg/errors" ) var scannerFreePool = sync.Pool{ @@ -70,7 +69,7 @@ func parseChunksToItems(rawItems *rawItemList, chunks [][]byte, transform Transf block := rawItems.bytes unItems, n := binary.Uvarint(block) if n <= 0 { - return errors.Errorf("recordio: failed to read number of packed items: %v", n) + return fmt.Errorf("recordio: failed to read number of packed items: %v", n) } nItems := int(unItems) pos := n @@ -84,7 +83,7 @@ func parseChunksToItems(rawItems *rawItemList, chunks [][]byte, transform Transf for i := 0; i < nItems; i++ { size, n := binary.Uvarint(block[pos:]) if n <= 0 { - return errors.Errorf("recordio: likely corrupt data, failed to read size of packed item %v: %v", i, n) + return fmt.Errorf("recordio: likely corrupt data, failed to read size of packed item %v: %v", i, n) } total += int(size) rawItems.cumSize[i] = total @@ -92,7 +91,7 @@ func parseChunksToItems(rawItems *rawItemList, chunks [][]byte, transform Transf } rawItems.firstOff = pos if total+pos != len(block) { - return errors.Errorf("recordio: corrupt block header, got block size %d, expected %d", len(block), total+pos) + return fmt.Errorf("recordio: corrupt block header, got block size %d, expected %d", len(block), total+pos) } return nil } @@ -162,7 +161,7 @@ type Scanner interface { } type scannerv2 struct { - err errorreporter.T + err errors.Once sc *internal.ChunkScanner opts ScannerOpts untransform TransformFunc @@ -240,7 +239,7 @@ func newScanner(in io.ReadSeeker, start, limit, nshard int, opts ScannerOpts) Sc if s == nil { panic("newScannerV2") } - s.err = errorreporter.T{Ignored: []error{io.EOF}} + s.err = errors.Once{Ignored: []error{io.EOF}} s.opts = opts s.untransform = nil s.header = nil @@ -260,12 +259,12 @@ func newScanner(in io.ReadSeeker, start, limit, nshard int, opts ScannerOpts) Sc func (s *scannerv2) readSpecialBlock(expectedMagic internal.MagicBytes, tr TransformFunc) []byte { if !s.sc.Scan() { - s.err.Set(errors.Errorf("Failed to read block %v", expectedMagic)) + s.err.Set(fmt.Errorf("Failed to read block %v", expectedMagic)) return nil } magic, chunks := s.sc.Block() if magic != expectedMagic { - s.err.Set(errors.Errorf("Failed to read block, expect %v, got %v", expectedMagic, magic)) + s.err.Set(fmt.Errorf("Failed to read block, expect %v, got %v", expectedMagic, magic)) return nil } rawItems := rawItemList{} @@ -275,7 +274,7 @@ func (s *scannerv2) readSpecialBlock(expectedMagic internal.MagicBytes, tr Trans return nil } if rawItems.len() != 1 { - s.err.Set(errors.Errorf("Wrong # of items in header block, %d", rawItems.len())) + s.err.Set(fmt.Errorf("Wrong # of items in header block, %d", rawItems.len())) return nil } return rawItems.item(0) @@ -295,7 +294,7 @@ func (s *scannerv2) readHeader() { if h.Key == KeyTransformer { str, ok := h.Value.(string) if !ok { - s.err.Set(errors.Errorf("Expect string value for key %v, but found %v", h.Key, h.Value)) + s.err.Set(fmt.Errorf("Expect string value for key %v, but found %v", h.Key, h.Value)) return } transformers = append(transformers, str) @@ -326,7 +325,7 @@ func (s *scannerv2) Trailer() []byte { return nil } if magic != internal.MagicTrailer { - s.err.Set(errors.Errorf("Did not found the trailer, instead found magic %v", magic)) + s.err.Set(fmt.Errorf("Did not found the trailer, instead found magic %v", magic)) return nil } rawItems := rawItemList{} @@ -336,7 +335,7 @@ func (s *scannerv2) Trailer() []byte { return nil } if rawItems.len() != 1 { - s.err.Set(errors.Errorf("Expect exactly one trailer item, but found %d", rawItems.len())) + s.err.Set(fmt.Errorf("Expect exactly one trailer item, but found %d", rawItems.len())) return nil } return rawItems.item(0) @@ -349,14 +348,14 @@ func (s *scannerv2) Get() interface{} { func (s *scannerv2) Seek(loc ItemLocation) { // TODO(saito) Avoid seeking the file if loc.Block points to the current block. if s.err.Err() == io.EOF { - s.err = errorreporter.T{} + s.err = errors.Once{} } s.sc.Seek(int64(loc.Block)) if !s.scanNextBlock() { return } if loc.Item >= s.rawItems.len() { - s.err.Set(errors.Errorf("Invalid location %+v, block has only %d items", loc, s.rawItems.len())) + s.err.Set(fmt.Errorf("Invalid location %+v, block has only %d items", loc, s.rawItems.len())) } s.nextItem = loc.Item } @@ -384,7 +383,7 @@ func (s *scannerv2) scanNextBlock() bool { // EOF return false } - s.err.Set(errors.Errorf("recordio: invalid magic number: %v", magic)) + s.err.Set(fmt.Errorf("recordio: invalid magic number: %v", magic)) return false } @@ -414,7 +413,7 @@ func (s *scannerv2) Err() error { func (s *scannerv2) Finish() error { err := s.Err() - s.err = errorreporter.T{} + s.err = errors.Once{} s.opts = ScannerOpts{} s.sc = nil s.untransform = nil diff --git a/recordio/transformer_test.go b/recordio/transformer_test.go index 9d269936..593b5d16 100644 --- a/recordio/transformer_test.go +++ b/recordio/transformer_test.go @@ -23,8 +23,8 @@ func transformerTest(t *testing.T, name string) float64 { Transformers: []string{name}, }) // Write lots of compressible data. - const itemSize = 16 << 10 - const nRecs = 1000 + const itemSize = 16 << 8 + const nRecs = 300 for i := 0; i < nRecs; i++ { data := make([]byte, itemSize) for j := range data { diff --git a/recordio/v2_test.go b/recordio/v2_test.go index 06755573..afaa13c4 100644 --- a/recordio/v2_test.go +++ b/recordio/v2_test.go @@ -14,13 +14,14 @@ import ( "github.com/grailbio/base/recordio" "github.com/grailbio/base/recordio/deprecated" + "github.com/grailbio/base/recordio/internal" "github.com/grailbio/base/recordio/recordioiov" + "github.com/grailbio/base/recordio/recordiozstd" "github.com/grailbio/testutil/assert" "github.com/grailbio/testutil/expect" ) -// The recordio chunk size -const chunkSize = 32 << 10 +func init() { recordiozstd.Init() } func marshalString(scratch []byte, v interface{}) ([]byte, error) { return []byte(v.(string)), nil @@ -80,7 +81,7 @@ func TestEmptyBody(t *testing.T) { buf := &bytes.Buffer{} wr := recordio.NewWriter(buf, recordio.WriterOpts{Marshal: marshalString}) assert.NoError(t, wr.Finish()) - assert.EQ(t, len(buf.Bytes()), chunkSize) // one header chunk + assert.EQ(t, len(buf.Bytes()), internal.ChunkSize) // one header chunk header, body, trailer := readAllV2(t, buf) assert.EQ(t, recordio.ParsedHeader(nil), header) assert.EQ(t, []string(nil), body) @@ -103,7 +104,7 @@ func TestV2NonEmptyHeaderEmptyBody(t *testing.T) { wr := recordio.NewWriter(buf, recordio.WriterOpts{Marshal: marshalString}) wr.AddHeader("Foo", "Hah") assert.NoError(t, wr.Finish()) - assert.EQ(t, len(buf.Bytes()), chunkSize) // one header chunk + assert.EQ(t, len(buf.Bytes()), internal.ChunkSize) // one header chunk header, body, trailer := readAllV2(t, buf) assert.EQ(t, recordio.ParsedHeader{recordio.KeyValue{"Foo", "Hah"}}, header) assert.EQ(t, []string(nil), body) @@ -112,11 +113,10 @@ func TestV2NonEmptyHeaderEmptyBody(t *testing.T) { func TestV2EmptyBodyNonEmptyTrailer(t *testing.T) { buf := &bytes.Buffer{} - wr := recordio.NewWriter(buf, recordio.WriterOpts{Marshal: marshalString}) - wr.AddHeader(recordio.KeyTrailer, true) + wr := recordio.NewWriter(buf, recordio.WriterOpts{Marshal: marshalString, KeyTrailer: true}) wr.SetTrailer([]byte("TTT")) assert.NoError(t, wr.Finish()) - assert.EQ(t, len(buf.Bytes()), 2*chunkSize) // header+trailer + assert.EQ(t, len(buf.Bytes()), 2*internal.ChunkSize) // header+trailer header, body, trailer := readAllV2(t, buf) assert.EQ(t, recordio.ParsedHeader{recordio.KeyValue{recordio.KeyTrailer, true}}, header) assert.EQ(t, []string(nil), body) @@ -125,12 +125,11 @@ func TestV2EmptyBodyNonEmptyTrailer(t *testing.T) { func TestV2LargeTrailer(t *testing.T) { buf := &bytes.Buffer{} - wr := recordio.NewWriter(buf, recordio.WriterOpts{Marshal: marshalString}) - wr.AddHeader(recordio.KeyTrailer, true) + wr := recordio.NewWriter(buf, recordio.WriterOpts{Marshal: marshalString, KeyTrailer: true}) wr.Append("XX") rnd := rand.New(rand.NewSource(0)) - largeData := randomString(chunkSize*10+100, rnd) + largeData := randomString(internal.ChunkSize*10+100, rnd) wr.SetTrailer([]byte(largeData)) assert.NoError(t, wr.Finish()) header, body, trailer := readAllV2(t, buf) @@ -149,8 +148,8 @@ func TestV2WriteRead(t *testing.T) { index[v.(string)] = loc return nil }, + KeyTrailer: true, }) - wr.AddHeader(recordio.KeyTrailer, true) wr.AddHeader("hh0", "vv0") wr.AddHeader("hh1", 12345) wr.AddHeader("hh2", uint16(234)) @@ -170,11 +169,11 @@ func TestV2WriteRead(t *testing.T) { recordio.KeyValue{"hh1", int64(12345)}, recordio.KeyValue{"hh2", uint64(234)}, }, header) - expect.EQ(t, "Trailer2", trailer) - expect.EQ(t, []string{"F0", "F1", "F2", "F3"}, body) + expect.EQ(t, trailer, "Trailer2") + expect.EQ(t, body, []string{"F0", "F1", "F2", "F3"}) // Test seeking - expect.EQ(t, 4, len(index)) + expect.EQ(t, len(index), 4) sc := recordio.NewScanner(bytes.NewReader(buf.Bytes()), recordio.ScannerOpts{ Unmarshal: unmarshalString, }) @@ -184,7 +183,73 @@ func TestV2WriteRead(t *testing.T) { sc.Seek(loc) expect.NoError(t, sc.Err()) expect.True(t, sc.Scan()) - expect.EQ(t, value, sc.Get().(string)) + expect.EQ(t, sc.Get().(string), value) + } +} + +func TestV2RestartWithSkipHeader(t *testing.T) { + ogBuf := &bytes.Buffer{} + index := make(map[string]recordio.ItemLocation) + + writerOpts := recordio.WriterOpts{ + Marshal: marshalString, + Index: func(loc recordio.ItemLocation, v interface{}) error { + index[v.(string)] = loc + return nil + }, + KeyTrailer: true, + } + + wr := recordio.NewWriter(ogBuf, writerOpts) + wr.AddHeader("hh0", "vv0") + wr.AddHeader("hh1", 12345) + wr.AddHeader("hh2", uint16(234)) + wr.Append("F0") + wr.Append("F1") + wr.Flush() + wr.Append("F2") + wr.Flush() + wr.Wait() + + bytesWrittenSoFar := uint64(32768 * 3) // 3 blocks have been written, 1 for header, 2 for data + + writerOpts.Index = func(loc recordio.ItemLocation, v interface{}) error { + loc.Block += bytesWrittenSoFar + index[v.(string)] = loc + return nil + } + writerOpts.SkipHeader = true + + // new buffer with the originally written bytes pre-populated + restartBuf := bytes.NewBuffer(ogBuf.Bytes()) + restartWriter := recordio.NewWriter(restartBuf, writerOpts) + + restartWriter.Append("F3") + restartWriter.SetTrailer([]byte("Trailer2")) + assert.NoError(t, restartWriter.Finish()) + + header, body, trailer := readAllV2(t, restartBuf) + expect.EQ(t, recordio.ParsedHeader{ + recordio.KeyValue{"trailer", true}, + recordio.KeyValue{"hh0", "vv0"}, + recordio.KeyValue{"hh1", int64(12345)}, + recordio.KeyValue{"hh2", uint64(234)}, + }, header) + expect.EQ(t, trailer, "Trailer2") + expect.EQ(t, body, []string{"F0", "F1", "F2", "F3"}) + + // Test seeking + expect.EQ(t, len(index), 4) + sc := recordio.NewScanner(bytes.NewReader(restartBuf.Bytes()), recordio.ScannerOpts{ + Unmarshal: unmarshalString, + }) + + for _, value := range body { + loc := index[value] + sc.Seek(loc) + expect.NoError(t, sc.Err()) + expect.True(t, sc.Scan()) + expect.EQ(t, sc.Get().(string), value) } } @@ -224,8 +289,8 @@ func TestV2TransformerError(t *testing.T) { assert.Regexp(t, wr.Err(), "synthetic transformer error") } -func TestV2Transformer(t *testing.T) { - bytewiseTransform := func(scratch []byte, in [][]byte, tr func(uint8) uint8) ([]byte, error) { +func getBytewiseTransformFunc() func(scratch []byte, in [][]byte, tr func(uint8) uint8) ([]byte, error) { + return func(scratch []byte, in [][]byte, tr func(uint8) uint8) ([]byte, error) { nBytes := recordioiov.TotalBytes(in) out := recordioiov.Slice(scratch, nBytes) n := 0 @@ -237,6 +302,10 @@ func TestV2Transformer(t *testing.T) { } return out, nil } +} + +func TestV2Transformer(t *testing.T) { + bytewiseTransform := getBytewiseTransformFunc() var nPlus, nMinus, nXor int32 // A transformer that adds N to every byte. @@ -279,17 +348,17 @@ func TestV2Transformer(t *testing.T) { wr := recordio.NewWriter(buf, recordio.WriterOpts{ Marshal: marshalString, Transformers: []string{"testplus 3", "testxor 111"}, + KeyTrailer: true, }) - wr.AddHeader(recordio.KeyTrailer, true) wr.Append("F0") wr.Append("F1") wr.Flush() wr.Append("F2") wr.SetTrailer([]byte("Trailer2")) assert.NoError(t, wr.Finish()) - assert.EQ(t, int32(3), nPlus) // two data + one trailer block - assert.EQ(t, int32(3), nXor) + assert.EQ(t, nPlus, int32(3)) // two data + one trailer block + assert.EQ(t, nXor, int32(3)) header, body, _ := readAllV2(t, buf) expect.EQ(t, recordio.ParsedHeader{ @@ -297,9 +366,85 @@ func TestV2Transformer(t *testing.T) { recordio.KeyValue{"transformer", "testxor 111"}, recordio.KeyValue{"trailer", true}, }, header) - expect.EQ(t, []string{"F0", "F1", "F2"}, body) - assert.EQ(t, int32(3), nPlus) - assert.EQ(t, int32(6), nXor) + expect.EQ(t, body, []string{"F0", "F1", "F2"}) + assert.EQ(t, nPlus, int32(3)) + assert.EQ(t, nXor, int32(6)) +} + +func TestV2TransformerWithRestart(t *testing.T) { + bytewiseTransform := getBytewiseTransformFunc() + var nPlus, nMinus, nXor int32 + + // A transformer that adds N to every byte. + recordio.RegisterTransformer("restart-testplus", + func(config string) (recordio.TransformFunc, error) { + delta, err := strconv.Atoi(config) + if err != nil { + return nil, err + } + return func(scratch []byte, in [][]byte) ([]byte, error) { + atomic.AddInt32(&nPlus, 1) + return bytewiseTransform(scratch, in, func(b uint8) uint8 { return b + uint8(delta) }) + }, nil + }, + func(config string) (recordio.TransformFunc, error) { + delta, err := strconv.Atoi(config) + if err != nil { + return nil, err + } + return func(scratch []byte, in [][]byte) ([]byte, error) { + atomic.AddInt32(&nMinus, 1) + return bytewiseTransform(scratch, in, func(b uint8) uint8 { return b - uint8(delta) }) + }, nil + }) + + // A transformer that xors every byte. + xorTransformerFactory := func(config string) (recordio.TransformFunc, error) { + delta, err := strconv.Atoi(config) + if err != nil { + return nil, err + } + return func(scratch []byte, in [][]byte) ([]byte, error) { + atomic.AddInt32(&nXor, 1) + return bytewiseTransform(scratch, in, func(b uint8) uint8 { return b ^ uint8(delta) }) + }, nil + } + recordio.RegisterTransformer("restart-testxor", xorTransformerFactory, xorTransformerFactory) + + ogBuf := &bytes.Buffer{} + writerOpts := recordio.WriterOpts{ + Marshal: marshalString, + Transformers: []string{"restart-testplus 3", "restart-testxor 111"}, + KeyTrailer: true, + } + wr := recordio.NewWriter(ogBuf, writerOpts) + + wr.Append("F0") + wr.Append("F1") + wr.Flush() + wr.Wait() + + restartBuf := bytes.NewBuffer(ogBuf.Bytes()) + + writerOpts.SkipHeader = true + wr = recordio.NewWriter(restartBuf, writerOpts) + + wr.Append("F2") + wr.SetTrailer([]byte("Trailer2")) + assert.NoError(t, wr.Finish()) + + assert.EQ(t, nPlus, int32(3)) // two data + one trailer block + assert.EQ(t, nXor, int32(3)) + + header, body, _ := readAllV2(t, restartBuf) + expect.EQ(t, recordio.ParsedHeader{ + recordio.KeyValue{"transformer", "restart-testplus 3"}, + recordio.KeyValue{"transformer", "restart-testxor 111"}, + recordio.KeyValue{"trailer", true}, + }, header) + expect.EQ(t, body, []string{"F0", "F1", "F2"}) + assert.EQ(t, nPlus, int32(3)) + assert.EQ(t, nXor, int32(6)) } func randomString(n int, r *rand.Rand) string { @@ -371,26 +516,29 @@ func doRandomTest( maxrecords int, datasize int, wopts recordio.WriterOpts) { - t.Logf("Start test with wopt %+v, nshards %d, maxrecords %d, datasize %d", wopts, nshard, maxrecords, datasize) - - rnd := rand.New(rand.NewSource(seed)) - var nRecords int - if maxrecords > 0 { - nRecords = rnd.Intn(maxrecords) + 1 - } - data, items, index := generateRandomRecordio(t, rnd, flushProbability, nRecords, datasize, wopts) + t.Run("r", func(t *testing.T) { + t.Parallel() + t.Logf("Start test with wopt %+v, nshards %d, maxrecords %d, datasize %d", wopts, nshard, maxrecords, datasize) + + rnd := rand.New(rand.NewSource(seed)) + var nRecords int + if maxrecords > 0 { + nRecords = rnd.Intn(maxrecords) + 1 + } + data, items, index := generateRandomRecordio(t, rnd, flushProbability, nRecords, datasize, wopts) - doShardedReads(t, data, 1, nshard, items) + doShardedReads(t, data, 1, nshard, items) - ropts := recordio.ScannerOpts{Unmarshal: unmarshalString} - sc := recordio.NewScanner(bytes.NewReader(data), ropts) - for _, value := range items { - loc := index[value] - sc.Seek(loc) - expect.NoError(t, sc.Err()) - expect.True(t, sc.Scan()) - expect.EQ(t, value, sc.Get().(string)) - } + ropts := recordio.ScannerOpts{Unmarshal: unmarshalString} + sc := recordio.NewScanner(bytes.NewReader(data), ropts) + for _, value := range items { + loc := index[value] + sc.Seek(loc) + expect.NoError(t, sc.Err()) + expect.True(t, sc.Scan()) + expect.EQ(t, value, sc.Get().(string)) + } + }) } func TestV2Random(t *testing.T) { @@ -398,23 +546,29 @@ func TestV2Random(t *testing.T) { maxrecords = 2000 datasize = 30 ) - doRandomTest(t, 0, 0.001, 2000, maxrecords, 40<<10, recordio.WriterOpts{}) - - doRandomTest(t, 0, 0.1, 1, maxrecords, datasize, recordio.WriterOpts{}) - doRandomTest(t, 0, 1.0, 1, maxrecords, datasize, recordio.WriterOpts{}) - doRandomTest(t, 0, 0.0, 1, maxrecords, datasize, recordio.WriterOpts{}) - doRandomTest(t, 0, 0.1, 1, maxrecords, datasize, recordio.WriterOpts{ - MaxFlushParallelism: 1, - }) - doRandomTest(t, 0, 0.1, 1000, maxrecords, datasize, recordio.WriterOpts{}) - doRandomTest(t, 0, 1.0, 3, maxrecords, 30, recordio.WriterOpts{}) - doRandomTest(t, 0, 0.0, 2, maxrecords, 30, recordio.WriterOpts{}) - // Make sure we generate blocks big enough so that - // shards have to straddle block boundaries. - // Make sure that lots of shards with a single record reads correctly. - doRandomTest(t, 0, 0.001, 2000, 1, datasize, recordio.WriterOpts{}) - // Same with an empty recordio file. - doRandomTest(t, 0, 0.001, 2000, 0, datasize, recordio.WriterOpts{}) + for wo := 0; wo < 2; wo++ { + opts := recordio.WriterOpts{} + if wo == 1 { + opts.Transformers = []string{"zstd"} + } + doRandomTest(t, 0, 0.001, 2000, maxrecords, 10<<10, opts) + doRandomTest(t, 0, 0.1, 1, maxrecords, datasize, opts) + doRandomTest(t, 0, 1.0, 1, maxrecords, datasize, opts) + doRandomTest(t, 0, 0.0, 1, maxrecords, datasize, opts) + + opts.MaxFlushParallelism = 1 + doRandomTest(t, 0, 0.1, 1, maxrecords, datasize, opts) + opts.MaxFlushParallelism = 0 + doRandomTest(t, 0, 0.1, 1000, maxrecords, datasize, opts) + doRandomTest(t, 0, 1.0, 3, maxrecords, 30, opts) + doRandomTest(t, 0, 0.0, 2, maxrecords, 30, opts) + // Make sure we generate blocks big enough so that + // shards have to straddle block boundaries. + // Make sure that lots of shards with a single record reads correctly. + doRandomTest(t, 0, 0.001, 2000, 1, datasize, opts) + // Same with an empty recordio file. + doRandomTest(t, 0, 0.001, 2000, 0, datasize, opts) + } } func TestRandomLargeWrites(t *testing.T) { diff --git a/recordio/writerv2.go b/recordio/writerv2.go index 6641bacd..9057a972 100644 --- a/recordio/writerv2.go +++ b/recordio/writerv2.go @@ -10,7 +10,7 @@ import ( "io" "sync" - "github.com/grailbio/base/errorreporter" + "github.com/grailbio/base/errors" "github.com/grailbio/base/recordio/internal" ) @@ -47,11 +47,14 @@ type IndexFunc func(loc ItemLocation, item interface{}) error type WriterOpts struct { // Marshal is called for every item added by Append. It serializes the the // record. If Marshal is nil, it defaults to a function that casts the value - // to []byte and returns it. + // to []byte and returns it. Marshal may be called concurrently. Marshal MarshalFunc // Index is called for every item added, just before it is written to - // storage. After Index is called, the Writer guarantees that it never touches + // storage. Index callback may be called concurrently and out of order of + // locations. + // + // After Index is called, the Writer guarantees that it never touches // the value again. The application may recycle the value in a freepool, if it // desires. Index may be nil. Index IndexFunc @@ -70,10 +73,17 @@ type WriterOpts struct { // If len(Transformers)==0, then an identity transformer is used. It will // return the block as is. // - // The following transformers are supported by default: + // Recordio package includes the following standard transformers: + // + // "zstd N" (N is -1 or an integer from 0 to 22): zstd compression level N. + // If " N" part is omitted or N=-1, the default compression level is used. + // To use zstd, import the 'recordiozstd' package and call + // 'recordiozstd.Init()' in an init() function. // // "flate N" (N is -1 or an integer from 0 to 9): flate compression level N. - // If " N" part is omitted or N=-1, the default compression level is used. + // If " N" part is omitted or N=-1, the default compression level is used. + // To use flate, import the 'recordioflate' package and call + // 'recordioflate.Init()' in an init() function. Transformers []string // MaxItems is the maximum number of items to pack into a single record. @@ -86,6 +96,14 @@ type WriterOpts struct { // DefaultMaxFlushParallelism. MaxFlushParallelism uint32 + // REQUIRES: AddHeader(KeyTrailer, true) has been called or the KeyTrailer + // option set to true. + KeyTrailer bool + + // SkipHeader skips writing out the header and starts in the + // `wStateWritingBody` state. + SkipHeader bool + // TODO(saito) Consider providing a flag to allow out-of-order writes, like // ConcurrentPackedWriter. } @@ -205,7 +223,7 @@ type writerv2 struct { // opts.MaxFlushParallelism. freeBlocks chan *writerv2Block opts WriterOpts - err errorreporter.T + err errors.Once fq flushQueue mu sync.Mutex @@ -218,7 +236,7 @@ type writerv2 struct { type flushQueue struct { freeBlocks chan *writerv2Block // Copy of writerv2.freeBlocks. opts WriterOpts // Copy of writerv2.opts. - err *errorreporter.T // Copy of writerv2.err. + err *errors.Once // Copy of writerv2.err. wr *internal.ChunkWriter // Raw chunk writer. transform TransformFunc @@ -272,6 +290,18 @@ func NewWriter(wr io.Writer, opts WriterOpts) Writer { opts: opts, freeBlocks: make(chan *writerv2Block, opts.MaxFlushParallelism), } + + if opts.SkipHeader { + w.state = wStateWritingBody + } else { + for _, val := range opts.Transformers { + w.header = append(w.header, KeyValue{KeyTransformer, val}) + } + } + if opts.KeyTrailer { + w.header = append(w.header, KeyValue{KeyTrailer, true}) + } + w.fq = flushQueue{ wr: internal.NewChunkWriter(wr, &w.err), opts: opts, @@ -289,9 +319,6 @@ func NewWriter(wr io.Writer, opts WriterOpts) Writer { if w.fq.transform, err = registry.getTransformer(opts.Transformers); err != nil { w.err.Set(err) } - for _, val := range opts.Transformers { - w.header = append(w.header, KeyValue{KeyTransformer, val}) - } return w } diff --git a/retry/retry.go b/retry/retry.go index 6a5cc87b..bcb00459 100644 --- a/retry/retry.go +++ b/retry/retry.go @@ -7,7 +7,10 @@ package retry import ( "context" + "fmt" "math" + "math/rand" + "reflect" "time" "github.com/grailbio/base/errors" @@ -30,7 +33,7 @@ type Policy interface { func Wait(ctx context.Context, policy Policy, retry int) error { keepgoing, wait := policy.Retry(retry) if !keepgoing { - return errors.E(errors.TooManyTries, "retry policy ran out of tries") + return errors.E(errors.TooManyTries, fmt.Sprintf("gave up after %d tries", retry)) } if deadline, ok := ctx.Deadline(); ok && time.Until(deadline) < wait { return errors.E(errors.Timeout, "ran out of time while waiting for retry") @@ -43,15 +46,55 @@ func Wait(ctx context.Context, policy Policy, retry int) error { } } +// WaitForFn uses the above Wait function taking the same policy and retry +// number and generalizes it for a use of a function. Just like Wait it +// errors in the cases of extra tries, context cancel, or if its deadline +// runs out waiting for the next try +func WaitForFn(ctx context.Context, policy Policy, fn interface{}, params ...interface{}) (result []reflect.Value) { + var out []reflect.Value + f := reflect.ValueOf(fn) + inputs := make([]reflect.Value, len(params)) + for i, in := range params { + inputs[i] = reflect.ValueOf(in) + } + + // will break out of loop if function doesn't error + for retries := 0; ; retries++ { + out = f.Call(inputs) + if out[len(out)-1].IsNil() { // assumes last output value of function is an error object + break + } + if retryErr := Wait(ctx, policy, retries); retryErr != nil { + return out + } + } + + return out +} + type backoff struct { factor float64 initial, max time.Duration } -// Backoff returns a Policy that nitially waits for the amount of +// maxInt64Convertible is the maximum float64 that can be converted to an int64 +// accurately. We use this to prevent overflow when computing the exponential +// backoff, which we compute with float64s. It is important that we push it +// through float64 then int64 so that we get compilation error if we use a +// value that cannot be represented as an int64. This value was produced with: +// math.Nextafter(float64(math.MaxInt64), 0) +const maxInt64Convertible = int64(float64(9223372036854774784)) + +// MaxBackoffMax is the maximum value that can be passed as max to Backoff. +const MaxBackoffMax = time.Duration(maxInt64Convertible) + +// Backoff returns a Policy that initially waits for the amount of // time specified by parameter initial; on each try this value is // multiplied by the provided factor, up to the max duration. func Backoff(initial, max time.Duration, factor float64) Policy { + if max > MaxBackoffMax { + panic("max > MaxBackoffMax") + } return &backoff{ initial: initial, max: max, @@ -60,9 +103,72 @@ func Backoff(initial, max time.Duration, factor float64) Policy { } func (b *backoff) Retry(retries int) (bool, time.Duration) { - wait := time.Duration(float64(b.initial) * math.Pow(b.factor, float64(retries))) - if wait > b.max { - wait = b.max + if retries < 0 { + panic("retries < 0") + } + nsfloat64 := float64(b.initial) * math.Pow(b.factor, float64(retries)) + nsfloat64 = math.Min(nsfloat64, float64(b.max)) + return true, time.Duration(int64(nsfloat64)) +} + +// BackoffWithTimeout returns a Policy that initially waits for the amount of +// time specified by parameter initial; on each try this value is +// multiplied by the provided factor, up to the max duration. +// After the max duration, the Policy will timeout and return an error. +func BackoffWithTimeout(initial, max time.Duration, factor float64) Policy { + n := int(math.Floor(math.Log(float64(max/initial))/math.Log(factor))) + 1 + return MaxRetries(Backoff(initial, max, factor), n) +} + +type jitter struct { + policy Policy + // frac is the fraction of the wait time to "jitter". + // Eg: if frac is 0.2, the policy will retain 80% of the wait time + // and jitter the remaining 20% + frac float64 +} + +// Jitter returns a policy that jitters 'frac' fraction of the wait times +// returned by the provided policy. For example, setting frac to 1.0 and 0.5 +// will implement "full jitter" and "equal jitter" approaches respectively. +// These approaches are describer here: +// https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ +func Jitter(policy Policy, frac float64) Policy { + return &jitter{policy, frac} +} + +func (b *jitter) Retry(retries int) (bool, time.Duration) { + ok, wait := b.policy.Retry(retries) + if wait > 0 { + prop := time.Duration(b.frac * float64(wait)) + wait = wait - prop + time.Duration(rand.Int63n(prop.Nanoseconds())) + } + return ok, wait +} + +type maxtries struct { + policy Policy + max int +} + +// MaxRetries returns a policy that enforces a maximum number of +// attempts. The provided policy is invoked when the current number +// of tries is within the permissible limit. If policy is nil, the +// returned policy will permit an immediate retry when the number of +// tries is within the allowable limits. +func MaxRetries(policy Policy, n int) Policy { + if n < 1 { + panic("retry.MaxRetries: n < 1") + } + return &maxtries{policy, n - 1} +} + +func (m *maxtries) Retry(retries int) (bool, time.Duration) { + if retries > m.max { + return false, time.Duration(0) + } + if m.policy != nil { + return m.policy.Retry(retries) } - return true, wait + return true, time.Duration(0) } diff --git a/retry/retry_test.go b/retry/retry_test.go index 9027eda6..a93d93ae 100644 --- a/retry/retry_test.go +++ b/retry/retry_test.go @@ -6,9 +6,12 @@ package retry import ( "context" + "fmt" "testing" "time" + "github.com/stretchr/testify/require" + "github.com/grailbio/base/errors" ) @@ -33,6 +36,103 @@ func TestBackoff(t *testing.T) { } } +// TestBackoffOverflow tests the behavior of exponential backoff for large +// numbers of retries. +func TestBackoffOverflow(t *testing.T) { + policy := Backoff(time.Second, 10*time.Second, 2) + expect := []time.Duration{ + 10 * time.Second, + 10 * time.Second, + 10 * time.Second, + 10 * time.Second, + } + for retries, wait := range expect { + // Use a large number of retries that might overflow exponential + // calculations. + keepgoing, dur := policy.Retry(1000 + retries) + if !keepgoing { + t.Fatal("!keepgoing") + } + if got, want := dur, wait; got != want { + t.Errorf("retry %d: got %v, want %v", retries, got, want) + } + } +} + +func TestBackoffWithFullJitter(t *testing.T) { + policy := Jitter(Backoff(time.Second, 10*time.Second, 2), 1.0) + checkWithin := func(t *testing.T, wantMin, wantMax, got time.Duration) { + if got < wantMin || got > wantMax { + t.Errorf("got %v, want within (%v, %v)", got, wantMin, wantMax) + } + } + expect := []time.Duration{ + time.Second, + 2 * time.Second, + 4 * time.Second, + 8 * time.Second, + 10 * time.Second, + 10 * time.Second, + } + for retries, wait := range expect { + keepgoing, dur := policy.Retry(retries) + if !keepgoing { + t.Fatal("!keepgoing") + } + checkWithin(t, 0, wait, dur) + } +} + +func TestBackoffWithEqualJitter(t *testing.T) { + policy := Jitter(Backoff(time.Second, 10*time.Second, 2), 0.5) + checkWithin := func(t *testing.T, wantMin, wantMax, got time.Duration) { + if got < wantMin || got > wantMax { + t.Errorf("got %v, want within (%v, %v)", got, wantMin, wantMax) + } + } + expect := []time.Duration{ + time.Second, + 2 * time.Second, + 4 * time.Second, + 8 * time.Second, + 10 * time.Second, + 10 * time.Second, + } + for retries, wait := range expect { + keepgoing, dur := policy.Retry(retries) + if !keepgoing { + t.Fatal("!keepgoing") + } + checkWithin(t, wait/2, wait, dur) + } +} + +func TestBackoffWithTimeout(t *testing.T) { + policy := BackoffWithTimeout(time.Second, 10*time.Second, 2) + expect := []time.Duration{ + time.Second, + 2 * time.Second, + 4 * time.Second, + 8 * time.Second, + } + var retries = 0 + for _, wait := range expect { + keepgoing, dur := policy.Retry(retries) + if !keepgoing { + t.Fatal("!keepgoing") + } + if got, want := dur, wait; got != want { + t.Errorf("retry %d: got %v, want %v", retries, got, want) + } + retries++ + } + keepgoing, _ := policy.Retry(retries) + if keepgoing { + t.Errorf("keepgoing: got %v, want %v", keepgoing, false) + } + +} + func TestWaitCancel(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) policy := Backoff(time.Hour, time.Hour, 1) @@ -50,3 +150,89 @@ func TestWaitDeadline(t *testing.T) { t.Errorf("got %v, want %v", got, want) } } + +func testWrapperHelper(i int) (int, error) { + if i == 0 { + return 0, fmt.Errorf("This is an Error") + } + return 9999, nil +} + +func testWrapperHelperLong(i int) (int, int, error) { + if i == 0 { + return 0, 0, fmt.Errorf("This is an Error") + } + return 1, 2, nil +} + +func TestWaitForFn(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + policy := Backoff(time.Hour, time.Hour, 1) + cancel() + + output := WaitForFn(ctx, policy, testWrapperHelper, 0) + require.EqualError(t, output[1].Interface().(error), "This is an Error") + + output = WaitForFn(ctx, policy, testWrapperHelper, 55) + require.Equal(t, 9999, int(output[0].Int())) + + var err error + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("wrong number of input, expected: 1, actual: 3") + } + }() + WaitForFn(ctx, policy, testWrapperHelper, 1, 2, 3) + require.EqualError(t, err, "wrong number of input, expected: 1, actual: 3") +} + +func TestWaitForFnLong(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + policy := Backoff(time.Hour, time.Hour, 1) + cancel() + + output := WaitForFn(ctx, policy, testWrapperHelperLong, 0) + require.EqualError(t, output[2].Interface().(error), "This is an Error") + + output = WaitForFn(ctx, policy, testWrapperHelperLong, 55) + require.Equal(t, 1, int(output[0].Int())) + require.Equal(t, 2, int(output[1].Int())) + +} + +func TestMaxRetries(t *testing.T) { + retryImmediately := Backoff(0, 0, 0) + + type testArgs struct { + retryPolicy Policy + fn func(*int) error + } + testCases := []struct { + testName string + args testArgs + expected int + }{ + { + testName: "function always fails", + args: testArgs{ + retryPolicy: MaxRetries(retryImmediately, 1), + fn: func(callCount *int) error { + *callCount++ + + return fmt.Errorf("always fail") + }, + }, + expected: 2, + }, + } + + for _, tc := range testCases { + t.Run(tc.testName, func(t *testing.T) { + callCount := 0 + + WaitForFn(context.Background(), tc.args.retryPolicy, tc.args.fn, &callCount) + + require.Equal(t, tc.expected, callCount) + }) + } +} diff --git a/s3util/s3copy.go b/s3util/s3copy.go new file mode 100644 index 00000000..b7994791 --- /dev/null +++ b/s3util/s3copy.go @@ -0,0 +1,219 @@ +package s3util + +import ( + "context" + "fmt" + "net/url" + "strings" + "time" + + "github.com/grailbio/base/errors" + "github.com/grailbio/base/retry" + "github.com/grailbio/base/traverse" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/aws/aws-sdk-go/service/s3/s3iface" +) + +const ( + // DefaultS3ObjectCopySizeLimit is the max size of object for a single PUT Object Copy request. + // As per AWS: https://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectCOPY.html + // the max size allowed is 5GB, but we use a smaller size here to speed up large file copies. + DefaultS3ObjectCopySizeLimit = 256 << 20 // 256MiB + + // defaultS3MultipartCopyPartSize is the max size of each part when doing a multi-part copy. + // Note: Though we can do parts of size up to defaultS3ObjectCopySizeLimit, for large files + // using smaller size parts (concurrently) is much faster. + DefaultS3MultipartCopyPartSize = 128 << 20 // 128MiB + + // s3MultipartCopyConcurrencyLimit is the number of concurrent parts to do during a multi-part copy. + s3MultipartCopyConcurrencyLimit = 100 + + defaultMaxRetries = 3 +) + +var ( + // DefaultRetryPolicy is the default retry policy + DefaultRetryPolicy = retry.MaxRetries(retry.Jitter(retry.Backoff(1*time.Second, time.Minute, 2), 0.25), defaultMaxRetries) +) + +type Debugger interface { + Debugf(format string, args ...interface{}) +} + +type noOpDebugger struct{} + +func (d noOpDebugger) Debugf(format string, args ...interface{}) {} + +// Copier supports operations to copy S3 objects (within or across buckets) +// by using S3 APIs that support the same (ie, without having to stream the data by reading and writing). +// +// Since AWS doesn't allow copying large files in a single operation, +// this will do a multi-part copy object in those cases. +// However, this behavior can also be controlled by setting appropriate values +// for S3ObjectCopySizeLimit and S3MultipartCopyPartSize. + +type Copier struct { + client s3iface.S3API + retrier retry.Policy + + // S3ObjectCopySizeLimit is the max size of object for a single PUT Object Copy request. + S3ObjectCopySizeLimit int64 + // S3MultipartCopyPartSize is the max size of each part when doing a multi-part copy. + S3MultipartCopyPartSize int64 + + Debugger +} + +func NewCopier(client s3iface.S3API) *Copier { + return NewCopierWithParams(client, DefaultRetryPolicy, DefaultS3ObjectCopySizeLimit, DefaultS3MultipartCopyPartSize, nil) +} + +func NewCopierWithParams(client s3iface.S3API, retrier retry.Policy, s3ObjectCopySizeLimit int64, s3MultipartCopyPartSize int64, debugger Debugger) *Copier { + if debugger == nil { + debugger = noOpDebugger{} + } + return &Copier{ + client: client, + retrier: retrier, + S3ObjectCopySizeLimit: s3ObjectCopySizeLimit, + S3MultipartCopyPartSize: s3MultipartCopyPartSize, + Debugger: debugger, + } +} + +// Copy copies the S3 object from srcUrl to dstUrl (both expected to be full S3 URLs) +// The size of the source object (srcSize) determines behavior (whether done as single or multi-part copy). +// +// dstMetadata must be set if the caller wishes to set the metadata on the dstUrl object. +// While the AWS API will copy the metadata over if done using CopyObject, but NOT when multi-part copy is done, +// this method requires that dstMetadata be always provided to remove ambiguity. +// So if metadata is desired on dstUrl object, *it must always be provided*. +func (c *Copier) Copy(ctx context.Context, srcUrl, dstUrl string, srcSize int64, dstMetadata map[string]*string) error { + copySrc := strings.TrimPrefix(srcUrl, "s3://") + dstBucket, dstKey, err := bucketKey(dstUrl) + if err != nil { + return err + } + if srcSize <= c.S3ObjectCopySizeLimit { + // Do single copy + input := &s3.CopyObjectInput{ + Bucket: aws.String(dstBucket), + Key: aws.String(dstKey), + CopySource: aws.String(copySrc), + Metadata: dstMetadata, + } + for retries := 0; ; retries++ { + _, err = c.client.CopyObjectWithContext(ctx, input) + err = CtxErr(ctx, err) + if err == nil { + break + } + severity := Severity(err) + if severity != errors.Temporary && severity != errors.Retriable { + break + } + c.Debugf("s3copy.Copy: attempt (%d): %s -> %s\n%v\n", retries, srcUrl, dstUrl, err) + if err = retry.Wait(ctx, c.retrier, retries); err != nil { + break + } + } + if err == nil { + c.Debugf("s3copy.Copy: done: %s -> %s", srcUrl, dstUrl) + } + return err + } + // Do a multi-part copy + numParts := (srcSize + c.S3MultipartCopyPartSize - 1) / c.S3MultipartCopyPartSize + input := &s3.CreateMultipartUploadInput{ + Bucket: aws.String(dstBucket), + Key: aws.String(dstKey), + Metadata: dstMetadata, + } + createOut, err := c.client.CreateMultipartUploadWithContext(ctx, input) + if err != nil { + return errors.E(fmt.Sprintf("CreateMultipartUpload: %s -> %s", srcUrl, dstUrl), err) + } + completedParts := make([]*s3.CompletedPart, numParts) + err = traverse.Limit(s3MultipartCopyConcurrencyLimit).Each(int(numParts), func(ti int) error { + i := int64(ti) + firstByte := i * c.S3MultipartCopyPartSize + lastByte := firstByte + c.S3MultipartCopyPartSize - 1 + if lastByte >= srcSize { + lastByte = srcSize - 1 + } + var partErr error + var uploadOut *s3.UploadPartCopyOutput + for retries := 0; ; retries++ { + uploadOut, partErr = c.client.UploadPartCopyWithContext(ctx, &s3.UploadPartCopyInput{ + Bucket: aws.String(dstBucket), + Key: aws.String(dstKey), + CopySource: aws.String(copySrc), + UploadId: createOut.UploadId, + PartNumber: aws.Int64(i + 1), + CopySourceRange: aws.String(fmt.Sprintf("bytes=%d-%d", firstByte, lastByte)), + }) + partErr = CtxErr(ctx, partErr) + if partErr == nil { + break + } + severity := Severity(partErr) + if severity != errors.Temporary && severity != errors.Retriable { + break + } + c.Debugf("s3copy.Copy: attempt (%d) (part %d/%d): %s -> %s\n%v\n", retries, i, numParts, srcUrl, dstUrl, partErr) + if partErr = retry.Wait(ctx, c.retrier, retries); partErr != nil { + break + } + } + if partErr == nil { + completedParts[i] = &s3.CompletedPart{ETag: uploadOut.CopyPartResult.ETag, PartNumber: aws.Int64(i + 1)} + c.Debugf("s3copy.Copy: done (part %d/%d): %s -> %s", i, numParts, srcUrl, dstUrl) + return nil + } + return errors.E(fmt.Sprintf("upload part copy (part %d/%d) %s -> %s", i, numParts, srcUrl, dstUrl), partErr) + }) + if err == nil { + // Complete the multi-part copy + for retries := 0; ; retries++ { + _, err = c.client.CompleteMultipartUploadWithContext(ctx, &s3.CompleteMultipartUploadInput{ + Bucket: aws.String(dstBucket), + Key: aws.String(dstKey), + UploadId: createOut.UploadId, + MultipartUpload: &s3.CompletedMultipartUpload{Parts: completedParts}, + }) + if err == nil || Severity(err) != errors.Temporary { + break + } + c.Debugf("s3copy.Copy complete upload: attempt (%d): %s -> %s\n%v\n", retries, srcUrl, dstUrl, err) + if err = retry.Wait(ctx, c.retrier, retries); err != nil { + break + } + } + if err == nil { + c.Debugf("s3copy.Copy: done (all %d parts): %s -> %s", numParts, srcUrl, dstUrl) + return nil + } + err = errors.E(fmt.Sprintf("complete multipart upload %s -> %s", srcUrl, dstUrl), Severity(err), err) + } + // Abort the multi-part copy + if _, er := c.client.AbortMultipartUploadWithContext(ctx, &s3.AbortMultipartUploadInput{ + Bucket: aws.String(dstBucket), + Key: aws.String(dstKey), + UploadId: createOut.UploadId, + }); er != nil { + err = fmt.Errorf("abort multipart copy %v (aborting due to original error: %v)", er, err) + } + return err +} + +// bucketKey returns the bucket and key for the given S3 object url and error (if any). +func bucketKey(rawurl string) (string, string, error) { + u, err := url.Parse(rawurl) + if err != nil { + return "", "", errors.E(errors.Invalid, errors.Fatal, fmt.Sprintf("cannot determine bucket and key from rawurl %s", rawurl), err) + } + bucket := u.Host + return bucket, strings.TrimPrefix(rawurl, "s3://"+bucket+"/"), nil +} diff --git a/s3util/s3copy_test.go b/s3util/s3copy_test.go new file mode 100644 index 00000000..c04cf50e --- /dev/null +++ b/s3util/s3copy_test.go @@ -0,0 +1,245 @@ +package s3util + +import ( + "bytes" + "context" + "fmt" + "io/ioutil" + "math/rand" + "testing" + "time" + + "github.com/grailbio/base/retry" + "github.com/grailbio/testutil" + "github.com/grailbio/testutil/s3test" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/s3" +) + +const testBucket = "test_bucket" + +var ( + testKeys = map[string]*testutil.ByteContent{"test/x": content("some sample content")} + errorKeys = map[string]error{ + "key_awscanceled": awserr.New(request.CanceledErrorCode, "test", nil), + "key_nosuchkey": awserr.New(s3.ErrCodeNoSuchKey, "test", nil), + "key_badrequest": awserr.New("BadRequest", "test", nil), + "key_canceled": context.Canceled, + "key_deadlineexceeded": context.DeadlineExceeded, + "key_awsrequesttimeout": awserr.New("RequestTimeout", "test", nil), + "key_nestedEOFrequest": awserr.New("MultipartUpload", "test", awserr.New("SerializationError", "test2", fmt.Errorf("unexpected EOF"))), + "key_awsinternalerror": awserr.New("InternalError", "test", nil), + } +) + +func newTestClient(t *testing.T) *s3test.Client { + t.Helper() + client := s3test.NewClient(t, testBucket) + client.Region = "us-west-2" + for k, v := range testKeys { + client.SetFileContentAt(k, v, "") + } + return client +} + +func newFailingTestClient(t *testing.T, fn *failN) *s3test.Client { + t.Helper() + client := newTestClient(t) + client.Err = func(api string, input interface{}) error { + switch api { + case "UploadPartCopyWithContext": + if upc, ok := input.(*s3.UploadPartCopyInput); ok { + // Possibly fail the first part with an error based on the key + if *upc.PartNumber == int64(1) && fn.fail() { + return errorKeys[*upc.Key] + } + } + case "CopyObjectRequest": + if req, ok := input.(*s3.CopyObjectInput); ok && fn.fail() { + return errorKeys[*req.Key] + } + } + return nil + } + return client +} + +func TestBucketKey(t *testing.T) { + for _, tc := range []struct { + url, wantBucket, wantKey string + wantErr bool + }{ + {"s3://bucket/key", "bucket", "key", false}, + {"s3://some_other-bucket/very/long/key", "some_other-bucket", "very/long/key", false}, + } { + gotB, gotK, gotE := bucketKey(tc.url) + if tc.wantErr && gotE == nil { + t.Errorf("%s got no error, want error", tc.url) + continue + } + if got, want := gotB, tc.wantBucket; got != want { + t.Errorf("got %s want %s", got, want) + } + if got, want := gotK, tc.wantKey; got != want { + t.Errorf("got %s want %s", got, want) + } + } +} + +func TestCopy(t *testing.T) { + client := newTestClient(t) + copier := NewCopier(client) + + srcKey, srcSize, dstKey := "test/x", testKeys["test/x"].Size(), "test/x_copy" + srcUrl := fmt.Sprintf("s3://%s/%s", testBucket, srcKey) + dstUrl := fmt.Sprintf("s3://%s/%s", testBucket, dstKey) + + checkObject(t, client, srcKey, testKeys[srcKey]) + if err := copier.Copy(context.Background(), srcUrl, dstUrl, srcSize, nil); err != nil { + t.Fatal(err) + } + checkObject(t, client, dstKey, testKeys[srcKey]) +} + +func TestCopyWithRetry(t *testing.T) { + client := newFailingTestClient(t, &failN{n: 2}) + retrier := retry.MaxRetries(retry.Jitter(retry.Backoff(10*time.Millisecond, 50*time.Millisecond, 2), 0.25), 4) + copier := NewCopierWithParams(client, retrier, 1<<10, 1<<10, testDebugger{t}) + + for _, tc := range []struct { + srcKey string + dstKey string + srcSize int64 + }{ + { + srcKey: "test/x", + dstKey: "key_awsrequesttimeout", + srcSize: testKeys["test/x"].Size(), + }, + { + srcKey: "test/x", + dstKey: "key_awsinternalerror", + srcSize: testKeys["test/x"].Size(), + }, + } { + srcUrl := fmt.Sprintf("s3://%s/%s", testBucket, tc.srcKey) + dstUrl := fmt.Sprintf("s3://%s/%s", testBucket, tc.dstKey) + checkObject(t, client, tc.srcKey, testKeys[tc.srcKey]) + if err := copier.Copy(context.Background(), srcUrl, dstUrl, tc.srcSize, nil); err != nil { + t.Fatal(err) + } + checkObject(t, client, tc.dstKey, testKeys[tc.srcKey]) + } + +} + +func TestCopyMultipart(t *testing.T) { + bctx := context.Background() + for _, tc := range []struct { + client *s3test.Client + dstKey string + size, limit, partsize int64 + useShortCtx, cancelCtx bool + wantErr bool + }{ + // 100KiB of data, multi-part limit 50KiB, part size 10KiB + {newTestClient(t), "dst1", 100 << 10, 50 << 10, 10 << 10, false, false, false}, + // 50KiB of data, multi-part limit 50KiB, part size 10KiB + {newTestClient(t), "dst2", 50 << 10, 50 << 10, 10 << 10, false, false, false}, + {newTestClient(t), "dst3", 100 << 10, 50 << 10, 10 << 10, true, false, true}, + {newTestClient(t), "dst4", 100 << 10, 50 << 10, 10 << 10, false, true, true}, + {newFailingTestClient(t, &failN{n: 2}), "key_badrequest", 100 << 10, 50 << 10, 10 << 10, false, false, false}, + {newFailingTestClient(t, &failN{n: 2}), "key_deadlineexceeded", 100 << 10, 50 << 10, 10 << 10, false, false, false}, + {newFailingTestClient(t, &failN{n: 2}), "key_awsrequesttimeout", 100 << 10, 50 << 10, 10 << 10, false, false, false}, + {newFailingTestClient(t, &failN{n: 2}), "key_nestedEOFrequest", 100 << 10, 50 << 10, 10 << 10, false, false, false}, + {newFailingTestClient(t, &failN{n: 2}), "key_canceled", 100 << 10, 50 << 10, 10 << 10, false, false, true}, + {newFailingTestClient(t, &failN{n: defaultMaxRetries + 1}), "key_badrequest", 100 << 10, 50 << 10, 10 << 10, false, false, true}, + } { + client := tc.client + b := make([]byte, tc.size) + if _, err := rand.Read(b); err != nil { + t.Fatal(err) + } + srcKey, srcContent := "src", &testutil.ByteContent{Data: b} + client.SetFileContentAt(srcKey, srcContent, "") + checkObject(t, client, srcKey, srcContent) + + retrier := retry.MaxRetries(retry.Jitter(retry.Backoff(10*time.Millisecond, 50*time.Millisecond, 2), 0.25), defaultMaxRetries) + copier := NewCopierWithParams(client, retrier, tc.limit, tc.partsize, testDebugger{t}) + + ctx := bctx + var cancel context.CancelFunc + if tc.useShortCtx { + ctx, cancel = context.WithTimeout(bctx, 10*time.Nanosecond) + } else if tc.cancelCtx { + ctx, cancel = context.WithCancel(bctx) + cancel() + } + srcUrl := fmt.Sprintf("s3://%s/%s", testBucket, srcKey) + dstUrl := fmt.Sprintf("s3://%s/%s", testBucket, tc.dstKey) + + err := copier.Copy(ctx, srcUrl, dstUrl, tc.size, nil) + if cancel != nil { + cancel() + } + if tc.wantErr { + if err == nil { + t.Errorf("%s got no error, want error", tc.dstKey) + } + continue + } + if err != nil { + t.Fatal(err) + } + checkObject(t, client, tc.dstKey, srcContent) + if t.Failed() { + t.Logf("case: %v", tc) + } + } +} + +func content(s string) *testutil.ByteContent { + return &testutil.ByteContent{Data: []byte(s)} +} + +func checkObject(t *testing.T, client *s3test.Client, key string, c *testutil.ByteContent) { + t.Helper() + out, err := client.GetObject(&s3.GetObjectInput{ + Bucket: aws.String(testBucket), + Key: aws.String(key), + }) + if err != nil { + t.Fatal(err) + } + p, err := ioutil.ReadAll(out.Body) + if err != nil { + t.Fatal(err) + } + if got, want := p, c.Data; !bytes.Equal(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +// failN returns true n times when fail() is called and then returns false, until its reset. +type failN struct { + n, i int +} + +func (p *failN) fail() bool { + if p.i < p.n { + p.i++ + return true + } + return false +} + +func (p *failN) reset() { + p.i = 0 +} + +type testDebugger struct{ *testing.T } + +func (d testDebugger) Debugf(format string, args ...interface{}) { d.T.Logf(format, args...) } diff --git a/s3util/s3error.go b/s3util/s3error.go new file mode 100644 index 00000000..2278d055 --- /dev/null +++ b/s3util/s3error.go @@ -0,0 +1,107 @@ +package s3util + +import ( + "context" + + "github.com/grailbio/base/errors" + + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/s3" +) + +// CtxErr will return the context's error (if any) or the other error. +// This is particularly useful to interpret AWS S3 API call errors +// because AWS sometimes wraps context errors (context.Canceled or context.DeadlineExceeded). +func CtxErr(ctx context.Context, other error) error { + if ctx.Err() != nil { + return ctx.Err() + } + return other +} + +// KindAndSeverity interprets a given error and returns errors.Severity. +// This is particularly useful to interpret AWS S3 API call errors. +func Severity(err error) errors.Severity { + if aerr, ok := err.(awserr.Error); ok { + _, severity := KindAndSeverity(aerr) + return severity + } + if re := errors.Recover(err); re != nil { + return re.Severity + } + return errors.Unknown +} + +// KindAndSeverity interprets a given error and returns errors.Kind and errors.Severity. +// This is particularly useful to interpret AWS S3 API call errors. +func KindAndSeverity(err error) (errors.Kind, errors.Severity) { + for { + if request.IsErrorThrottle(err) { + return errors.ResourcesExhausted, errors.Temporary + } + if request.IsErrorRetryable(err) { + return errors.Other, errors.Temporary + } + aerr, ok := err.(awserr.Error) + if !ok { + break + } + if aerr.Code() == request.CanceledErrorCode { + return errors.Canceled, errors.Fatal + } + // The underlying error was an S3 error. Try to classify it. + // Best guess based on Amazon's descriptions: + switch aerr.Code() { + // Code NotFound is not documented, but it's what the API actually returns. + case s3.ErrCodeNoSuchBucket, "NoSuchVersion", "NotFound": + return errors.NotExist, errors.Fatal + case s3.ErrCodeNoSuchKey: + // Treat as temporary because sometimes they are, due to S3's eventual consistency model + // https://aws.amazon.com/premiumsupport/knowledge-center/404-error-nosuchkey-s3/ + return errors.NotExist, errors.Temporary + case "AccessDenied": + return errors.NotAllowed, errors.Fatal + case "InvalidRequest", "InvalidArgument", "EntityTooSmall", "EntityTooLarge", "KeyTooLong", "MethodNotAllowed": + return errors.Invalid, errors.Fatal + case "ExpiredToken", "AccountProblem", "ServiceUnavailable", "TokenRefreshRequired", "OperationAborted": + return errors.Unavailable, errors.Fatal + case "PreconditionFailed": + return errors.Precondition, errors.Fatal + case "SlowDown": + return errors.ResourcesExhausted, errors.Temporary + case "BadRequest": + return errors.Other, errors.Temporary + case "InternalError": + // AWS recommends retrying InternalErrors: + // https://aws.amazon.com/premiumsupport/knowledge-center/s3-resolve-200-internalerror/ + // https://aws.amazon.com/premiumsupport/knowledge-center/http-5xx-errors-s3/ + return errors.Other, errors.Retriable + case "XAmzContentSHA256Mismatch": + // Example: + // + // XAmzContentSHA256Mismatch: The provided 'x-amz-content-sha256' header + // does not match what was computed. + // + // Happens sporadically for no discernible reason. Just retry. + return errors.Other, errors.Temporary + // "RequestError"s are not considered retryable by `request.IsErrorRetryable(err)` + // if the underlying cause is due to a "read: connection reset". For explanation, see: + // https://github.com/aws/aws-sdk-go/issues/2525#issuecomment-519263830 + // So we catch all "RequestError"s here as temporary. + case request.ErrCodeRequestError: + return errors.Other, errors.Temporary + // "SerializationError"s are not considered retryable by `request.IsErrorRetryable(err)` + // if the underlying cause is due to a "read: connection reset". For explanation, see: + // https://github.com/aws/aws-sdk-go/issues/2525#issuecomment-519263830 + // So we catch all "SerializationError"s here as temporary. + case request.ErrCodeSerialization: + return errors.Other, errors.Temporary + } + if aerr.OrigErr() == nil { + break + } + err = aerr.OrigErr() + } + return errors.Other, errors.Unknown +} diff --git a/security/identity/identity.vdl b/security/identity/identity.vdl index ea74f418..34aa4bf7 100644 --- a/security/identity/identity.vdl +++ b/security/identity/identity.vdl @@ -18,3 +18,8 @@ type Ec2Blesser interface { type GoogleBlesser interface { BlessGoogle(idToken string) (blessing security.WireBlessings | error) {access.Read} } + +// K8sBlesser returns a blessing giving the provided Kubernetes service accountop token. +type K8sBlesser interface { + BlessK8s(caCrt string, namespace string, token string, region string) (blessing security.WireBlessings | error) {access.Read} +} \ No newline at end of file diff --git a/security/identity/identity.vdl.go b/security/identity/identity.vdl.go index 24fb5f4b..32325564 100644 --- a/security/identity/identity.vdl.go +++ b/security/identity/identity.vdl.go @@ -1,15 +1,12 @@ -// Copyright 2018 GRAIL, Inc. All rights reserved. -// Use of this source code is governed by the Apache-2.0 -// license that can be found in the LICENSE file. - // This file was auto-generated by the vanadium vdl tool. // Package: identity // Package identity defines interfaces for Vanadium identity providers. +//nolint:golint package identity import ( - "v.io/v23" + v23 "v.io/v23" "v.io/v23/context" "v.io/v23/rpc" "v.io/v23/security" @@ -17,10 +14,10 @@ import ( "v.io/v23/vdl" ) -var _ = __VDLInit() // Must be first; see __VDLInit comments for details. +var _ = initializeVDL() // Must be first; see initializeVDL comments for details. -////////////////////////////////////////////////// // Interface definitions +// ===================== // Ec2BlesserClientMethods is the client interface // containing Ec2Blesser methods. @@ -33,10 +30,10 @@ type Ec2BlesserClientMethods interface { BlessEc2(_ *context.T, pkcs7b64 string, _ ...rpc.CallOpt) (blessing security.Blessings, _ error) } -// Ec2BlesserClientStub adds universal methods to Ec2BlesserClientMethods. +// Ec2BlesserClientStub embeds Ec2BlesserClientMethods and is a +// placeholder for additional management operations. type Ec2BlesserClientStub interface { Ec2BlesserClientMethods - rpc.UniversalServiceMethods } // Ec2BlesserClient returns a client stub for Ec2Blesser. @@ -73,7 +70,7 @@ type Ec2BlesserServerStubMethods Ec2BlesserServerMethods // Ec2BlesserServerStub adds universal methods to Ec2BlesserServerStubMethods. type Ec2BlesserServerStub interface { Ec2BlesserServerStubMethods - // Describe the Ec2Blesser interfaces. + // DescribeInterfaces the Ec2Blesser interfaces. Describe__() []rpc.InterfaceDesc } @@ -124,10 +121,10 @@ var descEc2Blesser = rpc.InterfaceDesc{ Name: "BlessEc2", Doc: "// BlessEc2 uses the provided EC2 instance identity document in PKCS#7\n// format to return a blessing to the client.", InArgs: []rpc.ArgDesc{ - {"pkcs7b64", ``}, // string + {Name: "pkcs7b64", Doc: ``}, // string }, OutArgs: []rpc.ArgDesc{ - {"blessing", ``}, // security.Blessings + {Name: "blessing", Doc: ``}, // security.Blessings }, Tags: []*vdl.Value{vdl.ValueOf(access.Tag("Read"))}, }, @@ -142,10 +139,10 @@ type GoogleBlesserClientMethods interface { BlessGoogle(_ *context.T, idToken string, _ ...rpc.CallOpt) (blessing security.Blessings, _ error) } -// GoogleBlesserClientStub adds universal methods to GoogleBlesserClientMethods. +// GoogleBlesserClientStub embeds GoogleBlesserClientMethods and is a +// placeholder for additional management operations. type GoogleBlesserClientStub interface { GoogleBlesserClientMethods - rpc.UniversalServiceMethods } // GoogleBlesserClient returns a client stub for GoogleBlesser. @@ -179,7 +176,7 @@ type GoogleBlesserServerStubMethods GoogleBlesserServerMethods // GoogleBlesserServerStub adds universal methods to GoogleBlesserServerStubMethods. type GoogleBlesserServerStub interface { GoogleBlesserServerStubMethods - // Describe the GoogleBlesser interfaces. + // DescribeInterfaces the GoogleBlesser interfaces. Describe__() []rpc.InterfaceDesc } @@ -229,23 +226,131 @@ var descGoogleBlesser = rpc.InterfaceDesc{ { Name: "BlessGoogle", InArgs: []rpc.ArgDesc{ - {"idToken", ``}, // string + {Name: "idToken", Doc: ``}, // string + }, + OutArgs: []rpc.ArgDesc{ + {Name: "blessing", Doc: ``}, // security.Blessings + }, + Tags: []*vdl.Value{vdl.ValueOf(access.Tag("Read"))}, + }, + }, +} + +// K8sBlesserClientMethods is the client interface +// containing K8sBlesser methods. +// +// K8sBlesser returns a blessing giving the provided Kubernetes service accountop token. +type K8sBlesserClientMethods interface { + BlessK8s(_ *context.T, caCrt string, namespace string, token string, region string, _ ...rpc.CallOpt) (blessing security.Blessings, _ error) +} + +// K8sBlesserClientStub embeds K8sBlesserClientMethods and is a +// placeholder for additional management operations. +type K8sBlesserClientStub interface { + K8sBlesserClientMethods +} + +// K8sBlesserClient returns a client stub for K8sBlesser. +func K8sBlesserClient(name string) K8sBlesserClientStub { + return implK8sBlesserClientStub{name} +} + +type implK8sBlesserClientStub struct { + name string +} + +func (c implK8sBlesserClientStub) BlessK8s(ctx *context.T, i0 string, i1 string, i2 string, i3 string, opts ...rpc.CallOpt) (o0 security.Blessings, err error) { + err = v23.GetClient(ctx).Call(ctx, c.name, "BlessK8s", []interface{}{i0, i1, i2, i3}, []interface{}{&o0}, opts...) + return +} + +// K8sBlesserServerMethods is the interface a server writer +// implements for K8sBlesser. +// +// K8sBlesser returns a blessing giving the provided Kubernetes service accountop token. +type K8sBlesserServerMethods interface { + BlessK8s(_ *context.T, _ rpc.ServerCall, caCrt string, namespace string, token string, region string) (blessing security.Blessings, _ error) +} + +// K8sBlesserServerStubMethods is the server interface containing +// K8sBlesser methods, as expected by rpc.Server. +// There is no difference between this interface and K8sBlesserServerMethods +// since there are no streaming methods. +type K8sBlesserServerStubMethods K8sBlesserServerMethods + +// K8sBlesserServerStub adds universal methods to K8sBlesserServerStubMethods. +type K8sBlesserServerStub interface { + K8sBlesserServerStubMethods + // DescribeInterfaces the K8sBlesser interfaces. + Describe__() []rpc.InterfaceDesc +} + +// K8sBlesserServer returns a server stub for K8sBlesser. +// It converts an implementation of K8sBlesserServerMethods into +// an object that may be used by rpc.Server. +func K8sBlesserServer(impl K8sBlesserServerMethods) K8sBlesserServerStub { + stub := implK8sBlesserServerStub{ + impl: impl, + } + // Initialize GlobState; always check the stub itself first, to handle the + // case where the user has the Glob method defined in their VDL source. + if gs := rpc.NewGlobState(stub); gs != nil { + stub.gs = gs + } else if gs := rpc.NewGlobState(impl); gs != nil { + stub.gs = gs + } + return stub +} + +type implK8sBlesserServerStub struct { + impl K8sBlesserServerMethods + gs *rpc.GlobState +} + +func (s implK8sBlesserServerStub) BlessK8s(ctx *context.T, call rpc.ServerCall, i0 string, i1 string, i2 string, i3 string) (security.Blessings, error) { + return s.impl.BlessK8s(ctx, call, i0, i1, i2, i3) +} + +func (s implK8sBlesserServerStub) Globber() *rpc.GlobState { + return s.gs +} + +func (s implK8sBlesserServerStub) Describe__() []rpc.InterfaceDesc { + return []rpc.InterfaceDesc{K8sBlesserDesc} +} + +// K8sBlesserDesc describes the K8sBlesser interface. +var K8sBlesserDesc rpc.InterfaceDesc = descK8sBlesser + +// descK8sBlesser hides the desc to keep godoc clean. +var descK8sBlesser = rpc.InterfaceDesc{ + Name: "K8sBlesser", + PkgPath: "github.com/grailbio/base/security/identity", + Doc: "// K8sBlesser returns a blessing giving the provided Kubernetes service accountop token.", + Methods: []rpc.MethodDesc{ + { + Name: "BlessK8s", + InArgs: []rpc.ArgDesc{ + {Name: "caCrt", Doc: ``}, // string + {Name: "namespace", Doc: ``}, // string + {Name: "token", Doc: ``}, // string + {Name: "region", Doc: ``}, // string }, OutArgs: []rpc.ArgDesc{ - {"blessing", ``}, // security.Blessings + {Name: "blessing", Doc: ``}, // security.Blessings }, Tags: []*vdl.Value{vdl.ValueOf(access.Tag("Read"))}, }, }, } -var __VDLInitCalled bool +var initializeVDLCalled bool -// __VDLInit performs vdl initialization. It is safe to call multiple times. +// initializeVDL performs vdl initialization. It is safe to call multiple times. // If you have an init ordering issue, just insert the following line verbatim // into your source files in this package, right after the "package foo" clause: // -// var _ = __VDLInit() +// var _ = initializeVDL() // // The purpose of this function is to ensure that vdl initialization occurs in // the right order, and very early in the init sequence. In particular, vdl @@ -254,11 +359,11 @@ var __VDLInitCalled bool // // This function returns a dummy value, so that it can be used to initialize the // first var in the file, to take advantage of Go's defined init order. -func __VDLInit() struct{} { - if __VDLInitCalled { +func initializeVDL() struct{} { + if initializeVDLCalled { return struct{}{} } - __VDLInitCalled = true + initializeVDLCalled = true return struct{}{} } diff --git a/security/keycrypt/keychain/keychain_darwin.go b/security/keycrypt/keychain/keychain_darwin.go index 2dae8c41..58327d9b 100644 --- a/security/keycrypt/keychain/keychain_darwin.go +++ b/security/keycrypt/keychain/keychain_darwin.go @@ -2,7 +2,7 @@ // Use of this source code is governed by the Apache-2.0 // license that can be found in the LICENSE file. -// +build darwin +// +build darwin,cgo // Secrets are stored directly into the macOS Keychain under the name // com.grail.keycrypt.$namespace; Keycrypt names are stored into the @@ -11,7 +11,6 @@ package keychain import ( "github.com/grailbio/base/security/keycrypt" - keychain "github.com/keybase/go-keychain" ) diff --git a/security/keycrypt/keychain/keychain_fallback.go b/security/keycrypt/keychain/keychain_fallback.go index 6d86d2b5..0639ba03 100644 --- a/security/keycrypt/keychain/keychain_fallback.go +++ b/security/keycrypt/keychain/keychain_fallback.go @@ -2,6 +2,6 @@ // Use of this source code is governed by the Apache-2.0 // license that can be found in the LICENSE file. -// +build !darwin +// +build !darwin !cgo package keychain diff --git a/security/ssh/certificateauthority/ssh.go b/security/ssh/certificateauthority/ssh.go new file mode 100644 index 00000000..429ce984 --- /dev/null +++ b/security/ssh/certificateauthority/ssh.go @@ -0,0 +1,172 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +// Package certificateauthority implements an x509 certificate authority. +package certificateauthority + +import ( + "crypto/rand" + + "golang.org/x/crypto/ssh" + + "github.com/grailbio/base/security/keycrypt" + + "errors" + "time" +) + +// CertificateAuthority is a ssh certificate authority. +type CertificateAuthority struct { + // The amount of allowable clock drift between the systems between + // which certificates are exchanged. + DriftMargin time.Duration + + // The keycrypt secret that contains the PEM-encoded private key. + PrivateKey keycrypt.Secret + + // Contains the PEM-encoded Certificate. + Certificate string + + // The ssh certificate signer. Populated by Init(). + Signer ssh.Signer +} + +type CertificateRequest struct { + // SSH Public Key that is being signed + SshPublicKey []byte + + // List of host names, or usernames that will be added to the cert + Principals []string + + // How long this certificate should be valid for + Ttl time.Duration + + // What identifier should be included in the request + // This value will be used in logging + KeyID string + + CertType string // either "user" or "host" + + CriticalOptions []string + + // Extensions to assign to the ssh Certificate + // The default allow basic function - permit-pty is usually required + // map[string]string{ + // "permit-X11-forwarding": "", + // "permit-agent-forwarding": "", + // "permit-port-forwarding": "", + // "permit-pty": "", + // "permit-user-rc": "", + // } + Extensions []string +} + +const sshSignAlg = ssh.SigAlgoRSASHA2256 + +func validateCertType(certType string) (uint32, error) { + switch certType { + case "user": + return ssh.UserCert, nil + case "host": + return ssh.HostCert, nil + } + return 0, errors.New("CertType must be either 'user' or 'host'") +} + +// Init initializes the certificate authority. Init extracts the +// authority certificate and private key from ca.Signer. +func (ca *CertificateAuthority) Init() error { + pkPemBlock, err := ca.PrivateKey.Get() + if err != nil { + return err + } + + // Load the private key + privateSigner, err := ssh.ParsePrivateKey(pkPemBlock) + if err != nil { + return err + } + + // Load the Certificate + certificate, _, _, _, err := ssh.ParseAuthorizedKey([]byte(ca.Certificate)) + if err != nil { + return err + } + // Link the private key with its matching Authority Certificate + ca.Signer, err = ssh.NewCertSigner(certificate.(*ssh.Certificate), privateSigner) + if err != nil { + return err + } + + ca.Signer = privateSigner + + return nil +} + +func (ca CertificateAuthority) IssueWithKeyUsage(cr CertificateRequest) (string, error) { + return ca.issueWithKeyUsage(time.Now(), cr) +} + +func (ca CertificateAuthority) issueWithKeyUsage(now time.Time, cr CertificateRequest) (string, error) { + + // Load the Certificate + pubKey, _, _, _, err := ssh.ParseAuthorizedKey(cr.SshPublicKey) + if err != nil { + return "", err + } + + now = now.Add(-ca.DriftMargin) + + certType, err := validateCertType(cr.CertType) + if err != nil { + return "", err + } + + certificate := &ssh.Certificate{ + Serial: 0, + Key: pubKey, + KeyId: cr.KeyID, // Descriptive name of the key (shown in logs) + ValidPrincipals: cr.Principals, // hostnames (for host cert), or usernames (for client cert) + ValidAfter: uint64(now.In(time.UTC).Unix()), + ValidBefore: uint64(now.Add(ca.DriftMargin + cr.Ttl).In(time.UTC).Unix()), + CertType: certType, // int representing a "user" or "host" type + Permissions: ssh.Permissions{ + CriticalOptions: convertArrayToMap(cr.CriticalOptions), + Extensions: convertArrayToMap(cr.Extensions), + }, + } + + // Replicate the certificate.SignCert functions but with a custom algorithm + certificate.Nonce = make([]byte, 32) + if _, err = rand.Read(certificate.Nonce); err != nil { + return "", err + } + certificate.SignatureKey = ca.Signer.PublicKey() + + // based on Certificate.bytesForSigning() + certificateBytes := certificate.Marshal() + // Drop trailing signature length + certificateBytes = certificateBytes[:len(certificateBytes)-4] + + certificate.Signature, err = ca.Signer.(ssh.AlgorithmSigner).SignWithAlgorithm(rand.Reader, certificateBytes, sshSignAlg) + if err != nil { + return "", err + } + + return string(ssh.MarshalAuthorizedKey(certificate)), err +} + +// Convert an array of strings into a map of string value pairs +// Value of the key is set to "" which is what the SSH Library wants for extensions and CriticalOptions as flags +func convertArrayToMap(initial []string) map[string]string { + if initial == nil { + return nil + } + + results := map[string]string{} + for _, key := range initial { + results[key] = "" + } + return results +} diff --git a/security/ssh/certificateauthority/ssh_test.go b/security/ssh/certificateauthority/ssh_test.go new file mode 100644 index 00000000..0f5fd6db --- /dev/null +++ b/security/ssh/certificateauthority/ssh_test.go @@ -0,0 +1,99 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +package certificateauthority + +import ( + "io/ioutil" + "testing" + + "bytes" + "os/exec" + "time" + + "github.com/grailbio/base/security/keycrypt" + "github.com/grailbio/testutil/assert" +) + +//ssh-keygen -s testdata/ssh_key.pem -I CA -O clear testdata/ssh_key.pem.pub +func TestAuthority(t *testing.T) { + sshPEM, err := ioutil.ReadFile("testdata/ssh_key.pem") + assert.NoError(t, err) + sshCERT, err := ioutil.ReadFile("testdata/ssh_key.pem-cert.pub") + assert.NoError(t, err) + userPubKey, err := ioutil.ReadFile("testdata/user_ssh_key.pem.pub") + assert.NoError(t, err) + + ca := CertificateAuthority{PrivateKey: keycrypt.Static(sshPEM), Certificate: string(sshCERT)} + err = ca.Init() + assert.NoError(t, err) + + cr := CertificateRequest{ + // SSH Public Key that is being signed + SshPublicKey: []byte(userPubKey), + + // List of host names, or usernames that will be added to the cert + Principals: []string{"ubuntu"}, + Ttl: time.Duration(3600) * time.Second, + KeyID: "foo", + + CertType: "user", + + CriticalOptions: nil, + + // Extensions to assign to the ssh Certificate + // The default allow basic function - permit-pty is usually required + Extensions: []string{ + "permit-X11-forwarding", + "permit-agent-forwarding", + "permit-port-forwarding", + "permit-pty", + "permit-user-rc", + }, + } + execTime := time.Date(2020, time.January, 19, 0, 0, 0, 0, time.UTC) + sshCert, err := ca.issueWithKeyUsage(execTime, cr) + assert.NoError(t, err) + + // Check the golang created certificate against the one created with + // ssh-keygen -t rsa-sha2-256 -s testdata/ssh_key.pem -I foo -V 20200118160000:20200118170000 -n ubuntu testdata/user_ssh_key.pem + preCreatedUserCert, err := ioutil.ReadFile("testdata/user_ssh_key.pem-cert.pub") + assert.NoError(t, err) + + cmd := exec.Command("ssh-keygen", "-L", "-f", "-") + cmd.Stdin = bytes.NewBuffer([]byte(preCreatedUserCert)) + preCreatedOutput, err := cmd.Output() + assert.NoError(t, err) + + cmd = exec.Command("ssh-keygen", "-L", "-f", "-") + cmd.Stdin = bytes.NewBuffer([]byte(sshCert)) + output, err := cmd.Output() + assert.NoError(t, err) + + if string(preCreatedOutput) != string(output) { + t.Errorf("IssueWithKeyUsage: got %q, want %q", output, preCreatedOutput) + } +} + +func TestValidateCertType(t *testing.T) { + cases := []struct { + input string + expected uint32 + expectErr bool + }{ + {"user", 1, false}, + {"host", 2, false}, + {"FOO", 0, true}, + } + for _, c := range cases { + got, err := validateCertType(c.input) + if got != c.expected { + t.Errorf("TestValidateCertType(%q): got %q, want %q", c.input, got, c.expected) + } + errCheck := (err != nil) + if errCheck != c.expectErr { + t.Errorf("TestValidateCertType(%q): got err %q, want %t", c.input, err, c.expectErr) + } + } +} diff --git a/security/ssh/certificateauthority/testdata/ssh_key.pem b/security/ssh/certificateauthority/testdata/ssh_key.pem new file mode 100644 index 00000000..c1fa3c99 --- /dev/null +++ b/security/ssh/certificateauthority/testdata/ssh_key.pem @@ -0,0 +1,27 @@ +-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABFwAAAAdzc2gtcn +NhAAAAAwEAAQAAAQEA1Z4EOQZZqhxfNIzB1ogM8AZuNRc/qLiCP61GAUtUviTrrWlcTbPG +o3jzBs6DbvoeP6Djr8uCSGW4Aqq1N67DMeBn5iAhrF1SQjNgEyvMITZ4voye+OS5s5tsJF +FwczzQPiSv+ha3Qn/0+ej/IQMibnQvbyVpBxsgTwlbfNQc93zxb9YyZVW12D5xKyAmIHeG +trujpAU6vgnI+nxEZe9udVdgvHlYDWVNIbOsr25X9UtEC1rBS3Mebc7vFgE7XV2Ih4ydKa +G3oub26tCqRJbo4c/wmQ2U/jrUJfuWTEVnFR2JO/vm0PO+NKak2p/Jits1Yb2Nr3MP8C65 +MUV3uC6aoQAAA9AXn2NDF59jQwAAAAdzc2gtcnNhAAABAQDVngQ5BlmqHF80jMHWiAzwBm +41Fz+ouII/rUYBS1S+JOutaVxNs8ajePMGzoNu+h4/oOOvy4JIZbgCqrU3rsMx4GfmICGs +XVJCM2ATK8whNni+jJ745Lmzm2wkUXBzPNA+JK/6FrdCf/T56P8hAyJudC9vJWkHGyBPCV +t81Bz3fPFv1jJlVbXYPnErICYgd4a2u6OkBTq+Ccj6fERl7251V2C8eVgNZU0hs6yvblf1 +S0QLWsFLcx5tzu8WATtdXYiHjJ0pobei5vbq0KpElujhz/CZDZT+OtQl+5ZMRWcVHYk7++ +bQ8740pqTan8mK2zVhvY2vcw/wLrkxRXe4LpqhAAAAAwEAAQAAAQBCNcUHS8GU6VBVAF/A +N9ESwFt+VyNjDzVHuVQeaJPuj5EA4cw7RUKHPqDM9ktkpV+kxyBrR3+tIuIC6Zhblu5nUi +/B8ymcvBwX6saXipatoK2SGhHYAUTRt5WwOBHLlgjRihuFP28zzGdI4n/ZYphUdeyR3Z2N +F0pPVQ4nRbNovZ6yAPgvmlmQsCI/2VQ3VweBEXeVReh3M/TwMoyzApXKeG/+9bIm5irlJp +I11MtHpnIf6leEXivYbjeJQ6sI3s/9l18GQyQayy01x7YsRTIezjzYoeB7dYELSwthxezX +XNG/JmeLaj690KPAR4vMz+w9lYLepD2TLpGrATSfGp71AAAAgGsO5TfxVRBfQZHGdE/lAz +K/L7Q0U2Jeshl0iuuGosEdpaOgNjgAnqUyYkRfQJR8RKJKv72d2/kuEF63vEQWZ1uFiQlw +j/2ixIMihakPGGZo1Nhd6UDhfyQuzx2YC4VEfQ+0JvfHDO0O5B+P8Yb1Z8MIoU8ACtzRFR +xsiDz3hkdmAAAAgQD++WJckx3uUNORTIHS/H0rXZnnv/itLaPMrV1KNiD9crDJENTgI/UK +O/20qoRwHaE8rwuFPCqEqeATWTT8qr7JjpzLPyx7EJTo+6EjALVEOW/ZWrQb0xpXuwOmTI +J9XBdAGrcZSjrw7y2yIo6j87kfC2lrC4ND5cYARaTurSwCKwAAAIEA1noJMDA6T3l7jHox +PAGLSmVutBSTe0NwyMqzkLHM5NEEp1Spe59HofGASgH6KWvwbj65lXrgmZTSt3+b2++H/L +VRgkVfSmKQqjjGMe6bQHc4VV0+DFoY+yQnBoj9MWqQYqneRUR1jS+AQ2Cza0d2Qq4SfpAi +6T+g9TE1ZC0PTGMAAAAWYWVpc2VyQEdCLUMwMllSNE05TFZDRwECAwQF +-----END OPENSSH PRIVATE KEY----- diff --git a/security/ssh/certificateauthority/testdata/ssh_key.pem-cert.pub b/security/ssh/certificateauthority/testdata/ssh_key.pem-cert.pub new file mode 100644 index 00000000..d48dbd5a --- /dev/null +++ b/security/ssh/certificateauthority/testdata/ssh_key.pem-cert.pub @@ -0,0 +1 @@ +ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAguGFSceeSbkfzhN5OpCA451FEP7pbmUgT91ZFC/OuyDIAAAADAQABAAABAQDVngQ5BlmqHF80jMHWiAzwBm41Fz+ouII/rUYBS1S+JOutaVxNs8ajePMGzoNu+h4/oOOvy4JIZbgCqrU3rsMx4GfmICGsXVJCM2ATK8whNni+jJ745Lmzm2wkUXBzPNA+JK/6FrdCf/T56P8hAyJudC9vJWkHGyBPCVt81Bz3fPFv1jJlVbXYPnErICYgd4a2u6OkBTq+Ccj6fERl7251V2C8eVgNZU0hs6yvblf1S0QLWsFLcx5tzu8WATtdXYiHjJ0pobei5vbq0KpElujhz/CZDZT+OtQl+5ZMRWcVHYk7++bQ8740pqTan8mK2zVhvY2vcw/wLrkxRXe4LpqhAAAAAAAAAAAAAAABAAAAAkNBAAAAAAAAAAAAAAAA//////////8AAAAAAAAAAAAAAAAAAAEXAAAAB3NzaC1yc2EAAAADAQABAAABAQDVngQ5BlmqHF80jMHWiAzwBm41Fz+ouII/rUYBS1S+JOutaVxNs8ajePMGzoNu+h4/oOOvy4JIZbgCqrU3rsMx4GfmICGsXVJCM2ATK8whNni+jJ745Lmzm2wkUXBzPNA+JK/6FrdCf/T56P8hAyJudC9vJWkHGyBPCVt81Bz3fPFv1jJlVbXYPnErICYgd4a2u6OkBTq+Ccj6fERl7251V2C8eVgNZU0hs6yvblf1S0QLWsFLcx5tzu8WATtdXYiHjJ0pobei5vbq0KpElujhz/CZDZT+OtQl+5ZMRWcVHYk7++bQ8740pqTan8mK2zVhvY2vcw/wLrkxRXe4LpqhAAABDwAAAAdzc2gtcnNhAAABAHnqiyQ536XkZLITz3ajE5FG/9RCVGqqsiLAJVWY0SKlUwX1O5JE84qzuA45tBhn+3akUMmgyqr4Xq1ObX6LDpvjEvGDCsDo227Jtp9M4m6TpaUxHK7qo4PDi9alYtPfoMsDgT7ZQ0ph2tK/++BTQGJuKr3MDtgPBsb5cw60Z2b3bz+I6LqHN/a0wgv24xxhuOlcRfYdxYnRyTOhqkAdst2cRmUbKU57in8JIZe3zFu5/hes9M9A8Hb0eoMBRHAUZ6fy/NxoaRAaWcj1bWq5hVlDOGZq6qzJGUQ0tkJLDUJOa2qPHQ3r5WsBvfla1BtNWnGrM7cRoQwCWM+OYAl2Z/Q= aeiser@GB-C02YR4M9LVCG \ No newline at end of file diff --git a/security/ssh/certificateauthority/testdata/ssh_key.pem.pub b/security/ssh/certificateauthority/testdata/ssh_key.pem.pub new file mode 100644 index 00000000..6ccf0e6f --- /dev/null +++ b/security/ssh/certificateauthority/testdata/ssh_key.pem.pub @@ -0,0 +1 @@ +ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDVngQ5BlmqHF80jMHWiAzwBm41Fz+ouII/rUYBS1S+JOutaVxNs8ajePMGzoNu+h4/oOOvy4JIZbgCqrU3rsMx4GfmICGsXVJCM2ATK8whNni+jJ745Lmzm2wkUXBzPNA+JK/6FrdCf/T56P8hAyJudC9vJWkHGyBPCVt81Bz3fPFv1jJlVbXYPnErICYgd4a2u6OkBTq+Ccj6fERl7251V2C8eVgNZU0hs6yvblf1S0QLWsFLcx5tzu8WATtdXYiHjJ0pobei5vbq0KpElujhz/CZDZT+OtQl+5ZMRWcVHYk7++bQ8740pqTan8mK2zVhvY2vcw/wLrkxRXe4Lpqh aeiser@GB-C02YR4M9LVCG diff --git a/security/ssh/certificateauthority/testdata/user_ssh_key.pem b/security/ssh/certificateauthority/testdata/user_ssh_key.pem new file mode 100644 index 00000000..62e1adaa --- /dev/null +++ b/security/ssh/certificateauthority/testdata/user_ssh_key.pem @@ -0,0 +1,27 @@ +-----BEGIN OPENSSH PRIVATE KEY----- +b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABFwAAAAdzc2gtcn +NhAAAAAwEAAQAAAQEA3RST+BXApGV/KGmflZv4l19y6rsduYLKKxE0fltrIPYxOqsHqijP +od1cjEH1q4BeG5pvMcgzlCgVHM3gpyy0eo2M4ADzp3DYBhnhfgqPj/7Sxek05c7i40pR50 +kbXExcOyc3Kq+cza2IY0DUJmhXY4JiPoxipWkNv6jqfm5NdkWtfz+VfKqxAJjC5MfMOd/2 +aSmXnKG1BiVcwtbxzfxSAZeiAmDFvhpJ9FgvDSxbbMbZZIbHz8L4KP44kPDNJmCZg2WNts +qBhuWyrh/6yyrJqHRbTtLFlGckdRELsiYbOR5ja2+6cTh2d5Q5WT1v0i3t31QajGBx6pNO +0bKZMkSzIQAAA9BphDyWaYQ8lgAAAAdzc2gtcnNhAAABAQDdFJP4FcCkZX8oaZ+Vm/iXX3 +Lqux25gsorETR+W2sg9jE6qweqKM+h3VyMQfWrgF4bmm8xyDOUKBUczeCnLLR6jYzgAPOn +cNgGGeF+Co+P/tLF6TTlzuLjSlHnSRtcTFw7Jzcqr5zNrYhjQNQmaFdjgmI+jGKlaQ2/qO +p+bk12Ra1/P5V8qrEAmMLkx8w53/ZpKZecobUGJVzC1vHN/FIBl6ICYMW+Gkn0WC8NLFts +xtlkhsfPwvgo/jiQ8M0mYJmDZY22yoGG5bKuH/rLKsmodFtO0sWUZyR1EQuyJhs5HmNrb7 +pxOHZ3lDlZPW/SLe3fVBqMYHHqk07RspkyRLMhAAAAAwEAAQAAAQByRdeyDPRVRU0zw1zE +hSk6fRC2Od/EatE675q1kWVPVVHe5FaC4rNoFDZpHRLyAdki5XGCRtw6QXmgON5dKuNi0V +W212cZ7l8K0EfY0XahVHL373HzMzvdhiXNqeEllSa7QKroOnuPaJoty22dKKO7AMLtV70J +iMKdhZ8nmLYbYjS0jMFmTFb2SBRUWZBycxVCsCdemwaRrY57zwJpJOiqv+O+AKy9BjDINa +vOFAcCPDXtPmHsvOBvzPiGyomKxcz4+4OoCFWPLcV777UEwCwYeRcj+iemwKulQUd22lN+ +BIMk0gwPUaOGvLjkR98HX5kw2E5B7FLJp5lIgrc9WZABAAAAgCtAFfzi85ZxeZacTmXIu6 +1+qYJ0qz24XxxIgwLCSgLa23O8HzOTJ3bYWMbcTt2IkJfJlk1nityePdbZrzlzEX5f2s8u +rWOlkB4VBXLHz5csBPG/+KXiyjc/4BNmgn6Pu5w1E30+wEn7ogJfzLsTF447AciYCXEM7D +QOIeCw5XzHAAAAgQDxP8cYfI9yRyaE06GRT1ZOAQuY8W6Fxcmc8AuEb8zlsH0ls50ml5M7 +zCBifLi/+8py7L2y4qlfuvrtztKqAHtf1cvYNzcnFB0NBAHl75F9fP4RPvKPDdWfr6GRN1 +MqOxlgIdReC63pXgC1+KDjNRd97ct7Se/yYu6ws0OXi+croQAAAIEA6pkaYu9sD0h/AH7K +OQq4RdcGABGTBBwPa6A9pw7yFTFGoYkE+IBiFkNQjH0bP3AlOwaVet1BCP0yKrXC38QG8o +s78xfI/1PP2BiMotLmJ8bphEwYzdgV4UG6rqVnHslJElyS8W/YDSdb0ciEZtVWswxMKey+ +d8qcqv7aI550V4EAAAAWYWVpc2VyQEdCLUMwMllSNE05TFZDRwECAwQF +-----END OPENSSH PRIVATE KEY----- diff --git a/security/ssh/certificateauthority/testdata/user_ssh_key.pem-cert.pub b/security/ssh/certificateauthority/testdata/user_ssh_key.pem-cert.pub new file mode 100644 index 00000000..c5f007a5 --- /dev/null +++ b/security/ssh/certificateauthority/testdata/user_ssh_key.pem-cert.pub @@ -0,0 +1 @@ +ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAg6R0/RgZA09Nd+5yUloPMeQl1F343thf6RT24aho4geMAAAADAQABAAABAQDdFJP4FcCkZX8oaZ+Vm/iXX3Lqux25gsorETR+W2sg9jE6qweqKM+h3VyMQfWrgF4bmm8xyDOUKBUczeCnLLR6jYzgAPOncNgGGeF+Co+P/tLF6TTlzuLjSlHnSRtcTFw7Jzcqr5zNrYhjQNQmaFdjgmI+jGKlaQ2/qOp+bk12Ra1/P5V8qrEAmMLkx8w53/ZpKZecobUGJVzC1vHN/FIBl6ICYMW+Gkn0WC8NLFtsxtlkhsfPwvgo/jiQ8M0mYJmDZY22yoGG5bKuH/rLKsmodFtO0sWUZyR1EQuyJhs5HmNrb7pxOHZ3lDlZPW/SLe3fVBqMYHHqk07RspkyRLMhAAAAAAAAAAAAAAABAAAAA2ZvbwAAAAoAAAAGdWJ1bnR1AAAAAF4jnAAAAAAAXiOqEAAAAAAAAACCAAAAFXBlcm1pdC1YMTEtZm9yd2FyZGluZwAAAAAAAAAXcGVybWl0LWFnZW50LWZvcndhcmRpbmcAAAAAAAAAFnBlcm1pdC1wb3J0LWZvcndhcmRpbmcAAAAAAAAACnBlcm1pdC1wdHkAAAAAAAAADnBlcm1pdC11c2VyLXJjAAAAAAAAAAAAAAEXAAAAB3NzaC1yc2EAAAADAQABAAABAQDVngQ5BlmqHF80jMHWiAzwBm41Fz+ouII/rUYBS1S+JOutaVxNs8ajePMGzoNu+h4/oOOvy4JIZbgCqrU3rsMx4GfmICGsXVJCM2ATK8whNni+jJ745Lmzm2wkUXBzPNA+JK/6FrdCf/T56P8hAyJudC9vJWkHGyBPCVt81Bz3fPFv1jJlVbXYPnErICYgd4a2u6OkBTq+Ccj6fERl7251V2C8eVgNZU0hs6yvblf1S0QLWsFLcx5tzu8WATtdXYiHjJ0pobei5vbq0KpElujhz/CZDZT+OtQl+5ZMRWcVHYk7++bQ8740pqTan8mK2zVhvY2vcw/wLrkxRXe4LpqhAAABFAAAAAxyc2Etc2hhMi0yNTYAAAEAc9S58z1U5k5XWaJ5xrJUexxqbpYoXuiLY3PjvenNsExIl3drAUS7WA7IR3RyiUrj+qiZpqT9vS+IfGzkjehrjX/Uk2l0/THsZXj+iLZDCqPU841eR2e6DvSdJbuqsDMiDCi0Z/Qj0a5Wd8uBqv2wg7tKwbvwt8JvXTM/D3s3XMxzvSCtKUE6iEPoTMS6hnXcgFqN243wGgckL/3OtVn0zzi8/FTPM7e22n+9gxZAKrE3xBodBcuzzZV868KXmwwcLYjWAeOReLVkCE2XYJEMQQXFpmmjO/dpu3kOE5Q2n6uMa3HVBfUldLmNnC6JRvjTPr4bzvYQFNMd47BjgZy8YA== aeiser@GB-C02YR4M9LVCG diff --git a/security/ssh/certificateauthority/testdata/user_ssh_key.pem.pub b/security/ssh/certificateauthority/testdata/user_ssh_key.pem.pub new file mode 100644 index 00000000..20a44892 --- /dev/null +++ b/security/ssh/certificateauthority/testdata/user_ssh_key.pem.pub @@ -0,0 +1 @@ +ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDdFJP4FcCkZX8oaZ+Vm/iXX3Lqux25gsorETR+W2sg9jE6qweqKM+h3VyMQfWrgF4bmm8xyDOUKBUczeCnLLR6jYzgAPOncNgGGeF+Co+P/tLF6TTlzuLjSlHnSRtcTFw7Jzcqr5zNrYhjQNQmaFdjgmI+jGKlaQ2/qOp+bk12Ra1/P5V8qrEAmMLkx8w53/ZpKZecobUGJVzC1vHN/FIBl6ICYMW+Gkn0WC8NLFtsxtlkhsfPwvgo/jiQ8M0mYJmDZY22yoGG5bKuH/rLKsmodFtO0sWUZyR1EQuyJhs5HmNrb7pxOHZ3lDlZPW/SLe3fVBqMYHHqk07RspkyRLMh aeiser@GB-C02YR4M9LVCG diff --git a/security/ticket/aws.go b/security/ticket/aws.go index eee7841c..527adeec 100644 --- a/security/ticket/aws.go +++ b/security/ticket/aws.go @@ -5,16 +5,21 @@ package ticket import ( + "errors" "strings" "time" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/client" "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/credentials/stscreds" "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/ecr" "github.com/aws/aws-sdk-go/service/sts" + "github.com/grailbio/base/cloud/ec2util" + "github.com/grailbio/base/common/log" "github.com/grailbio/base/ttlcache" - "v.io/x/lib/vlog" ) type cacheKey struct { @@ -57,24 +62,24 @@ func (b *AwsAssumeRoleBuilder) newS3Ticket(ctx *TicketContext) (TicketS3Ticket, } func (b *AwsAssumeRoleBuilder) newEcrTicket(ctx *TicketContext) (TicketEcrTicket, error) { + log.Debug(ctx.ctx, "generating ECR ticket", "AwsAssumeRoleBuilder", b) awsCredentials, err := b.genAwsCredentials(ctx) if err != nil { return TicketEcrTicket{}, err } - return TicketEcrTicket{ - Value: newEcrTicket(awsCredentials), + Value: newEcrTicket(ctx, awsCredentials), }, nil } func (b *AwsAssumeRoleBuilder) genAwsCredentials(ctx *TicketContext) (AwsCredentials, error) { - vlog.Infof("AwsAssumeRoleBuilder: %+v", b) + log.Debug(ctx.ctx, "generating AWS credentials", "AwsAssumeRoleBuilder", b) empty := AwsCredentials{} sessionName := strings.Replace(ctx.remoteBlessings.String(), ":", ",", -1) // AWS session names must be 64 characters or less - if runes := []rune(sessionName); len(runes) > 64 { + if runes := []rune(sessionName); len(runes) > 64 { // Some risk with simple truncation - two large IAM role's would overlap // for example. This is mitigated by the format which includes instance id // as the last component. Ability to determine exactly which instance made @@ -84,10 +89,10 @@ func (b *AwsAssumeRoleBuilder) genAwsCredentials(ctx *TicketContext) (AwsCredent } key := cacheKey{b.Region, b.Role, sessionName} if v, ok := cache.Get(key); ok { - vlog.VI(1).Infof("cache hit for %+v", key) + log.Debug(ctx.ctx, "AWS credentials lookup cache hit", "key", key) return v.(AwsCredentials), nil } - vlog.VI(1).Infof("cache miss for %+v", key) + log.Debug(ctx.ctx, "AWS credentials lookup cache miss", "key", key) s := ctx.session if aws.StringValue(s.Config.Region) != b.Region { @@ -95,6 +100,7 @@ func (b *AwsAssumeRoleBuilder) genAwsCredentials(ctx *TicketContext) (AwsCredent var err error s, err = session.NewSession(s.Config.WithRegion(b.Region)) if err != nil { + log.Error(ctx.ctx, "error creating AWS session", "err", err.Error()) return empty, err } } @@ -114,6 +120,7 @@ func (b *AwsAssumeRoleBuilder) genAwsCredentials(ctx *TicketContext) (AwsCredent assumeRoleOutput, err := client.AssumeRole(assumeRoleInput) if err != nil { + log.Error(ctx.ctx, "error in AssumeRole API call", "key", key) return empty, err } @@ -125,7 +132,7 @@ func (b *AwsAssumeRoleBuilder) genAwsCredentials(ctx *TicketContext) (AwsCredent Expiration: assumeRoleOutput.Credentials.Expiration.Format(time.RFC3339Nano), } - vlog.VI(1).Infof("add to cache %+v", key) + log.Debug(ctx.ctx, "adding AWS credentials to cache", "key", key) cache.Set(key, result) return result, nil @@ -160,13 +167,13 @@ func (b *AwsSessionBuilder) newS3Ticket(ctx *TicketContext) (TicketS3Ticket, err } func (b *AwsSessionBuilder) genAwsSession(ctx *TicketContext) (AwsCredentials, error) { - vlog.Infof("AwsSessionBuilder: %s", b.AwsCredentials.AccessKeyId) + log.Debug(ctx.ctx, "enerating AWS session", "AwsAssumeRoleBuilder", b.AwsCredentials.AccessKeyId) empty := AwsCredentials{} awsCredentials := b.AwsCredentials sessionName := strings.Replace(ctx.remoteBlessings.String(), ":", ",", -1) // AWS session names must be 64 characters or less - if runes := []rune(sessionName); len(runes) > 64 { + if runes := []rune(sessionName); len(runes) > 64 { // Some risk with simple truncation - two large IAM role's would overlap // for example. This is mitigated by the format which includes instance id // as the last component. Ability to determine exactly which instance made @@ -176,10 +183,10 @@ func (b *AwsSessionBuilder) genAwsSession(ctx *TicketContext) (AwsCredentials, e } key := cacheKey{awsCredentials.Region, awsCredentials.AccessKeyId, sessionName} if v, ok := cache.Get(key); ok { - vlog.VI(1).Infof("cache hit for %+v", key) + log.Debug(ctx.ctx, "AWS session lookup cache hit", "key", key) return v.(AwsCredentials), nil } - vlog.VI(1).Infof("cache miss for %+v", key) + log.Debug(ctx.ctx, "AWS session lookup cache miss", "key", key) s, err := session.NewSession(&aws.Config{ Region: aws.String(awsCredentials.Region), Credentials: credentials.NewStaticCredentials( @@ -209,13 +216,13 @@ func (b *AwsSessionBuilder) genAwsSession(ctx *TicketContext) (AwsCredentials, e Expiration: sessionTokenOutput.Credentials.Expiration.Format(time.RFC3339Nano), } - vlog.VI(1).Infof("add to cache %+v", key) + log.Debug(ctx.ctx, "Adding AWS session to cache", "key", key) cache.Set(key, result) return result, nil } -func newEcrTicket(awsCredentials AwsCredentials) EcrTicket { +func newEcrTicket(ctx *TicketContext, awsCredentials AwsCredentials) EcrTicket { empty := EcrTicket{} s, err := session.NewSession(&aws.Config{ Region: aws.String(awsCredentials.Region), @@ -224,18 +231,22 @@ func newEcrTicket(awsCredentials AwsCredentials) EcrTicket { awsCredentials.SecretAccessKey, awsCredentials.SessionToken), }) + if err != nil { + log.Error(ctx.ctx, "error creating AWS session", "err", err.Error()) + return empty + } r, err := ecr.New(s).GetAuthorizationToken(&ecr.GetAuthorizationTokenInput{}) if err != nil { - vlog.Error(err) + log.Error(ctx.ctx, "error fetching ECR authorization token", "err", err.Error()) return empty } if len(r.AuthorizationData) == 0 { - vlog.Errorf("no authorization data from ECR") + log.Error(ctx.ctx, "no authorization data from ECR") return empty } auth := r.AuthorizationData[0] if auth.AuthorizationToken == nil || auth.ProxyEndpoint == nil || auth.ExpiresAt == nil { - vlog.Errorf("bad authorization data from ECR") + log.Error(ctx.ctx, "bad authorization data from ECR") return empty } return EcrTicket{ @@ -244,3 +255,103 @@ func newEcrTicket(awsCredentials AwsCredentials) EcrTicket { Endpoint: *auth.ProxyEndpoint, } } + +// Returns a list of Compute Instances that match the filter +func AwsEc2InstanceLookup(ctx *TicketContext, builder *AwsComputeInstancesBuilder) ([]ComputeInstance, error) { + var instances []ComputeInstance + + if len(builder.InstanceFilters) == 0 { + return instances, errors.New("An instance filters is required") + } + + // Create the STS session with the provided lookup role + config := aws.Config{ + Region: aws.String(builder.Region), + Credentials: stscreds.NewCredentials(ctx.session, builder.AwsAccountLookupRole), + Retryer: client.DefaultRetryer{ + NumMaxRetries: 100, + }, + } + + s, err := session.NewSession(&config) + if err != nil { + log.Error(ctx.ctx, "error creating AWS session", "err", err.Error()) + return instances, err + } + + var filters []*ec2.Filter + filters = append(filters, + &ec2.Filter{ + Name: aws.String("instance-state-name"), + Values: []*string{ + aws.String("running"), + }, + }, + ) + + for _, f := range builder.InstanceFilters { + filters = append(filters, + &ec2.Filter{ + Name: aws.String(f.Key), + Values: []*string{ + aws.String(f.Value), + }, + }, + ) + } + + output, err := ec2.New(s, &config).DescribeInstances(&ec2.DescribeInstancesInput{ + Filters: filters, + }) + if err != nil { + log.Error(ctx.ctx, "error describing EC2 instance", "err", err.Error()) + return instances, err + } + + for _, reservations := range output.Reservations { + for _, instance := range reservations.Instances { + var params []Parameter + publicIp, err := ec2util.GetPublicIPAddress(instance) + if err != nil { + log.Error(ctx.ctx, "error fetching EC2 public IP address. Continuing anyways.", "err", err.Error()) + continue // parse error skip + } + + privateIp, err := ec2util.GetPrivateIPAddress(instance) + if err != nil { + log.Error(ctx.ctx, "error fetching EC2 private IP address. Continuing anyways.", "err", err.Error()) + continue // parse error skip + } + + ec2Tags, err := ec2util.GetTags(instance) + if err != nil { + log.Error(ctx.ctx, "error fetching EC2 tags. Continuing anyways.", "err", err.Error()) + continue // parse error skip + } + for _, tag := range ec2Tags { + params = append(params, + Parameter{ + Key: *tag.Key, + Value: *tag.Value, + }) + } + + instanceId, err := ec2util.GetInstanceId(instance) + if err != nil { + log.Error(ctx.ctx, "error fetching EC2 instance ID. Continuing anyways.", "err", err.Error()) + continue // parse error skip + } + + instances = append(instances, + ComputeInstance{ + PublicIp: publicIp, + PrivateIp: privateIp, + InstanceId: instanceId, + Tags: params, + }) + } + } + + log.Debug(ctx.ctx, "AWS EC2 instances", "instances", instances) + return instances, nil +} diff --git a/security/ticket/b2.go b/security/ticket/b2.go index 72772648..397d4509 100644 --- a/security/ticket/b2.go +++ b/security/ticket/b2.go @@ -10,18 +10,18 @@ import ( "fmt" "net/http" + "github.com/grailbio/base/common/log" "github.com/grailbio/base/security/keycrypt" - "v.io/x/lib/vlog" ) const ( b2AuthorizeURL = "https://api.backblazeb2.com/b2api/v1/b2_authorize_account" ) -func (b *B2AccountAuthorizationBuilder) newB2Ticket() (TicketB2Ticket, error) { - vlog.Infof("B2AccountAuthorizationBuilder: %+v", b) +func (b *B2AccountAuthorizationBuilder) newB2Ticket(ctx *TicketContext) (TicketB2Ticket, error) { + log.Info(ctx.ctx, "Creating BackBlaze ticket.", "B2AccountAuthorizationBuilder", b) - b2Ticket, err := b.genB2Ticket() + b2Ticket, err := b.genB2Ticket(ctx) if err != nil { return TicketB2Ticket{}, err } @@ -31,7 +31,7 @@ func (b *B2AccountAuthorizationBuilder) newB2Ticket() (TicketB2Ticket, error) { }, nil } -func (b *B2AccountAuthorizationBuilder) genB2Ticket() (*B2Ticket, error) { +func (b *B2AccountAuthorizationBuilder) genB2Ticket(ctx *TicketContext) (*B2Ticket, error) { secret, err := keycrypt.Lookup(b.ApplicationKey) if err != nil { return nil, err @@ -46,14 +46,14 @@ func (b *B2AccountAuthorizationBuilder) genB2Ticket() (*B2Ticket, error) { req, err := http.NewRequest("GET", b2AuthorizeURL, nil) if err != nil { - vlog.Errorf("Cannot create new request: %s (%s)", b2AuthorizeURL, err) + log.Error(ctx.ctx, "Failed to create new request.", "b2AuthorizeURL", b2AuthorizeURL, "err", err.Error()) return nil, err } req.Header.Set("Authorization", headerForAuthorizeAccount) client := &http.Client{} resp, err := client.Do(req) if err != nil { - vlog.Errorf("Cannot authorize: %s (%s)", b2AuthorizeURL, err) + log.Error(ctx.ctx, "Failed to authorize.", "b2AuthorizeURL", b2AuthorizeURL, "err", err.Error()) return nil, err } defer resp.Body.Close() @@ -66,17 +66,17 @@ func (b *B2AccountAuthorizationBuilder) genB2Ticket() (*B2Ticket, error) { } var er ErrorResponse if err := json.NewDecoder(resp.Body).Decode(&er); err != nil { - vlog.Errorf("Cannot decode: %v (%s)", er, err) + log.Error(ctx.ctx, "Failed to decode response.", "errResponse", er, "err", err.Error()) return nil, err } - err := fmt.Errorf("Status %d: %s", er.Status, er.Message) - vlog.Error(err) + err := fmt.Errorf("status %d: %s", er.Status, er.Message) + log.Error(ctx.ctx, "Request failed.", "err", err.Error()) return nil, err } var b2Ticket B2Ticket if err := json.NewDecoder(resp.Body).Decode(&b2Ticket); err != nil { - vlog.Errorf("Cannot decode: %v (%s)", b2Ticket, err) + log.Error(ctx.ctx, "Failed to decode BackBlaze ticket.", "b2Ticket", b2Ticket, "err", err.Error()) return nil, err } return &b2Ticket, nil diff --git a/security/ticket/helper.go b/security/ticket/helper.go new file mode 100644 index 00000000..8a618b41 --- /dev/null +++ b/security/ticket/helper.go @@ -0,0 +1,250 @@ +package ticket + +import ( + "fmt" + "reflect" + "strings" + "v.io/v23/context" +) + +// An UnexpectedTicketType error is produced when a ticket cannot be cast to the expected type. +type UnexpectedTicketType struct { + Expected string + Actual string +} + +func (err UnexpectedTicketType) Error() string { + return fmt.Sprintf("ticket was a %q, not a %q", err.Actual, err.Expected) +} + +func expected(expected interface{}, actual interface{}) UnexpectedTicketType { + return UnexpectedTicketType{ + Expected: reflect.TypeOf(expected).Name(), + Actual: reflect.TypeOf(actual).Name(), + } +} + +// A Getter retrieves a ticket value for the key. +// +// Users of this package should use the default Client. +// This type exists primarily for unit tests which do not rely on the ticket-server. +type Getter func(ctx *context.T, key string) (Ticket, error) + +/* +Client is the default Getter which uses Vanadium to interact with the ticket-server. + +For example, to get a string value: + + myValue, err := ticket.Client.GetString(ctx, "ticket/path") +*/ +var Client Getter = func(ctx *context.T, key string) (Ticket, error) { + return TicketServiceClient(key).Get(ctx) +} + +func (g Getter) getTicket(ctx *context.T, path ...string) (Ticket, error) { + key := strings.Join(path, "/") + return g(ctx, key) +} + +// GetData for key from the ticket-server. It must be stored in a GenericTicket. +// Path components will be joined with a `/`. +func (g Getter) GetData(ctx *context.T, path ...string) (data []byte, err error) { + tick, err := g.getTicket(ctx, path...) + if err != nil { + return nil, err + } + + cast, ok := tick.(TicketGenericTicket) + if !ok { + return nil, expected(TicketGenericTicket{}, tick) + } + + return cast.Value.Data, nil +} + +// GetString for key from the ticket-server. It must be stored in a GenericTicket. +// Path components will be joined with a `/`. +func (g Getter) GetString(ctx *context.T, path ...string) (value string, err error) { + data, err := g.GetData(ctx, path...) + if err != nil { + return "", err + } + + return string(data), nil +} + +// GetAws credentials and helpers for key from the ticket-server. +// Path components will be joined with a `/`. +func (g Getter) GetAws(ctx *context.T, path ...string) (aws AwsTicket, err error) { + tick, err := g.getTicket(ctx, path...) + if err != nil { + return aws, err + } + + cast, ok := tick.(TicketAwsTicket) + if !ok { + return aws, expected(TicketAwsTicket{}, cast) + } + + return cast.Value, nil +} + +// GetS3 credentials and helpers for key from the ticket-server. +// Path components will be joined with a `/`. +func (g Getter) GetS3(ctx *context.T, path ...string) (S3 S3Ticket, err error) { + tick, err := g.getTicket(ctx, path...) + if err != nil { + return S3, err + } + + cast, ok := tick.(TicketS3Ticket) + if !ok { + return S3, expected(TicketS3Ticket{}, cast) + } + + return cast.Value, nil +} + +// GetSshCertificate for key from the ticket-server. +// Path components will be joined with a `/`. +func (g Getter) GetSshCertificate(ctx *context.T, path ...string) (SshCertificate SshCertificateTicket, err error) { + tick, err := g.getTicket(ctx, path...) + if err != nil { + return SshCertificate, err + } + + cast, ok := tick.(TicketSshCertificateTicket) + if !ok { + return SshCertificate, expected(TicketSshCertificateTicket{}, cast) + } + + return cast.Value, nil +} + +// GetEcr endpoint and helpers for key from the ticket-server. +// Path components will be joined with a `/`. +func (g Getter) GetEcr(ctx *context.T, path ...string) (Ecr EcrTicket, err error) { + tick, err := g.getTicket(ctx, path...) + if err != nil { + return Ecr, err + } + + cast, ok := tick.(TicketEcrTicket) + if !ok { + return Ecr, expected(TicketEcrTicket{}, cast) + } + + return cast.Value, nil +} + +// GetTlsServer credentials and helpers for key from the ticket-server. +// Path components will be joined with a `/`. +func (g Getter) GetTlsServer(ctx *context.T, path ...string) (TlsServer TlsServerTicket, err error) { + tick, err := g.getTicket(ctx, path...) + if err != nil { + return TlsServer, err + } + + cast, ok := tick.(TicketTlsServerTicket) + if !ok { + return TlsServer, expected(TicketTlsServerTicket{}, cast) + } + + return cast.Value, nil +} + +// GetTlsClient credentials and helpers for key from the ticket-server. +// Path components will be joined with a `/`. +func (g Getter) GetTlsClient(ctx *context.T, path ...string) (TlsClient TlsClientTicket, err error) { + tick, err := g.getTicket(ctx, path...) + if err != nil { + return TlsClient, err + } + + cast, ok := tick.(TicketTlsClientTicket) + if !ok { + return TlsClient, expected(TicketTlsClientTicket{}, cast) + } + + return cast.Value, nil +} + +// GetDocker credentials and helpers for key from the ticket-server. +// Path components will be joined with a `/`. +func (g Getter) GetDocker(ctx *context.T, path ...string) (Docker DockerTicket, err error) { + tick, err := g.getTicket(ctx, path...) + if err != nil { + return Docker, err + } + + cast, ok := tick.(TicketDockerTicket) + if !ok { + return Docker, expected(TicketDockerTicket{}, cast) + } + + return cast.Value, nil +} + +// GetDockerServer credentials and helpers for key from the ticket-server. +// Path components will be joined with a `/`. +func (g Getter) GetDockerServer(ctx *context.T, path ...string) (DockerServer DockerServerTicket, err error) { + tick, err := g.getTicket(ctx, path...) + if err != nil { + return DockerServer, err + } + + cast, ok := tick.(TicketDockerServerTicket) + if !ok { + return DockerServer, expected(TicketDockerServerTicket{}, cast) + } + + return cast.Value, nil +} + +// GetDockerClient credentials and helpers for key from the ticket-server. +// Path components will be joined with a `/`. +func (g Getter) GetDockerClient(ctx *context.T, path ...string) (DockerClient DockerClientTicket, err error) { + tick, err := g.getTicket(ctx, path...) + if err != nil { + return DockerClient, err + } + + cast, ok := tick.(TicketDockerClientTicket) + if !ok { + return DockerClient, expected(TicketDockerClientTicket{}, cast) + } + + return cast.Value, nil +} + +// GetB2 credentials and helpers for key from the ticket-server. +// Path components will be joined with a `/`. +func (g Getter) GetB2(ctx *context.T, path ...string) (B2 B2Ticket, err error) { + tick, err := g.getTicket(ctx, path...) + if err != nil { + return B2, err + } + + cast, ok := tick.(TicketB2Ticket) + if !ok { + return B2, expected(TicketB2Ticket{}, cast) + } + + return cast.Value, nil +} + +// GetVanadium blessing and helpers for key from the ticket-server. +// Path components will be joined with a `/`. +func (g Getter) GetVanadium(ctx *context.T, path ...string) (Vanadium VanadiumTicket, err error) { + tick, err := g.getTicket(ctx, path...) + if err != nil { + return Vanadium, err + } + + cast, ok := tick.(TicketVanadiumTicket) + if !ok { + return Vanadium, expected(TicketVanadiumTicket{}, cast) + } + + return cast.Value, nil +} diff --git a/security/ticket/helper_test.go b/security/ticket/helper_test.go new file mode 100644 index 00000000..aac9e7a3 --- /dev/null +++ b/security/ticket/helper_test.go @@ -0,0 +1,345 @@ +package ticket_test + +import ( + "bytes" + "fmt" + "github.com/grailbio/base/security/ticket" + "reflect" + "testing" + "v.io/v23/context" +) + +func TestGetter_path(t *testing.T) { + t.Run("it joins with slashes", func(t *testing.T) { + want := "ok" + client := mockString("string/key", want) + got, err := client.GetString(testContext(), "string", "key") + if err != nil { + t.Fatal(err) + } + if got != want { + t.Errorf("got %v, want %v", got, want) + } + }) +} + +func TestGetter_getTicket(t *testing.T) { + t.Run("it can error", func(t *testing.T) { + client := mockString("some/key", "ok") + _, err := client.GetString(testContext(), "other/key") + if err == nil { + t.Fatal("want error, got nil") + } + }) +} + +func TestGetter_GetData(t *testing.T) { + key := "data/key" + want := []byte{1, 2, 3} + client := mockData(key, want) + + got, err := client.GetData(testContext(), key) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestGetter_GetString(t *testing.T) { + key := "string/key" + want := "this is just a test" + client := mockString(key, want) + + got, err := client.GetString(testContext(), key) + if err != nil { + t.Fatal(err) + } + if got != want { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestGetter_GetAws(t *testing.T) { + key := "aws/key" + want := ticket.AwsTicket{ + AwsAssumeRoleBuilder: &ticket.AwsAssumeRoleBuilder{ + Region: "region", + Role: "role", + TtlSec: 123, + }, + AwsCredentials: ticket.AwsCredentials{ + Region: "region", + AccessKeyId: "accessKeyID", + SecretAccessKey: "secretAccessKey", + SessionToken: "sessionToken", + Expiration: "expiration", + }, + } + client := mockAws(key, want) + + got, err := client.GetAws(testContext(), key) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestGetter_GetS3(t *testing.T) { + key := "s3/key" + want := ticket.S3Ticket{ + AwsAssumeRoleBuilder: &ticket.AwsAssumeRoleBuilder{ + Region: "region", + Role: "role", + TtlSec: 123, + }, + Endpoint: "endpoint", + Bucket: "bucket", + Prefix: "prefix", + } + client := mockS3(key, want) + + got, err := client.GetS3(testContext(), key) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestGetter_GetSshCertificate(t *testing.T) { + key := "ssh/key" + want := ticket.SshCertificateTicket{ + Username: "username", + } + client := mockSshCertificate(key, want) + + got, err := client.GetSshCertificate(testContext(), key) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestGetter_GetEcr(t *testing.T) { + key := "ecr/key" + want := ticket.EcrTicket{ + AwsAssumeRoleBuilder: &ticket.AwsAssumeRoleBuilder{ + Region: "region", + Role: "role", + TtlSec: 123, + }, + AuthorizationToken: "authorizationToken", + Expiration: "expiration", + Endpoint: "endpoint", + } + client := mockEcr(key, want) + + got, err := client.GetEcr(testContext(), key) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestGetter_GetTlsServer(t *testing.T) { + key := "TlsServer/key" + want := ticket.TlsServerTicket{ + TlsCertAuthorityBuilder: &ticket.TlsCertAuthorityBuilder{ + Authority: "authority", + TtlSec: 123, + CommonName: "commonName", + }, + Credentials: ticket.TlsCredentials{ + AuthorityCert: "authorityCert", + Cert: "cert", + Key: "key", + }, + } + client := mockTlsServer(key, want) + + got, err := client.GetTlsServer(testContext(), key) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestGetter_GetTlsClient(t *testing.T) { + key := "TlsClient/key" + want := ticket.TlsClientTicket{ + TlsCertAuthorityBuilder: &ticket.TlsCertAuthorityBuilder{ + Authority: "authority", + TtlSec: 123, + CommonName: "commonName", + }, + Credentials: ticket.TlsCredentials{ + AuthorityCert: "authorityCert", + Cert: "cert", + Key: "key", + }, + } + client := mockTlsClient(key, want) + + got, err := client.GetTlsClient(testContext(), key) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestGetter_GetDocker(t *testing.T) { + key := "Docker/key" + want := ticket.DockerTicket{ + TlsCertAuthorityBuilder: &ticket.TlsCertAuthorityBuilder{ + Authority: "authority", + TtlSec: 123, + CommonName: "commonName", + }, + Credentials: ticket.TlsCredentials{ + AuthorityCert: "authorityCert", + Cert: "cert", + Key: "key", + }, + Url: "url", + } + client := mockDocker(key, want) + + got, err := client.GetDocker(testContext(), key) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestGetter_GetDockerServer(t *testing.T) { + key := "DockerServer/key" + want := ticket.DockerServerTicket{} + client := mockDockerServer(key, want) + + got, err := client.GetDockerServer(testContext(), key) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestGetter_GetDockerClient(t *testing.T) { + key := "DockerClient/key" + want := ticket.DockerClientTicket{} + client := mockDockerClient(key, want) + + got, err := client.GetDockerClient(testContext(), key) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestGetter_GetB2(t *testing.T) { + key := "B2/key" + want := ticket.B2Ticket{} + client := mockB2(key, want) + + got, err := client.GetB2(testContext(), key) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestGetter_GetVanadium(t *testing.T) { + key := "Vanadium/key" + want := ticket.VanadiumTicket{} + client := mockVanadium(key, want) + + got, err := client.GetVanadium(testContext(), key) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } +} + +// testContext creates a nil context which is safe to use with a mock client. +func testContext() *context.T { + return nil +} + +// mock is a shim to convert a ticket value to a ticket.Getter. +func mock(expectKey string, value interface{}) ticket.Getter { + return func(_ *context.T, gotKey string) (ticket.Ticket, error) { + if gotKey == expectKey { + return value.(ticket.Ticket), nil + } + + return nil, fmt.Errorf("ticket not found") + } +} + +func mockData(key string, data []byte) ticket.Getter { + return mock(key, ticket.TicketGenericTicket{ + Value: ticket.GenericTicket{Data: data}, + }) +} + +func mockString(key string, s string) ticket.Getter { + return mock(key, ticket.TicketGenericTicket{ + Value: ticket.GenericTicket{Data: []byte(s)}, + }) +} + +func mockAws(key string, aws ticket.AwsTicket) ticket.Getter { + return mock(key, ticket.TicketAwsTicket{Value: aws}) +} +func mockS3(key string, s3 ticket.S3Ticket) ticket.Getter { + return mock(key, ticket.TicketS3Ticket{Value: s3}) +} +func mockSshCertificate(key string, ssh ticket.SshCertificateTicket) ticket.Getter { + return mock(key, ticket.TicketSshCertificateTicket{Value: ssh}) +} +func mockEcr(key string, ecr ticket.EcrTicket) ticket.Getter { + return mock(key, ticket.TicketEcrTicket{Value: ecr}) +} +func mockTlsServer(key string, TlsServer ticket.TlsServerTicket) ticket.Getter { + return mock(key, ticket.TicketTlsServerTicket{Value: TlsServer}) +} +func mockTlsClient(key string, TlsClient ticket.TlsClientTicket) ticket.Getter { + return mock(key, ticket.TicketTlsClientTicket{Value: TlsClient}) +} +func mockDocker(key string, Docker ticket.DockerTicket) ticket.Getter { + return mock(key, ticket.TicketDockerTicket{Value: Docker}) +} +func mockDockerServer(key string, DockerServer ticket.DockerServerTicket) ticket.Getter { + return mock(key, ticket.TicketDockerServerTicket{Value: DockerServer}) +} +func mockDockerClient(key string, DockerClient ticket.DockerClientTicket) ticket.Getter { + return mock(key, ticket.TicketDockerClientTicket{Value: DockerClient}) +} +func mockB2(key string, b2 ticket.B2Ticket) ticket.Getter { + return mock(key, ticket.TicketB2Ticket{Value: b2}) +} +func mockVanadium(key string, vanadium ticket.VanadiumTicket) ticket.Getter { + return mock(key, ticket.TicketVanadiumTicket{Value: vanadium}) +} diff --git a/security/ticket/ssh.go b/security/ticket/ssh.go new file mode 100644 index 00000000..e52ce189 --- /dev/null +++ b/security/ticket/ssh.go @@ -0,0 +1,80 @@ +// Copyright 2020 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +package ticket + +import ( + "time" + + "github.com/grailbio/base/common/log" + "github.com/grailbio/base/security/keycrypt" + "github.com/grailbio/base/security/ssh/certificateauthority" +) + +const sshDriftMargin = 10 * time.Minute + +func (b *SshCertAuthorityBuilder) newSshCertificateTicket(ctx *TicketContext) (TicketSshCertificateTicket, error) { + sshCert, err := b.genSshCertWithKeyUsage(ctx) + + if err != nil { + return TicketSshCertificateTicket{}, err + } + + return TicketSshCertificateTicket{ + Value: SshCertificateTicket{ + Credentials: sshCert, + }, + }, nil +} + +func (b *SshCertAuthorityBuilder) genSshCertWithKeyUsage(ctx *TicketContext) (SshCert, error) { + log.Info(ctx.ctx, "Generating SSH certificate.", "SshCertAuthorityBuilder", b) + empty := SshCert{} + + CaPrivateKey, err := keycrypt.Lookup(b.CaPrivateKey) + if err != nil { + return empty, err + } + + authority := certificateauthority.CertificateAuthority{DriftMargin: sshDriftMargin, PrivateKey: CaPrivateKey, Certificate: b.CaCertificate} + if err = authority.Init(); err != nil { + return empty, err + } + + ttl := time.Duration(b.TtlMin) * time.Minute + + cr := certificateauthority.CertificateRequest{ + // SSH Public Key that is being signed + SshPublicKey: []byte(b.PublicKey), + + // List of host names, or usernames that will be added to the cert + Principals: b.Principals, + Ttl: ttl, + KeyID: ctx.remoteBlessings.String(), + + CertType: "user", + + CriticalOptions: b.CriticalOptions, + + // Extensions to assign to the ssh Certificate + // The default allow basic function - permit-pty is usually required + // Recommended values are: + // []string{ + // "permit-X11-forwarding", + // "permit-agent-forwarding", + // "permit-port-forwarding", + // "permit-pty", + // "permit-user-rc", + // } + Extensions: b.ExtensionsOptions, + } + + sshCert, err := authority.IssueWithKeyUsage(cr) + if err != nil { + return empty, err + } + + r := SshCert{Cert: sshCert} + return r, nil +} diff --git a/security/ticket/ticket.go b/security/ticket/ticket.go index c7f94a86..da6fb3d8 100644 --- a/security/ticket/ticket.go +++ b/security/ticket/ticket.go @@ -7,14 +7,15 @@ package ticket import ( "bytes" "fmt" + "os" "reflect" "strings" "github.com/aws/aws-sdk-go/aws/session" + "github.com/grailbio/base/common/log" "github.com/grailbio/base/security/keycrypt" "v.io/v23/context" "v.io/v23/security" - "v.io/x/lib/vlog" ) // TicketContext wraps the informations that needs to carry around between @@ -37,12 +38,13 @@ func NewTicketContext(ctx *context.T, session *session.Session, remoteBlessings // Builder is the interface for building a Ticket. type Builder interface { - Build(ctx *TicketContext) (Ticket, error) + Build(ctx *TicketContext, parameters []Parameter) (Ticket, error) } var ( _ Builder = (*TicketAwsTicket)(nil) _ Builder = (*TicketS3Ticket)(nil) + _ Builder = (*TicketSshCertificateTicket)(nil) _ Builder = (*TicketEcrTicket)(nil) _ Builder = (*TicketTlsServerTicket)(nil) _ Builder = (*TicketTlsClientTicket)(nil) @@ -55,7 +57,7 @@ var ( ) // Build builds a Ticket by running all the builders. -func (t TicketAwsTicket) Build(ctx *TicketContext) (Ticket, error) { +func (t TicketAwsTicket) Build(ctx *TicketContext, _ []Parameter) (Ticket, error) { r := TicketAwsTicket{} var err error if t.Value.AwsAssumeRoleBuilder != nil { @@ -76,7 +78,7 @@ func (t TicketAwsTicket) Build(ctx *TicketContext) (Ticket, error) { } t.Value.AwsSessionBuilder = nil } - r = *mergeOrDie(&r, &t).(*TicketAwsTicket) + r = *mergeOrDie(ctx, &r, &t).(*TicketAwsTicket) err = r.Value.AwsCredentials.kmsInterpolate() return r, err @@ -88,7 +90,7 @@ func (t *AwsCredentials) kmsInterpolate() (err error) { } // Build builds a Ticket by running all the builders. -func (t TicketS3Ticket) Build(ctx *TicketContext) (Ticket, error) { +func (t TicketS3Ticket) Build(ctx *TicketContext, _ []Parameter) (Ticket, error) { r := TicketS3Ticket{} var err error if t.Value.AwsAssumeRoleBuilder != nil { @@ -109,13 +111,54 @@ func (t TicketS3Ticket) Build(ctx *TicketContext) (Ticket, error) { } t.Value.AwsSessionBuilder = nil } - r = *mergeOrDie(&r, &t).(*TicketS3Ticket) + r = *mergeOrDie(ctx, &r, &t).(*TicketS3Ticket) err = r.Value.AwsCredentials.kmsInterpolate() return r, err } // Build builds a Ticket by running all the builders. -func (t TicketEcrTicket) Build(ctx *TicketContext) (Ticket, error) { +func (t TicketSshCertificateTicket) Build(ctx *TicketContext, parameters []Parameter) (Ticket, error) { + rCompute := TicketSshCertificateTicket{} + + // Populate the ComputeInstances first as input to the SSH CertBuilder + if t.Value.AwsComputeInstancesBuilder != nil { + var instanceBuilder = t.Value.AwsComputeInstancesBuilder + if instanceBuilder.AwsAccountLookupRole != "" { + instances, err := AwsEc2InstanceLookup(ctx, instanceBuilder) + if err != nil { + return nil, err + } + rCompute.Value.ComputeInstances = instances + } else { + return rCompute, fmt.Errorf("AwsAccountLookupRole required for AwsComputeInstancesBuilder.") + } + } + + rSsh := TicketSshCertificateTicket{} + if t.Value.SshCertAuthorityBuilder != nil { + + // Set the PublicKey parameter on the builder from the input parameters + // NOTE: If multiple publicKeys are provided as input, use the last one + for _, param := range parameters { + if param.Key == "PublicKey" { + t.Value.SshCertAuthorityBuilder.PublicKey = param.Value + } + } + + var err error + rSsh, err = t.Value.SshCertAuthorityBuilder.newSshCertificateTicket(ctx) + if err != nil { + return rSsh, err + } + t.Value.SshCertAuthorityBuilder = nil + } + + r := *mergeOrDie(ctx, &rCompute, &rSsh).(*TicketSshCertificateTicket) + return *mergeOrDie(ctx, &r, &t).(*TicketSshCertificateTicket), nil +} + +// Build builds a Ticket by running all the builders. +func (t TicketEcrTicket) Build(ctx *TicketContext, _ []Parameter) (Ticket, error) { r := TicketEcrTicket{} if t.Value.AwsAssumeRoleBuilder != nil { var err error @@ -125,95 +168,95 @@ func (t TicketEcrTicket) Build(ctx *TicketContext) (Ticket, error) { } t.Value.AwsAssumeRoleBuilder = nil } - return *mergeOrDie(&r, &t).(*TicketEcrTicket), nil + return *mergeOrDie(ctx, &r, &t).(*TicketEcrTicket), nil } // Build builds a Ticket by running all the builders. -func (t TicketTlsServerTicket) Build(_ *TicketContext) (Ticket, error) { +func (t TicketTlsServerTicket) Build(ctx *TicketContext, _ []Parameter) (Ticket, error) { r := TicketTlsServerTicket{} if t.Value.TlsCertAuthorityBuilder != nil { var err error - r, err = t.Value.TlsCertAuthorityBuilder.newTlsServerTicket() + r, err = t.Value.TlsCertAuthorityBuilder.newTlsServerTicket(ctx) if err != nil { return r, err } t.Value.TlsCertAuthorityBuilder = nil } - return *mergeOrDie(&r, &t).(*TicketTlsServerTicket), nil + return *mergeOrDie(ctx, &r, &t).(*TicketTlsServerTicket), nil } // Build builds a Ticket by running all the builders. -func (t TicketTlsClientTicket) Build(_ *TicketContext) (Ticket, error) { +func (t TicketTlsClientTicket) Build(ctx *TicketContext, _ []Parameter) (Ticket, error) { r := TicketTlsClientTicket{} if t.Value.TlsCertAuthorityBuilder != nil { var err error - r, err = t.Value.TlsCertAuthorityBuilder.newTlsClientTicket() + r, err = t.Value.TlsCertAuthorityBuilder.newTlsClientTicket(ctx) if err != nil { return r, err } t.Value.TlsCertAuthorityBuilder = nil } - return *mergeOrDie(&r, &t).(*TicketTlsClientTicket), nil + return *mergeOrDie(ctx, &r, &t).(*TicketTlsClientTicket), nil } // Build builds a Ticket by running all the builders. -func (t TicketDockerTicket) Build(_ *TicketContext) (Ticket, error) { +func (t TicketDockerTicket) Build(ctx *TicketContext, _ []Parameter) (Ticket, error) { r := TicketDockerTicket{} if t.Value.TlsCertAuthorityBuilder != nil { var err error - r, err = t.Value.TlsCertAuthorityBuilder.newDockerTicket() + r, err = t.Value.TlsCertAuthorityBuilder.newDockerTicket(ctx) if err != nil { return r, err } t.Value.TlsCertAuthorityBuilder = nil } - return *mergeOrDie(&r, &t).(*TicketDockerTicket), nil + return *mergeOrDie(ctx, &r, &t).(*TicketDockerTicket), nil } // Build builds a Ticket by running all the builders. -func (t TicketDockerServerTicket) Build(_ *TicketContext) (Ticket, error) { +func (t TicketDockerServerTicket) Build(ctx *TicketContext, _ []Parameter) (Ticket, error) { r := TicketDockerServerTicket{} if t.Value.TlsCertAuthorityBuilder != nil { var err error - r, err = t.Value.TlsCertAuthorityBuilder.newDockerServerTicket() + r, err = t.Value.TlsCertAuthorityBuilder.newDockerServerTicket(ctx) if err != nil { return r, err } t.Value.TlsCertAuthorityBuilder = nil } - return *mergeOrDie(&r, &t).(*TicketDockerServerTicket), nil + return *mergeOrDie(ctx, &r, &t).(*TicketDockerServerTicket), nil } // Build builds a Ticket by running all the builders. -func (t TicketDockerClientTicket) Build(_ *TicketContext) (Ticket, error) { +func (t TicketDockerClientTicket) Build(ctx *TicketContext, _ []Parameter) (Ticket, error) { r := TicketDockerClientTicket{} if t.Value.TlsCertAuthorityBuilder != nil { var err error - r, err = t.Value.TlsCertAuthorityBuilder.newDockerClientTicket() + r, err = t.Value.TlsCertAuthorityBuilder.newDockerClientTicket(ctx) if err != nil { return r, err } t.Value.TlsCertAuthorityBuilder = nil } - return *mergeOrDie(&r, &t).(*TicketDockerClientTicket), nil + return *mergeOrDie(ctx, &r, &t).(*TicketDockerClientTicket), nil } // Build builds a Ticket by running all the builders. -func (t TicketB2Ticket) Build(_ *TicketContext) (Ticket, error) { +func (t TicketB2Ticket) Build(ctx *TicketContext, _ []Parameter) (Ticket, error) { r := TicketB2Ticket{} if t.Value.B2AccountAuthorizationBuilder != nil { var err error - r, err = t.Value.B2AccountAuthorizationBuilder.newB2Ticket() + r, err = t.Value.B2AccountAuthorizationBuilder.newB2Ticket(ctx) if err != nil { return r, err } t.Value.B2AccountAuthorizationBuilder = nil } - return *mergeOrDie(&r, &t).(*TicketB2Ticket), nil + return *mergeOrDie(ctx, &r, &t).(*TicketB2Ticket), nil } // Build builds a Ticket by running all the builders. -func (t TicketVanadiumTicket) Build(ctx *TicketContext) (Ticket, error) { +func (t TicketVanadiumTicket) Build(ctx *TicketContext, _ []Parameter) (Ticket, error) { r := TicketVanadiumTicket{} if t.Value.VanadiumBuilder != nil { var err error @@ -223,13 +266,13 @@ func (t TicketVanadiumTicket) Build(ctx *TicketContext) (Ticket, error) { } t.Value.VanadiumBuilder = nil } - return *mergeOrDie(&r, &t).(*TicketVanadiumTicket), nil + return *mergeOrDie(ctx, &r, &t).(*TicketVanadiumTicket), nil } // Build builds a Ticket. -func (t TicketGenericTicket) Build(ctx *TicketContext) (Ticket, error) { +func (t TicketGenericTicket) Build(ctx *TicketContext, _ []Parameter) (Ticket, error) { r := TicketGenericTicket{} - r = *mergeOrDie(&r, &t).(*TicketGenericTicket) + r = *mergeOrDie(ctx, &r, &t).(*TicketGenericTicket) var err error r.Value.Data, err = kmsInterpolationBytes(r.Value.Data) return r, err @@ -238,14 +281,15 @@ func (t TicketGenericTicket) Build(ctx *TicketContext) (Ticket, error) { // merge i2 in i1 by overwriting in i1 all the non-zero fields in i2. The i1 // and i2 needs to be references to the same type. Only simple types (bool, // numeric, string) and string are supported. -func mergeOrDie(i1, i2 interface{}) interface{} { +func mergeOrDie(ctx *TicketContext, i1, i2 interface{}) interface{} { if reflect.DeepEqual(i1, i2) { return i1 } v1, v2 := reflect.ValueOf(i1).Elem(), reflect.ValueOf(i2).Elem() k1, k2 := v1.Kind(), v2.Kind() if k1 != k2 { - vlog.Fatalf("different types in merge: %+v (%s) vs %v (%s)", v1, v1.Kind(), v2, v2.Kind()) + log.Error(ctx.ctx, "different types in merge: %+v (%s) vs %v (%s)", v1, v1.Kind(), v2, v2.Kind()) + os.Exit(255) } switch k1 { case reflect.Struct: @@ -254,7 +298,7 @@ func mergeOrDie(i1, i2 interface{}) interface{} { if !f1.CanSet() { continue } - v := mergeOrDie(f1.Addr().Interface(), f2.Addr().Interface()) + v := mergeOrDie(ctx, f1.Addr().Interface(), f2.Addr().Interface()) f1.Set(reflect.Indirect(reflect.ValueOf(v))) } case reflect.Map: diff --git a/security/ticket/ticket.vdl b/security/ticket/ticket.vdl index 16018c27..9645a09c 100644 --- a/security/ticket/ticket.vdl +++ b/security/ticket/ticket.vdl @@ -2,6 +2,13 @@ package ticket import "v.io/v23/security/access" +// TicketConfig Controls fields +type Control enum { + PagerDutyId + Rationale + TicketId +} + // AwsCredentials describes a set of (potentially temporary) AWS credentials. type AwsCredentials struct { Region string @@ -45,13 +52,44 @@ type TlsCertAuthorityBuilder struct { // Common Name of the generated cert. CommonName string - + // Subject Alternate Name list. - // Note: x509 spec says if SAN is set, CN is usually ignored. + // Note: x509 spec says if SAN is set, CN is usually ignored. // Include CN in SAN list if you want the CN to be verified. San []string } +type SshCertAuthorityBuilder struct { + // ssh-encoded private key of the Certificate Authority. + CaPrivateKey string + + // ssh-encoded Certificate + CaCertificate string + + // ssh-encoded Public key that will be signed to create the certificate. + PublicKey string + + // Additional SSH Cert options like + // permit-X11-forwarding + // permit-agent-forwarding + // permit-port-forwarding + // permit-pty + // permit-user-rc + ExtensionsOptions []string + + // Additional SSH Options that are required to be valid/accepted + CriticalOptions []string + + // The Usernames that this key can connect as - defaults as + // ubuntu + // core + // ec2-user + Principals []string + + // TTL for the generated cert - user cert < 60 ; host cert < 2628000 (5 years) + TtlMin int32 +} + // B2AccountAuthorizationBuilder describes the information required to // obtain a B2 account authorization. type B2AccountAuthorizationBuilder struct { @@ -103,7 +141,13 @@ type EcrTicket struct { Endpoint string } -// TlsCredentials describes a generic set of TLS credentials that include: +// SshCert describes a ssh public Certifcate +type SshCert struct { + // ssh-encoded certificate (host or user). + Cert string +} + +// TlsCredentials describes a generic set of Tls credentials that include: // the CA that accepted by the client (only peers that present a certificate // sign by this CA are accepted), the client certificate and the client // private key. @@ -154,7 +198,7 @@ type DockerServerTicket struct { Credentials TlsCredentials } -// DockerClientTicket instance represents the TLS certificate material required +// DockerClientTicket instance represents the TLS certificate material required // for clients to authenticate against a specific DockerServer. type DockerClientTicket struct { TlsCertAuthorityBuilder ?TlsCertAuthorityBuilder @@ -165,6 +209,38 @@ type DockerClientTicket struct { Url string } +// SshCertificateTicket describes a SSH Signed Certificate. +// SSH Certificates are essentially a version of TLS certs but they have additional +// optional parameters and can take a public key as part of their signing request. +type SshCertificateTicket struct { + SshCertAuthorityBuilder ?SshCertAuthorityBuilder + AwsComputeInstancesBuilder ?AwsComputeInstancesBuilder + + ComputeInstances []ComputeInstance + Credentials SshCert + // Recommended username to use + Username string +} + +type AwsComputeInstancesBuilder struct { + // Instance Filters that will produce a list of instance IDs and related information + // https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_DescribeInstances.html + InstanceFilters []Parameter + + // AWS ARN for a role that should be used to perform the instance lookups + AwsAccountLookupRole string + + // AWS region to use for the lookup + Region string +} + +// Simplification of describeInstance data to provide to ticket-server users +type ComputeInstance struct { + PublicIp string + PrivateIp string + InstanceId string + Tags []Parameter +} // B2Ticket instance contains a B2 account level authorization // token plus URLs and configuration values for the account. @@ -199,6 +275,7 @@ type GenericTicket struct { type Ticket union { AwsTicket AwsTicket S3Ticket S3Ticket + SshCertificateTicket SshCertificateTicket EcrTicket EcrTicket TlsServerTicket TlsServerTicket TlsClientTicket TlsClientTicket @@ -214,6 +291,7 @@ type Ticket union { type TicketConfig struct { Ticket Ticket Permissions access.Permissions + Controls map[Control]bool } type Config struct { @@ -221,10 +299,23 @@ type Config struct { Permissions access.Permissions } +// Key/Value pair that can be passed into the GET request. +type Parameter struct { + Key string + Value string +} + + // TicketService provides a way to obtain a ticket. The access can be // restricted by setting the permissions appropriately. type TicketService interface { GetPermissions() (perms access.Permissions, version string | error) {access.Read} SetPermissions(perms access.Permissions, version string) error {access.Admin} Get() (Ticket | error) {access.Read} + GetWithParameters(parameters []Parameter) (Ticket | error) {access.Read} + GetWithArgs(args map[string]string) (Ticket | error) {access.Read} } + +type ListService interface { + List() ([]string | error) {access.Read} +} \ No newline at end of file diff --git a/security/ticket/ticket.vdl.go b/security/ticket/ticket.vdl.go index c75062e9..38bebac4 100644 --- a/security/ticket/ticket.vdl.go +++ b/security/ticket/ticket.vdl.go @@ -1,25 +1,101 @@ -// Copyright 2018 GRAIL, Inc. All rights reserved. -// Use of this source code is governed by the Apache-2.0 -// license that can be found in the LICENSE file. - // This file was auto-generated by the vanadium vdl tool. // Package: ticket +//nolint:golint package ticket import ( "fmt" - "v.io/v23" + + v23 "v.io/v23" "v.io/v23/context" "v.io/v23/rpc" "v.io/v23/security/access" "v.io/v23/vdl" ) -var _ = __VDLInit() // Must be first; see __VDLInit comments for details. +var _ = initializeVDL() // Must be first; see initializeVDL comments for details. -////////////////////////////////////////////////// // Type definitions +// ================ + +// TicketConfig Controls fields +type Control int + +const ( + ControlPagerDutyId Control = iota + ControlRationale + ControlTicketId +) + +// ControlAll holds all labels for Control. +var ControlAll = [...]Control{ControlPagerDutyId, ControlRationale, ControlTicketId} + +// ControlFromString creates a Control from a string label. +//nolint:deadcode,unused +func ControlFromString(label string) (x Control, err error) { + err = x.Set(label) + return +} + +// Set assigns label to x. +func (x *Control) Set(label string) error { + switch label { + case "PagerDutyId", "pagerdutyid": + *x = ControlPagerDutyId + return nil + case "Rationale", "rationale": + *x = ControlRationale + return nil + case "TicketId", "ticketid": + *x = ControlTicketId + return nil + } + *x = -1 + return fmt.Errorf("unknown label %q in ticket.Control", label) +} + +// String returns the string label of x. +func (x Control) String() string { + switch x { + case ControlPagerDutyId: + return "PagerDutyId" + case ControlRationale: + return "Rationale" + case ControlTicketId: + return "TicketId" + } + return "" +} + +func (Control) VDLReflect(struct { + Name string `vdl:"github.com/grailbio/base/security/ticket.Control"` + Enum struct{ PagerDutyId, Rationale, TicketId string } +}) { +} + +func (x Control) VDLIsZero() bool { //nolint:gocyclo + return x == ControlPagerDutyId +} + +func (x Control) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.WriteValueString(vdlTypeEnum1, x.String()); err != nil { + return err + } + return nil +} + +func (x *Control) VDLRead(dec vdl.Decoder) error { //nolint:gocyclo + switch value, err := dec.ReadValueString(); { + case err != nil: + return err + default: + if err := x.Set(value); err != nil { + return err + } + } + return nil +} // AwsCredentials describes a set of (potentially temporary) AWS credentials. type AwsCredentials struct { @@ -36,12 +112,12 @@ func (AwsCredentials) VDLReflect(struct { }) { } -func (x AwsCredentials) VDLIsZero() bool { +func (x AwsCredentials) VDLIsZero() bool { //nolint:gocyclo return x == AwsCredentials{} } -func (x AwsCredentials) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_struct_1); err != nil { +func (x AwsCredentials) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeStruct2); err != nil { return err } if x.Region != "" { @@ -75,9 +151,9 @@ func (x AwsCredentials) VDLWrite(enc vdl.Encoder) error { return enc.FinishValue() } -func (x *AwsCredentials) VDLRead(dec vdl.Decoder) error { +func (x *AwsCredentials) VDLRead(dec vdl.Decoder) error { //nolint:gocyclo *x = AwsCredentials{} - if err := dec.StartValue(__VDLType_struct_1); err != nil { + if err := dec.StartValue(vdlTypeStruct2); err != nil { return err } decType := dec.Type() @@ -89,8 +165,8 @@ func (x *AwsCredentials) VDLRead(dec vdl.Decoder) error { case index == -1: return dec.FinishValue() } - if decType != __VDLType_struct_1 { - index = __VDLType_struct_1.FieldIndexByName(decType.Field(index).Name) + if decType != vdlTypeStruct2 { + index = vdlTypeStruct2.FieldIndexByName(decType.Field(index).Name) if index == -1 { if err := dec.SkipValue(); err != nil { return err @@ -152,12 +228,12 @@ func (AwsAssumeRoleBuilder) VDLReflect(struct { }) { } -func (x AwsAssumeRoleBuilder) VDLIsZero() bool { +func (x AwsAssumeRoleBuilder) VDLIsZero() bool { //nolint:gocyclo return x == AwsAssumeRoleBuilder{} } -func (x AwsAssumeRoleBuilder) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_struct_2); err != nil { +func (x AwsAssumeRoleBuilder) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeStruct3); err != nil { return err } if x.Region != "" { @@ -181,9 +257,9 @@ func (x AwsAssumeRoleBuilder) VDLWrite(enc vdl.Encoder) error { return enc.FinishValue() } -func (x *AwsAssumeRoleBuilder) VDLRead(dec vdl.Decoder) error { +func (x *AwsAssumeRoleBuilder) VDLRead(dec vdl.Decoder) error { //nolint:gocyclo *x = AwsAssumeRoleBuilder{} - if err := dec.StartValue(__VDLType_struct_2); err != nil { + if err := dec.StartValue(vdlTypeStruct3); err != nil { return err } decType := dec.Type() @@ -195,8 +271,8 @@ func (x *AwsAssumeRoleBuilder) VDLRead(dec vdl.Decoder) error { case index == -1: return dec.FinishValue() } - if decType != __VDLType_struct_2 { - index = __VDLType_struct_2.FieldIndexByName(decType.Field(index).Name) + if decType != vdlTypeStruct3 { + index = vdlTypeStruct3.FieldIndexByName(decType.Field(index).Name) if index == -1 { if err := dec.SkipValue(); err != nil { return err @@ -244,12 +320,12 @@ func (AwsSessionBuilder) VDLReflect(struct { }) { } -func (x AwsSessionBuilder) VDLIsZero() bool { +func (x AwsSessionBuilder) VDLIsZero() bool { //nolint:gocyclo return x == AwsSessionBuilder{} } -func (x AwsSessionBuilder) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_struct_3); err != nil { +func (x AwsSessionBuilder) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeStruct4); err != nil { return err } if x.AwsCredentials != (AwsCredentials{}) { @@ -271,9 +347,9 @@ func (x AwsSessionBuilder) VDLWrite(enc vdl.Encoder) error { return enc.FinishValue() } -func (x *AwsSessionBuilder) VDLRead(dec vdl.Decoder) error { +func (x *AwsSessionBuilder) VDLRead(dec vdl.Decoder) error { //nolint:gocyclo *x = AwsSessionBuilder{} - if err := dec.StartValue(__VDLType_struct_3); err != nil { + if err := dec.StartValue(vdlTypeStruct4); err != nil { return err } decType := dec.Type() @@ -285,8 +361,8 @@ func (x *AwsSessionBuilder) VDLRead(dec vdl.Decoder) error { case index == -1: return dec.FinishValue() } - if decType != __VDLType_struct_3 { - index = __VDLType_struct_3.FieldIndexByName(decType.Field(index).Name) + if decType != vdlTypeStruct4 { + index = vdlTypeStruct4.FieldIndexByName(decType.Field(index).Name) if index == -1 { if err := dec.SkipValue(); err != nil { return err @@ -330,7 +406,7 @@ func (TlsCertAuthorityBuilder) VDLReflect(struct { }) { } -func (x TlsCertAuthorityBuilder) VDLIsZero() bool { +func (x TlsCertAuthorityBuilder) VDLIsZero() bool { //nolint:gocyclo if x.Authority != "" { return false } @@ -346,8 +422,8 @@ func (x TlsCertAuthorityBuilder) VDLIsZero() bool { return true } -func (x TlsCertAuthorityBuilder) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_struct_4); err != nil { +func (x TlsCertAuthorityBuilder) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeStruct5); err != nil { return err } if x.Authority != "" { @@ -369,7 +445,7 @@ func (x TlsCertAuthorityBuilder) VDLWrite(enc vdl.Encoder) error { if err := enc.NextField(3); err != nil { return err } - if err := __VDLWriteAnon_list_1(enc, x.San); err != nil { + if err := vdlWriteAnonList1(enc, x.San); err != nil { return err } } @@ -379,8 +455,8 @@ func (x TlsCertAuthorityBuilder) VDLWrite(enc vdl.Encoder) error { return enc.FinishValue() } -func __VDLWriteAnon_list_1(enc vdl.Encoder, x []string) error { - if err := enc.StartValue(__VDLType_list_5); err != nil { +func vdlWriteAnonList1(enc vdl.Encoder, x []string) error { + if err := enc.StartValue(vdlTypeList6); err != nil { return err } if err := enc.SetLenHint(len(x)); err != nil { @@ -397,9 +473,9 @@ func __VDLWriteAnon_list_1(enc vdl.Encoder, x []string) error { return enc.FinishValue() } -func (x *TlsCertAuthorityBuilder) VDLRead(dec vdl.Decoder) error { +func (x *TlsCertAuthorityBuilder) VDLRead(dec vdl.Decoder) error { //nolint:gocyclo *x = TlsCertAuthorityBuilder{} - if err := dec.StartValue(__VDLType_struct_4); err != nil { + if err := dec.StartValue(vdlTypeStruct5); err != nil { return err } decType := dec.Type() @@ -411,8 +487,8 @@ func (x *TlsCertAuthorityBuilder) VDLRead(dec vdl.Decoder) error { case index == -1: return dec.FinishValue() } - if decType != __VDLType_struct_4 { - index = __VDLType_struct_4.FieldIndexByName(decType.Field(index).Name) + if decType != vdlTypeStruct5 { + index = vdlTypeStruct5.FieldIndexByName(decType.Field(index).Name) if index == -1 { if err := dec.SkipValue(); err != nil { return err @@ -443,15 +519,15 @@ func (x *TlsCertAuthorityBuilder) VDLRead(dec vdl.Decoder) error { x.CommonName = value } case 3: - if err := __VDLReadAnon_list_1(dec, &x.San); err != nil { + if err := vdlReadAnonList1(dec, &x.San); err != nil { return err } } } } -func __VDLReadAnon_list_1(dec vdl.Decoder, x *[]string) error { - if err := dec.StartValue(__VDLType_list_5); err != nil { +func vdlReadAnonList1(dec vdl.Decoder, x *[]string) error { + if err := dec.StartValue(vdlTypeList6); err != nil { return err } if len := dec.LenHint(); len > 0 { @@ -471,6 +547,183 @@ func __VDLReadAnon_list_1(dec vdl.Decoder, x *[]string) error { } } +type SshCertAuthorityBuilder struct { + // ssh-encoded private key of the Certificate Authority. + CaPrivateKey string + // ssh-encoded Certificate + CaCertificate string + // ssh-encoded Public key that will be signed to create the certificate. + PublicKey string + // Additional SSH Cert options like + // permit-X11-forwarding + // permit-agent-forwarding + // permit-port-forwarding + // permit-pty + // permit-user-rc + ExtensionsOptions []string + // Additional SSH Options that are required to be valid/accepted + CriticalOptions []string + // The Usernames that this key can connect as - defaults as + // ubuntu + // core + // ec2-user + Principals []string + // TTL for the generated cert - user cert < 60 ; host cert < 2628000 (5 years) + TtlMin int32 +} + +func (SshCertAuthorityBuilder) VDLReflect(struct { + Name string `vdl:"github.com/grailbio/base/security/ticket.SshCertAuthorityBuilder"` +}) { +} + +func (x SshCertAuthorityBuilder) VDLIsZero() bool { //nolint:gocyclo + if x.CaPrivateKey != "" { + return false + } + if x.CaCertificate != "" { + return false + } + if x.PublicKey != "" { + return false + } + if len(x.ExtensionsOptions) != 0 { + return false + } + if len(x.CriticalOptions) != 0 { + return false + } + if len(x.Principals) != 0 { + return false + } + if x.TtlMin != 0 { + return false + } + return true +} + +func (x SshCertAuthorityBuilder) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeStruct7); err != nil { + return err + } + if x.CaPrivateKey != "" { + if err := enc.NextFieldValueString(0, vdl.StringType, x.CaPrivateKey); err != nil { + return err + } + } + if x.CaCertificate != "" { + if err := enc.NextFieldValueString(1, vdl.StringType, x.CaCertificate); err != nil { + return err + } + } + if x.PublicKey != "" { + if err := enc.NextFieldValueString(2, vdl.StringType, x.PublicKey); err != nil { + return err + } + } + if len(x.ExtensionsOptions) != 0 { + if err := enc.NextField(3); err != nil { + return err + } + if err := vdlWriteAnonList1(enc, x.ExtensionsOptions); err != nil { + return err + } + } + if len(x.CriticalOptions) != 0 { + if err := enc.NextField(4); err != nil { + return err + } + if err := vdlWriteAnonList1(enc, x.CriticalOptions); err != nil { + return err + } + } + if len(x.Principals) != 0 { + if err := enc.NextField(5); err != nil { + return err + } + if err := vdlWriteAnonList1(enc, x.Principals); err != nil { + return err + } + } + if x.TtlMin != 0 { + if err := enc.NextFieldValueInt(6, vdl.Int32Type, int64(x.TtlMin)); err != nil { + return err + } + } + if err := enc.NextField(-1); err != nil { + return err + } + return enc.FinishValue() +} + +func (x *SshCertAuthorityBuilder) VDLRead(dec vdl.Decoder) error { //nolint:gocyclo + *x = SshCertAuthorityBuilder{} + if err := dec.StartValue(vdlTypeStruct7); err != nil { + return err + } + decType := dec.Type() + for { + index, err := dec.NextField() + switch { + case err != nil: + return err + case index == -1: + return dec.FinishValue() + } + if decType != vdlTypeStruct7 { + index = vdlTypeStruct7.FieldIndexByName(decType.Field(index).Name) + if index == -1 { + if err := dec.SkipValue(); err != nil { + return err + } + continue + } + } + switch index { + case 0: + switch value, err := dec.ReadValueString(); { + case err != nil: + return err + default: + x.CaPrivateKey = value + } + case 1: + switch value, err := dec.ReadValueString(); { + case err != nil: + return err + default: + x.CaCertificate = value + } + case 2: + switch value, err := dec.ReadValueString(); { + case err != nil: + return err + default: + x.PublicKey = value + } + case 3: + if err := vdlReadAnonList1(dec, &x.ExtensionsOptions); err != nil { + return err + } + case 4: + if err := vdlReadAnonList1(dec, &x.CriticalOptions); err != nil { + return err + } + case 5: + if err := vdlReadAnonList1(dec, &x.Principals); err != nil { + return err + } + case 6: + switch value, err := dec.ReadValueInt(32); { + case err != nil: + return err + default: + x.TtlMin = int32(value) + } + } + } +} + // B2AccountAuthorizationBuilder describes the information required to // obtain a B2 account authorization. type B2AccountAuthorizationBuilder struct { @@ -483,12 +736,12 @@ func (B2AccountAuthorizationBuilder) VDLReflect(struct { }) { } -func (x B2AccountAuthorizationBuilder) VDLIsZero() bool { +func (x B2AccountAuthorizationBuilder) VDLIsZero() bool { //nolint:gocyclo return x == B2AccountAuthorizationBuilder{} } -func (x B2AccountAuthorizationBuilder) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_struct_6); err != nil { +func (x B2AccountAuthorizationBuilder) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeStruct8); err != nil { return err } if x.AccountId != "" { @@ -507,9 +760,9 @@ func (x B2AccountAuthorizationBuilder) VDLWrite(enc vdl.Encoder) error { return enc.FinishValue() } -func (x *B2AccountAuthorizationBuilder) VDLRead(dec vdl.Decoder) error { +func (x *B2AccountAuthorizationBuilder) VDLRead(dec vdl.Decoder) error { //nolint:gocyclo *x = B2AccountAuthorizationBuilder{} - if err := dec.StartValue(__VDLType_struct_6); err != nil { + if err := dec.StartValue(vdlTypeStruct8); err != nil { return err } decType := dec.Type() @@ -521,8 +774,8 @@ func (x *B2AccountAuthorizationBuilder) VDLRead(dec vdl.Decoder) error { case index == -1: return dec.FinishValue() } - if decType != __VDLType_struct_6 { - index = __VDLType_struct_6.FieldIndexByName(decType.Field(index).Name) + if decType != vdlTypeStruct8 { + index = vdlTypeStruct8.FieldIndexByName(decType.Field(index).Name) if index == -1 { if err := dec.SkipValue(); err != nil { return err @@ -560,12 +813,12 @@ func (VanadiumBuilder) VDLReflect(struct { }) { } -func (x VanadiumBuilder) VDLIsZero() bool { +func (x VanadiumBuilder) VDLIsZero() bool { //nolint:gocyclo return x == VanadiumBuilder{} } -func (x VanadiumBuilder) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_struct_7); err != nil { +func (x VanadiumBuilder) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeStruct9); err != nil { return err } if x.BlessingName != "" { @@ -579,9 +832,9 @@ func (x VanadiumBuilder) VDLWrite(enc vdl.Encoder) error { return enc.FinishValue() } -func (x *VanadiumBuilder) VDLRead(dec vdl.Decoder) error { +func (x *VanadiumBuilder) VDLRead(dec vdl.Decoder) error { //nolint:gocyclo *x = VanadiumBuilder{} - if err := dec.StartValue(__VDLType_struct_7); err != nil { + if err := dec.StartValue(vdlTypeStruct9); err != nil { return err } decType := dec.Type() @@ -593,8 +846,8 @@ func (x *VanadiumBuilder) VDLRead(dec vdl.Decoder) error { case index == -1: return dec.FinishValue() } - if decType != __VDLType_struct_7 { - index = __VDLType_struct_7.FieldIndexByName(decType.Field(index).Name) + if decType != vdlTypeStruct9 { + index = vdlTypeStruct9.FieldIndexByName(decType.Field(index).Name) if index == -1 { if err := dec.SkipValue(); err != nil { return err @@ -602,8 +855,8 @@ func (x *VanadiumBuilder) VDLRead(dec vdl.Decoder) error { continue } } - switch index { - case 0: + if index == 0 { + switch value, err := dec.ReadValueString(); { case err != nil: return err @@ -627,12 +880,12 @@ func (AwsTicket) VDLReflect(struct { }) { } -func (x AwsTicket) VDLIsZero() bool { +func (x AwsTicket) VDLIsZero() bool { //nolint:gocyclo return x == AwsTicket{} } -func (x AwsTicket) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_struct_8); err != nil { +func (x AwsTicket) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeStruct10); err != nil { return err } if x.AwsAssumeRoleBuilder != nil { @@ -667,9 +920,9 @@ func (x AwsTicket) VDLWrite(enc vdl.Encoder) error { return enc.FinishValue() } -func (x *AwsTicket) VDLRead(dec vdl.Decoder) error { +func (x *AwsTicket) VDLRead(dec vdl.Decoder) error { //nolint:gocyclo *x = AwsTicket{} - if err := dec.StartValue(__VDLType_struct_8); err != nil { + if err := dec.StartValue(vdlTypeStruct10); err != nil { return err } decType := dec.Type() @@ -681,8 +934,8 @@ func (x *AwsTicket) VDLRead(dec vdl.Decoder) error { case index == -1: return dec.FinishValue() } - if decType != __VDLType_struct_8 { - index = __VDLType_struct_8.FieldIndexByName(decType.Field(index).Name) + if decType != vdlTypeStruct10 { + index = vdlTypeStruct10.FieldIndexByName(decType.Field(index).Name) if index == -1 { if err := dec.SkipValue(); err != nil { return err @@ -692,7 +945,7 @@ func (x *AwsTicket) VDLRead(dec vdl.Decoder) error { } switch index { case 0: - if err := dec.StartValue(__VDLType_optional_9); err != nil { + if err := dec.StartValue(vdlTypeOptional11); err != nil { return err } if dec.IsNil() { @@ -708,7 +961,7 @@ func (x *AwsTicket) VDLRead(dec vdl.Decoder) error { } } case 1: - if err := dec.StartValue(__VDLType_optional_10); err != nil { + if err := dec.StartValue(vdlTypeOptional12); err != nil { return err } if dec.IsNil() { @@ -746,12 +999,12 @@ func (S3Ticket) VDLReflect(struct { }) { } -func (x S3Ticket) VDLIsZero() bool { +func (x S3Ticket) VDLIsZero() bool { //nolint:gocyclo return x == S3Ticket{} } -func (x S3Ticket) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_struct_11); err != nil { +func (x S3Ticket) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeStruct13); err != nil { return err } if x.AwsAssumeRoleBuilder != nil { @@ -801,9 +1054,9 @@ func (x S3Ticket) VDLWrite(enc vdl.Encoder) error { return enc.FinishValue() } -func (x *S3Ticket) VDLRead(dec vdl.Decoder) error { +func (x *S3Ticket) VDLRead(dec vdl.Decoder) error { //nolint:gocyclo *x = S3Ticket{} - if err := dec.StartValue(__VDLType_struct_11); err != nil { + if err := dec.StartValue(vdlTypeStruct13); err != nil { return err } decType := dec.Type() @@ -815,8 +1068,8 @@ func (x *S3Ticket) VDLRead(dec vdl.Decoder) error { case index == -1: return dec.FinishValue() } - if decType != __VDLType_struct_11 { - index = __VDLType_struct_11.FieldIndexByName(decType.Field(index).Name) + if decType != vdlTypeStruct13 { + index = vdlTypeStruct13.FieldIndexByName(decType.Field(index).Name) if index == -1 { if err := dec.SkipValue(); err != nil { return err @@ -826,7 +1079,7 @@ func (x *S3Ticket) VDLRead(dec vdl.Decoder) error { } switch index { case 0: - if err := dec.StartValue(__VDLType_optional_9); err != nil { + if err := dec.StartValue(vdlTypeOptional11); err != nil { return err } if dec.IsNil() { @@ -842,7 +1095,7 @@ func (x *S3Ticket) VDLRead(dec vdl.Decoder) error { } } case 1: - if err := dec.StartValue(__VDLType_optional_10); err != nil { + if err := dec.StartValue(vdlTypeOptional12); err != nil { return err } if dec.IsNil() { @@ -904,12 +1157,12 @@ func (EcrTicket) VDLReflect(struct { }) { } -func (x EcrTicket) VDLIsZero() bool { +func (x EcrTicket) VDLIsZero() bool { //nolint:gocyclo return x == EcrTicket{} } -func (x EcrTicket) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_struct_12); err != nil { +func (x EcrTicket) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeStruct14); err != nil { return err } if x.AwsAssumeRoleBuilder != nil { @@ -942,9 +1195,9 @@ func (x EcrTicket) VDLWrite(enc vdl.Encoder) error { return enc.FinishValue() } -func (x *EcrTicket) VDLRead(dec vdl.Decoder) error { +func (x *EcrTicket) VDLRead(dec vdl.Decoder) error { //nolint:gocyclo *x = EcrTicket{} - if err := dec.StartValue(__VDLType_struct_12); err != nil { + if err := dec.StartValue(vdlTypeStruct14); err != nil { return err } decType := dec.Type() @@ -956,8 +1209,8 @@ func (x *EcrTicket) VDLRead(dec vdl.Decoder) error { case index == -1: return dec.FinishValue() } - if decType != __VDLType_struct_12 { - index = __VDLType_struct_12.FieldIndexByName(decType.Field(index).Name) + if decType != vdlTypeStruct14 { + index = vdlTypeStruct14.FieldIndexByName(decType.Field(index).Name) if index == -1 { if err := dec.SkipValue(); err != nil { return err @@ -967,7 +1220,7 @@ func (x *EcrTicket) VDLRead(dec vdl.Decoder) error { } switch index { case 0: - if err := dec.StartValue(__VDLType_optional_9); err != nil { + if err := dec.StartValue(vdlTypeOptional11); err != nil { return err } if dec.IsNil() { @@ -1007,7 +1260,72 @@ func (x *EcrTicket) VDLRead(dec vdl.Decoder) error { } } -// TlsCredentials describes a generic set of TLS credentials that include: +// SshCert describes a ssh public Certifcate +type SshCert struct { + // ssh-encoded certificate (host or user). + Cert string +} + +func (SshCert) VDLReflect(struct { + Name string `vdl:"github.com/grailbio/base/security/ticket.SshCert"` +}) { +} + +func (x SshCert) VDLIsZero() bool { //nolint:gocyclo + return x == SshCert{} +} + +func (x SshCert) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeStruct15); err != nil { + return err + } + if x.Cert != "" { + if err := enc.NextFieldValueString(0, vdl.StringType, x.Cert); err != nil { + return err + } + } + if err := enc.NextField(-1); err != nil { + return err + } + return enc.FinishValue() +} + +func (x *SshCert) VDLRead(dec vdl.Decoder) error { //nolint:gocyclo + *x = SshCert{} + if err := dec.StartValue(vdlTypeStruct15); err != nil { + return err + } + decType := dec.Type() + for { + index, err := dec.NextField() + switch { + case err != nil: + return err + case index == -1: + return dec.FinishValue() + } + if decType != vdlTypeStruct15 { + index = vdlTypeStruct15.FieldIndexByName(decType.Field(index).Name) + if index == -1 { + if err := dec.SkipValue(); err != nil { + return err + } + continue + } + } + if index == 0 { + + switch value, err := dec.ReadValueString(); { + case err != nil: + return err + default: + x.Cert = value + } + } + } +} + +// TlsCredentials describes a generic set of Tls credentials that include: // the CA that accepted by the client (only peers that present a certificate // sign by this CA are accepted), the client certificate and the client // private key. @@ -1025,12 +1343,12 @@ func (TlsCredentials) VDLReflect(struct { }) { } -func (x TlsCredentials) VDLIsZero() bool { +func (x TlsCredentials) VDLIsZero() bool { //nolint:gocyclo return x == TlsCredentials{} } -func (x TlsCredentials) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_struct_13); err != nil { +func (x TlsCredentials) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeStruct16); err != nil { return err } if x.AuthorityCert != "" { @@ -1054,9 +1372,9 @@ func (x TlsCredentials) VDLWrite(enc vdl.Encoder) error { return enc.FinishValue() } -func (x *TlsCredentials) VDLRead(dec vdl.Decoder) error { +func (x *TlsCredentials) VDLRead(dec vdl.Decoder) error { //nolint:gocyclo *x = TlsCredentials{} - if err := dec.StartValue(__VDLType_struct_13); err != nil { + if err := dec.StartValue(vdlTypeStruct16); err != nil { return err } decType := dec.Type() @@ -1068,8 +1386,8 @@ func (x *TlsCredentials) VDLRead(dec vdl.Decoder) error { case index == -1: return dec.FinishValue() } - if decType != __VDLType_struct_13 { - index = __VDLType_struct_13.FieldIndexByName(decType.Field(index).Name) + if decType != vdlTypeStruct16 { + index = vdlTypeStruct16.FieldIndexByName(decType.Field(index).Name) if index == -1 { if err := dec.SkipValue(); err != nil { return err @@ -1114,12 +1432,12 @@ func (TlsServerTicket) VDLReflect(struct { }) { } -func (x TlsServerTicket) VDLIsZero() bool { +func (x TlsServerTicket) VDLIsZero() bool { //nolint:gocyclo return x == TlsServerTicket{} } -func (x TlsServerTicket) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_struct_14); err != nil { +func (x TlsServerTicket) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeStruct17); err != nil { return err } if x.TlsCertAuthorityBuilder != nil { @@ -1145,9 +1463,9 @@ func (x TlsServerTicket) VDLWrite(enc vdl.Encoder) error { return enc.FinishValue() } -func (x *TlsServerTicket) VDLRead(dec vdl.Decoder) error { +func (x *TlsServerTicket) VDLRead(dec vdl.Decoder) error { //nolint:gocyclo *x = TlsServerTicket{} - if err := dec.StartValue(__VDLType_struct_14); err != nil { + if err := dec.StartValue(vdlTypeStruct17); err != nil { return err } decType := dec.Type() @@ -1159,8 +1477,8 @@ func (x *TlsServerTicket) VDLRead(dec vdl.Decoder) error { case index == -1: return dec.FinishValue() } - if decType != __VDLType_struct_14 { - index = __VDLType_struct_14.FieldIndexByName(decType.Field(index).Name) + if decType != vdlTypeStruct17 { + index = vdlTypeStruct17.FieldIndexByName(decType.Field(index).Name) if index == -1 { if err := dec.SkipValue(); err != nil { return err @@ -1170,7 +1488,7 @@ func (x *TlsServerTicket) VDLRead(dec vdl.Decoder) error { } switch index { case 0: - if err := dec.StartValue(__VDLType_optional_15); err != nil { + if err := dec.StartValue(vdlTypeOptional18); err != nil { return err } if dec.IsNil() { @@ -1197,7 +1515,8 @@ func (x *TlsServerTicket) VDLRead(dec vdl.Decoder) error { type TlsClientTicket struct { TlsCertAuthorityBuilder *TlsCertAuthorityBuilder Credentials TlsCredentials - Endpoints []string + // Endpoints indicate the servers the client can connect to. + Endpoints []string } func (TlsClientTicket) VDLReflect(struct { @@ -1205,7 +1524,7 @@ func (TlsClientTicket) VDLReflect(struct { }) { } -func (x TlsClientTicket) VDLIsZero() bool { +func (x TlsClientTicket) VDLIsZero() bool { //nolint:gocyclo if x.TlsCertAuthorityBuilder != nil { return false } @@ -1218,8 +1537,8 @@ func (x TlsClientTicket) VDLIsZero() bool { return true } -func (x TlsClientTicket) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_struct_16); err != nil { +func (x TlsClientTicket) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeStruct19); err != nil { return err } if x.TlsCertAuthorityBuilder != nil { @@ -1243,7 +1562,7 @@ func (x TlsClientTicket) VDLWrite(enc vdl.Encoder) error { if err := enc.NextField(2); err != nil { return err } - if err := __VDLWriteAnon_list_1(enc, x.Endpoints); err != nil { + if err := vdlWriteAnonList1(enc, x.Endpoints); err != nil { return err } } @@ -1253,9 +1572,9 @@ func (x TlsClientTicket) VDLWrite(enc vdl.Encoder) error { return enc.FinishValue() } -func (x *TlsClientTicket) VDLRead(dec vdl.Decoder) error { +func (x *TlsClientTicket) VDLRead(dec vdl.Decoder) error { //nolint:gocyclo *x = TlsClientTicket{} - if err := dec.StartValue(__VDLType_struct_16); err != nil { + if err := dec.StartValue(vdlTypeStruct19); err != nil { return err } decType := dec.Type() @@ -1267,8 +1586,8 @@ func (x *TlsClientTicket) VDLRead(dec vdl.Decoder) error { case index == -1: return dec.FinishValue() } - if decType != __VDLType_struct_16 { - index = __VDLType_struct_16.FieldIndexByName(decType.Field(index).Name) + if decType != vdlTypeStruct19 { + index = vdlTypeStruct19.FieldIndexByName(decType.Field(index).Name) if index == -1 { if err := dec.SkipValue(); err != nil { return err @@ -1278,7 +1597,7 @@ func (x *TlsClientTicket) VDLRead(dec vdl.Decoder) error { } switch index { case 0: - if err := dec.StartValue(__VDLType_optional_15); err != nil { + if err := dec.StartValue(vdlTypeOptional18); err != nil { return err } if dec.IsNil() { @@ -1298,7 +1617,7 @@ func (x *TlsClientTicket) VDLRead(dec vdl.Decoder) error { return err } case 2: - if err := __VDLReadAnon_list_1(dec, &x.Endpoints); err != nil { + if err := vdlReadAnonList1(dec, &x.Endpoints); err != nil { return err } } @@ -1320,12 +1639,12 @@ func (DockerTicket) VDLReflect(struct { }) { } -func (x DockerTicket) VDLIsZero() bool { +func (x DockerTicket) VDLIsZero() bool { //nolint:gocyclo return x == DockerTicket{} } -func (x DockerTicket) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_struct_17); err != nil { +func (x DockerTicket) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeStruct20); err != nil { return err } if x.TlsCertAuthorityBuilder != nil { @@ -1356,9 +1675,9 @@ func (x DockerTicket) VDLWrite(enc vdl.Encoder) error { return enc.FinishValue() } -func (x *DockerTicket) VDLRead(dec vdl.Decoder) error { +func (x *DockerTicket) VDLRead(dec vdl.Decoder) error { //nolint:gocyclo *x = DockerTicket{} - if err := dec.StartValue(__VDLType_struct_17); err != nil { + if err := dec.StartValue(vdlTypeStruct20); err != nil { return err } decType := dec.Type() @@ -1370,8 +1689,8 @@ func (x *DockerTicket) VDLRead(dec vdl.Decoder) error { case index == -1: return dec.FinishValue() } - if decType != __VDLType_struct_17 { - index = __VDLType_struct_17.FieldIndexByName(decType.Field(index).Name) + if decType != vdlTypeStruct20 { + index = vdlTypeStruct20.FieldIndexByName(decType.Field(index).Name) if index == -1 { if err := dec.SkipValue(); err != nil { return err @@ -1381,7 +1700,7 @@ func (x *DockerTicket) VDLRead(dec vdl.Decoder) error { } switch index { case 0: - if err := dec.StartValue(__VDLType_optional_15); err != nil { + if err := dec.StartValue(vdlTypeOptional18); err != nil { return err } if dec.IsNil() { @@ -1423,12 +1742,12 @@ func (DockerServerTicket) VDLReflect(struct { }) { } -func (x DockerServerTicket) VDLIsZero() bool { +func (x DockerServerTicket) VDLIsZero() bool { //nolint:gocyclo return x == DockerServerTicket{} } -func (x DockerServerTicket) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_struct_18); err != nil { +func (x DockerServerTicket) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeStruct21); err != nil { return err } if x.TlsCertAuthorityBuilder != nil { @@ -1454,9 +1773,9 @@ func (x DockerServerTicket) VDLWrite(enc vdl.Encoder) error { return enc.FinishValue() } -func (x *DockerServerTicket) VDLRead(dec vdl.Decoder) error { +func (x *DockerServerTicket) VDLRead(dec vdl.Decoder) error { //nolint:gocyclo *x = DockerServerTicket{} - if err := dec.StartValue(__VDLType_struct_18); err != nil { + if err := dec.StartValue(vdlTypeStruct21); err != nil { return err } decType := dec.Type() @@ -1468,8 +1787,8 @@ func (x *DockerServerTicket) VDLRead(dec vdl.Decoder) error { case index == -1: return dec.FinishValue() } - if decType != __VDLType_struct_18 { - index = __VDLType_struct_18.FieldIndexByName(decType.Field(index).Name) + if decType != vdlTypeStruct21 { + index = vdlTypeStruct21.FieldIndexByName(decType.Field(index).Name) if index == -1 { if err := dec.SkipValue(); err != nil { return err @@ -1479,7 +1798,7 @@ func (x *DockerServerTicket) VDLRead(dec vdl.Decoder) error { } switch index { case 0: - if err := dec.StartValue(__VDLType_optional_15); err != nil { + if err := dec.StartValue(vdlTypeOptional18); err != nil { return err } if dec.IsNil() { @@ -1507,7 +1826,8 @@ func (x *DockerServerTicket) VDLRead(dec vdl.Decoder) error { type DockerClientTicket struct { TlsCertAuthorityBuilder *TlsCertAuthorityBuilder Credentials TlsCredentials - Url string + // Url indicates the Docker host the client can connect to. + Url string } func (DockerClientTicket) VDLReflect(struct { @@ -1515,12 +1835,12 @@ func (DockerClientTicket) VDLReflect(struct { }) { } -func (x DockerClientTicket) VDLIsZero() bool { +func (x DockerClientTicket) VDLIsZero() bool { //nolint:gocyclo return x == DockerClientTicket{} } -func (x DockerClientTicket) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_struct_19); err != nil { +func (x DockerClientTicket) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeStruct22); err != nil { return err } if x.TlsCertAuthorityBuilder != nil { @@ -1551,9 +1871,9 @@ func (x DockerClientTicket) VDLWrite(enc vdl.Encoder) error { return enc.FinishValue() } -func (x *DockerClientTicket) VDLRead(dec vdl.Decoder) error { +func (x *DockerClientTicket) VDLRead(dec vdl.Decoder) error { //nolint:gocyclo *x = DockerClientTicket{} - if err := dec.StartValue(__VDLType_struct_19); err != nil { + if err := dec.StartValue(vdlTypeStruct22); err != nil { return err } decType := dec.Type() @@ -1565,8 +1885,8 @@ func (x *DockerClientTicket) VDLRead(dec vdl.Decoder) error { case index == -1: return dec.FinishValue() } - if decType != __VDLType_struct_19 { - index = __VDLType_struct_19.FieldIndexByName(decType.Field(index).Name) + if decType != vdlTypeStruct22 { + index = vdlTypeStruct22.FieldIndexByName(decType.Field(index).Name) if index == -1 { if err := dec.SkipValue(); err != nil { return err @@ -1576,7 +1896,7 @@ func (x *DockerClientTicket) VDLRead(dec vdl.Decoder) error { } switch index { case 0: - if err := dec.StartValue(__VDLType_optional_15); err != nil { + if err := dec.StartValue(vdlTypeOptional18); err != nil { return err } if dec.IsNil() { @@ -1606,67 +1926,32 @@ func (x *DockerClientTicket) VDLRead(dec vdl.Decoder) error { } } -// B2Ticket instance contains a B2 account level authorization -// token plus URLs and configuration values for the account. -type B2Ticket struct { - B2AccountAuthorizationBuilder *B2AccountAuthorizationBuilder - AccountId string - AuthorizationToken string - ApiUrl string - DownloadUrl string - RecommendedPartSize int64 - AbsoluteMinimumPartSize int64 +// Key/Value pair that can be passed into the GET request. +type Parameter struct { + Key string + Value string } -func (B2Ticket) VDLReflect(struct { - Name string `vdl:"github.com/grailbio/base/security/ticket.B2Ticket"` +func (Parameter) VDLReflect(struct { + Name string `vdl:"github.com/grailbio/base/security/ticket.Parameter"` }) { } -func (x B2Ticket) VDLIsZero() bool { - return x == B2Ticket{} +func (x Parameter) VDLIsZero() bool { //nolint:gocyclo + return x == Parameter{} } -func (x B2Ticket) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_struct_20); err != nil { +func (x Parameter) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeStruct23); err != nil { return err } - if x.B2AccountAuthorizationBuilder != nil { - if err := enc.NextField(0); err != nil { - return err - } - enc.SetNextStartValueIsOptional() - if err := x.B2AccountAuthorizationBuilder.VDLWrite(enc); err != nil { - return err - } - } - if x.AccountId != "" { - if err := enc.NextFieldValueString(1, vdl.StringType, x.AccountId); err != nil { - return err - } - } - if x.AuthorizationToken != "" { - if err := enc.NextFieldValueString(2, vdl.StringType, x.AuthorizationToken); err != nil { - return err - } - } - if x.ApiUrl != "" { - if err := enc.NextFieldValueString(3, vdl.StringType, x.ApiUrl); err != nil { - return err - } - } - if x.DownloadUrl != "" { - if err := enc.NextFieldValueString(4, vdl.StringType, x.DownloadUrl); err != nil { - return err - } - } - if x.RecommendedPartSize != 0 { - if err := enc.NextFieldValueInt(5, vdl.Int64Type, x.RecommendedPartSize); err != nil { + if x.Key != "" { + if err := enc.NextFieldValueString(0, vdl.StringType, x.Key); err != nil { return err } } - if x.AbsoluteMinimumPartSize != 0 { - if err := enc.NextFieldValueInt(6, vdl.Int64Type, x.AbsoluteMinimumPartSize); err != nil { + if x.Value != "" { + if err := enc.NextFieldValueString(1, vdl.StringType, x.Value); err != nil { return err } } @@ -1676,9 +1961,9 @@ func (x B2Ticket) VDLWrite(enc vdl.Encoder) error { return enc.FinishValue() } -func (x *B2Ticket) VDLRead(dec vdl.Decoder) error { - *x = B2Ticket{} - if err := dec.StartValue(__VDLType_struct_20); err != nil { +func (x *Parameter) VDLRead(dec vdl.Decoder) error { //nolint:gocyclo + *x = Parameter{} + if err := dec.StartValue(vdlTypeStruct23); err != nil { return err } decType := dec.Type() @@ -1690,8 +1975,8 @@ func (x *B2Ticket) VDLRead(dec vdl.Decoder) error { case index == -1: return dec.FinishValue() } - if decType != __VDLType_struct_20 { - index = __VDLType_struct_20.FieldIndexByName(decType.Field(index).Name) + if decType != vdlTypeStruct23 { + index = vdlTypeStruct23.FieldIndexByName(decType.Field(index).Name) if index == -1 { if err := dec.SkipValue(); err != nil { return err @@ -1701,17 +1986,598 @@ func (x *B2Ticket) VDLRead(dec vdl.Decoder) error { } switch index { case 0: - if err := dec.StartValue(__VDLType_optional_21); err != nil { + switch value, err := dec.ReadValueString(); { + case err != nil: return err + default: + x.Key = value } - if dec.IsNil() { - x.B2AccountAuthorizationBuilder = nil - if err := dec.FinishValue(); err != nil { - return err - } - } else { - x.B2AccountAuthorizationBuilder = new(B2AccountAuthorizationBuilder) - dec.IgnoreNextStartValue() + case 1: + switch value, err := dec.ReadValueString(); { + case err != nil: + return err + default: + x.Value = value + } + } + } +} + +type AwsComputeInstancesBuilder struct { + // Instance Filters that will produce a list of instance IDs and related information + // https://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_DescribeInstances.html + InstanceFilters []Parameter + // AWS ARN for a role that should be used to perform the instance lookups + AwsAccountLookupRole string + // AWS region to use for the lookup + Region string +} + +func (AwsComputeInstancesBuilder) VDLReflect(struct { + Name string `vdl:"github.com/grailbio/base/security/ticket.AwsComputeInstancesBuilder"` +}) { +} + +func (x AwsComputeInstancesBuilder) VDLIsZero() bool { //nolint:gocyclo + if len(x.InstanceFilters) != 0 { + return false + } + if x.AwsAccountLookupRole != "" { + return false + } + if x.Region != "" { + return false + } + return true +} + +func (x AwsComputeInstancesBuilder) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeStruct24); err != nil { + return err + } + if len(x.InstanceFilters) != 0 { + if err := enc.NextField(0); err != nil { + return err + } + if err := vdlWriteAnonList2(enc, x.InstanceFilters); err != nil { + return err + } + } + if x.AwsAccountLookupRole != "" { + if err := enc.NextFieldValueString(1, vdl.StringType, x.AwsAccountLookupRole); err != nil { + return err + } + } + if x.Region != "" { + if err := enc.NextFieldValueString(2, vdl.StringType, x.Region); err != nil { + return err + } + } + if err := enc.NextField(-1); err != nil { + return err + } + return enc.FinishValue() +} + +func vdlWriteAnonList2(enc vdl.Encoder, x []Parameter) error { + if err := enc.StartValue(vdlTypeList25); err != nil { + return err + } + if err := enc.SetLenHint(len(x)); err != nil { + return err + } + for _, elem := range x { + if err := enc.NextEntry(false); err != nil { + return err + } + if err := elem.VDLWrite(enc); err != nil { + return err + } + } + if err := enc.NextEntry(true); err != nil { + return err + } + return enc.FinishValue() +} + +func (x *AwsComputeInstancesBuilder) VDLRead(dec vdl.Decoder) error { //nolint:gocyclo + *x = AwsComputeInstancesBuilder{} + if err := dec.StartValue(vdlTypeStruct24); err != nil { + return err + } + decType := dec.Type() + for { + index, err := dec.NextField() + switch { + case err != nil: + return err + case index == -1: + return dec.FinishValue() + } + if decType != vdlTypeStruct24 { + index = vdlTypeStruct24.FieldIndexByName(decType.Field(index).Name) + if index == -1 { + if err := dec.SkipValue(); err != nil { + return err + } + continue + } + } + switch index { + case 0: + if err := vdlReadAnonList2(dec, &x.InstanceFilters); err != nil { + return err + } + case 1: + switch value, err := dec.ReadValueString(); { + case err != nil: + return err + default: + x.AwsAccountLookupRole = value + } + case 2: + switch value, err := dec.ReadValueString(); { + case err != nil: + return err + default: + x.Region = value + } + } + } +} + +func vdlReadAnonList2(dec vdl.Decoder, x *[]Parameter) error { + if err := dec.StartValue(vdlTypeList25); err != nil { + return err + } + if len := dec.LenHint(); len > 0 { + *x = make([]Parameter, 0, len) + } else { + *x = nil + } + for { + switch done, err := dec.NextEntry(); { + case err != nil: + return err + case done: + return dec.FinishValue() + default: + var elem Parameter + if err := elem.VDLRead(dec); err != nil { + return err + } + *x = append(*x, elem) + } + } +} + +// Simplification of describeInstance data to provide to ticket-server users +type ComputeInstance struct { + PublicIp string + PrivateIp string + InstanceId string + Tags []Parameter +} + +func (ComputeInstance) VDLReflect(struct { + Name string `vdl:"github.com/grailbio/base/security/ticket.ComputeInstance"` +}) { +} + +func (x ComputeInstance) VDLIsZero() bool { //nolint:gocyclo + if x.PublicIp != "" { + return false + } + if x.PrivateIp != "" { + return false + } + if x.InstanceId != "" { + return false + } + if len(x.Tags) != 0 { + return false + } + return true +} + +func (x ComputeInstance) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeStruct26); err != nil { + return err + } + if x.PublicIp != "" { + if err := enc.NextFieldValueString(0, vdl.StringType, x.PublicIp); err != nil { + return err + } + } + if x.PrivateIp != "" { + if err := enc.NextFieldValueString(1, vdl.StringType, x.PrivateIp); err != nil { + return err + } + } + if x.InstanceId != "" { + if err := enc.NextFieldValueString(2, vdl.StringType, x.InstanceId); err != nil { + return err + } + } + if len(x.Tags) != 0 { + if err := enc.NextField(3); err != nil { + return err + } + if err := vdlWriteAnonList2(enc, x.Tags); err != nil { + return err + } + } + if err := enc.NextField(-1); err != nil { + return err + } + return enc.FinishValue() +} + +func (x *ComputeInstance) VDLRead(dec vdl.Decoder) error { //nolint:gocyclo + *x = ComputeInstance{} + if err := dec.StartValue(vdlTypeStruct26); err != nil { + return err + } + decType := dec.Type() + for { + index, err := dec.NextField() + switch { + case err != nil: + return err + case index == -1: + return dec.FinishValue() + } + if decType != vdlTypeStruct26 { + index = vdlTypeStruct26.FieldIndexByName(decType.Field(index).Name) + if index == -1 { + if err := dec.SkipValue(); err != nil { + return err + } + continue + } + } + switch index { + case 0: + switch value, err := dec.ReadValueString(); { + case err != nil: + return err + default: + x.PublicIp = value + } + case 1: + switch value, err := dec.ReadValueString(); { + case err != nil: + return err + default: + x.PrivateIp = value + } + case 2: + switch value, err := dec.ReadValueString(); { + case err != nil: + return err + default: + x.InstanceId = value + } + case 3: + if err := vdlReadAnonList2(dec, &x.Tags); err != nil { + return err + } + } + } +} + +// SshCertificateTicket describes a SSH Signed Certificate. +// SSH Certificates are essentially a version of TLS certs but they have additional +// optional parameters and can take a public key as part of their signing request. +type SshCertificateTicket struct { + SshCertAuthorityBuilder *SshCertAuthorityBuilder + AwsComputeInstancesBuilder *AwsComputeInstancesBuilder + ComputeInstances []ComputeInstance + Credentials SshCert + // Recommended username to use + Username string +} + +func (SshCertificateTicket) VDLReflect(struct { + Name string `vdl:"github.com/grailbio/base/security/ticket.SshCertificateTicket"` +}) { +} + +func (x SshCertificateTicket) VDLIsZero() bool { //nolint:gocyclo + if x.SshCertAuthorityBuilder != nil { + return false + } + if x.AwsComputeInstancesBuilder != nil { + return false + } + if len(x.ComputeInstances) != 0 { + return false + } + if x.Credentials != (SshCert{}) { + return false + } + if x.Username != "" { + return false + } + return true +} + +func (x SshCertificateTicket) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeStruct27); err != nil { + return err + } + if x.SshCertAuthorityBuilder != nil { + if err := enc.NextField(0); err != nil { + return err + } + enc.SetNextStartValueIsOptional() + if err := x.SshCertAuthorityBuilder.VDLWrite(enc); err != nil { + return err + } + } + if x.AwsComputeInstancesBuilder != nil { + if err := enc.NextField(1); err != nil { + return err + } + enc.SetNextStartValueIsOptional() + if err := x.AwsComputeInstancesBuilder.VDLWrite(enc); err != nil { + return err + } + } + if len(x.ComputeInstances) != 0 { + if err := enc.NextField(2); err != nil { + return err + } + if err := vdlWriteAnonList3(enc, x.ComputeInstances); err != nil { + return err + } + } + if x.Credentials != (SshCert{}) { + if err := enc.NextField(3); err != nil { + return err + } + if err := x.Credentials.VDLWrite(enc); err != nil { + return err + } + } + if x.Username != "" { + if err := enc.NextFieldValueString(4, vdl.StringType, x.Username); err != nil { + return err + } + } + if err := enc.NextField(-1); err != nil { + return err + } + return enc.FinishValue() +} + +func vdlWriteAnonList3(enc vdl.Encoder, x []ComputeInstance) error { + if err := enc.StartValue(vdlTypeList30); err != nil { + return err + } + if err := enc.SetLenHint(len(x)); err != nil { + return err + } + for _, elem := range x { + if err := enc.NextEntry(false); err != nil { + return err + } + if err := elem.VDLWrite(enc); err != nil { + return err + } + } + if err := enc.NextEntry(true); err != nil { + return err + } + return enc.FinishValue() +} + +func (x *SshCertificateTicket) VDLRead(dec vdl.Decoder) error { //nolint:gocyclo + *x = SshCertificateTicket{} + if err := dec.StartValue(vdlTypeStruct27); err != nil { + return err + } + decType := dec.Type() + for { + index, err := dec.NextField() + switch { + case err != nil: + return err + case index == -1: + return dec.FinishValue() + } + if decType != vdlTypeStruct27 { + index = vdlTypeStruct27.FieldIndexByName(decType.Field(index).Name) + if index == -1 { + if err := dec.SkipValue(); err != nil { + return err + } + continue + } + } + switch index { + case 0: + if err := dec.StartValue(vdlTypeOptional28); err != nil { + return err + } + if dec.IsNil() { + x.SshCertAuthorityBuilder = nil + if err := dec.FinishValue(); err != nil { + return err + } + } else { + x.SshCertAuthorityBuilder = new(SshCertAuthorityBuilder) + dec.IgnoreNextStartValue() + if err := x.SshCertAuthorityBuilder.VDLRead(dec); err != nil { + return err + } + } + case 1: + if err := dec.StartValue(vdlTypeOptional29); err != nil { + return err + } + if dec.IsNil() { + x.AwsComputeInstancesBuilder = nil + if err := dec.FinishValue(); err != nil { + return err + } + } else { + x.AwsComputeInstancesBuilder = new(AwsComputeInstancesBuilder) + dec.IgnoreNextStartValue() + if err := x.AwsComputeInstancesBuilder.VDLRead(dec); err != nil { + return err + } + } + case 2: + if err := vdlReadAnonList3(dec, &x.ComputeInstances); err != nil { + return err + } + case 3: + if err := x.Credentials.VDLRead(dec); err != nil { + return err + } + case 4: + switch value, err := dec.ReadValueString(); { + case err != nil: + return err + default: + x.Username = value + } + } + } +} + +func vdlReadAnonList3(dec vdl.Decoder, x *[]ComputeInstance) error { + if err := dec.StartValue(vdlTypeList30); err != nil { + return err + } + if len := dec.LenHint(); len > 0 { + *x = make([]ComputeInstance, 0, len) + } else { + *x = nil + } + for { + switch done, err := dec.NextEntry(); { + case err != nil: + return err + case done: + return dec.FinishValue() + default: + var elem ComputeInstance + if err := elem.VDLRead(dec); err != nil { + return err + } + *x = append(*x, elem) + } + } +} + +// B2Ticket instance contains a B2 account level authorization +// token plus URLs and configuration values for the account. +type B2Ticket struct { + B2AccountAuthorizationBuilder *B2AccountAuthorizationBuilder + AccountId string + AuthorizationToken string + ApiUrl string + DownloadUrl string + RecommendedPartSize int64 + AbsoluteMinimumPartSize int64 +} + +func (B2Ticket) VDLReflect(struct { + Name string `vdl:"github.com/grailbio/base/security/ticket.B2Ticket"` +}) { +} + +func (x B2Ticket) VDLIsZero() bool { //nolint:gocyclo + return x == B2Ticket{} +} + +func (x B2Ticket) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeStruct31); err != nil { + return err + } + if x.B2AccountAuthorizationBuilder != nil { + if err := enc.NextField(0); err != nil { + return err + } + enc.SetNextStartValueIsOptional() + if err := x.B2AccountAuthorizationBuilder.VDLWrite(enc); err != nil { + return err + } + } + if x.AccountId != "" { + if err := enc.NextFieldValueString(1, vdl.StringType, x.AccountId); err != nil { + return err + } + } + if x.AuthorizationToken != "" { + if err := enc.NextFieldValueString(2, vdl.StringType, x.AuthorizationToken); err != nil { + return err + } + } + if x.ApiUrl != "" { + if err := enc.NextFieldValueString(3, vdl.StringType, x.ApiUrl); err != nil { + return err + } + } + if x.DownloadUrl != "" { + if err := enc.NextFieldValueString(4, vdl.StringType, x.DownloadUrl); err != nil { + return err + } + } + if x.RecommendedPartSize != 0 { + if err := enc.NextFieldValueInt(5, vdl.Int64Type, x.RecommendedPartSize); err != nil { + return err + } + } + if x.AbsoluteMinimumPartSize != 0 { + if err := enc.NextFieldValueInt(6, vdl.Int64Type, x.AbsoluteMinimumPartSize); err != nil { + return err + } + } + if err := enc.NextField(-1); err != nil { + return err + } + return enc.FinishValue() +} + +func (x *B2Ticket) VDLRead(dec vdl.Decoder) error { //nolint:gocyclo + *x = B2Ticket{} + if err := dec.StartValue(vdlTypeStruct31); err != nil { + return err + } + decType := dec.Type() + for { + index, err := dec.NextField() + switch { + case err != nil: + return err + case index == -1: + return dec.FinishValue() + } + if decType != vdlTypeStruct31 { + index = vdlTypeStruct31.FieldIndexByName(decType.Field(index).Name) + if index == -1 { + if err := dec.SkipValue(); err != nil { + return err + } + continue + } + } + switch index { + case 0: + if err := dec.StartValue(vdlTypeOptional32); err != nil { + return err + } + if dec.IsNil() { + x.B2AccountAuthorizationBuilder = nil + if err := dec.FinishValue(); err != nil { + return err + } + } else { + x.B2AccountAuthorizationBuilder = new(B2AccountAuthorizationBuilder) + dec.IgnoreNextStartValue() if err := x.B2AccountAuthorizationBuilder.VDLRead(dec); err != nil { return err } @@ -1776,12 +2642,12 @@ func (VanadiumTicket) VDLReflect(struct { }) { } -func (x VanadiumTicket) VDLIsZero() bool { +func (x VanadiumTicket) VDLIsZero() bool { //nolint:gocyclo return x == VanadiumTicket{} } -func (x VanadiumTicket) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_struct_22); err != nil { +func (x VanadiumTicket) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeStruct33); err != nil { return err } if x.VanadiumBuilder != nil { @@ -1804,9 +2670,9 @@ func (x VanadiumTicket) VDLWrite(enc vdl.Encoder) error { return enc.FinishValue() } -func (x *VanadiumTicket) VDLRead(dec vdl.Decoder) error { +func (x *VanadiumTicket) VDLRead(dec vdl.Decoder) error { //nolint:gocyclo *x = VanadiumTicket{} - if err := dec.StartValue(__VDLType_struct_22); err != nil { + if err := dec.StartValue(vdlTypeStruct33); err != nil { return err } decType := dec.Type() @@ -1818,8 +2684,8 @@ func (x *VanadiumTicket) VDLRead(dec vdl.Decoder) error { case index == -1: return dec.FinishValue() } - if decType != __VDLType_struct_22 { - index = __VDLType_struct_22.FieldIndexByName(decType.Field(index).Name) + if decType != vdlTypeStruct33 { + index = vdlTypeStruct33.FieldIndexByName(decType.Field(index).Name) if index == -1 { if err := dec.SkipValue(); err != nil { return err @@ -1829,7 +2695,7 @@ func (x *VanadiumTicket) VDLRead(dec vdl.Decoder) error { } switch index { case 0: - if err := dec.StartValue(__VDLType_optional_23); err != nil { + if err := dec.StartValue(vdlTypeOptional34); err != nil { return err } if dec.IsNil() { @@ -1866,19 +2732,16 @@ func (GenericTicket) VDLReflect(struct { }) { } -func (x GenericTicket) VDLIsZero() bool { - if len(x.Data) != 0 { - return false - } - return true +func (x GenericTicket) VDLIsZero() bool { //nolint:gocyclo + return len(x.Data) == 0 } -func (x GenericTicket) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_struct_24); err != nil { +func (x GenericTicket) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeStruct35); err != nil { return err } if len(x.Data) != 0 { - if err := enc.NextFieldValueBytes(0, __VDLType_list_25, x.Data); err != nil { + if err := enc.NextFieldValueBytes(0, vdlTypeList36, x.Data); err != nil { return err } } @@ -1888,9 +2751,9 @@ func (x GenericTicket) VDLWrite(enc vdl.Encoder) error { return enc.FinishValue() } -func (x *GenericTicket) VDLRead(dec vdl.Decoder) error { +func (x *GenericTicket) VDLRead(dec vdl.Decoder) error { //nolint:gocyclo *x = GenericTicket{} - if err := dec.StartValue(__VDLType_struct_24); err != nil { + if err := dec.StartValue(vdlTypeStruct35); err != nil { return err } decType := dec.Type() @@ -1902,8 +2765,8 @@ func (x *GenericTicket) VDLRead(dec vdl.Decoder) error { case index == -1: return dec.FinishValue() } - if decType != __VDLType_struct_24 { - index = __VDLType_struct_24.FieldIndexByName(decType.Field(index).Name) + if decType != vdlTypeStruct35 { + index = vdlTypeStruct35.FieldIndexByName(decType.Field(index).Name) if index == -1 { if err := dec.SkipValue(); err != nil { return err @@ -1911,8 +2774,8 @@ func (x *GenericTicket) VDLRead(dec vdl.Decoder) error { continue } } - switch index { - case 0: + if index == 0 { + if err := dec.ReadValueBytes(-1, &x.Data); err != nil { return err } @@ -1932,7 +2795,7 @@ type ( // Name returns the field name. Name() string // VDLReflect describes the Ticket union type. - VDLReflect(__TicketReflect) + VDLReflect(vdlTicketReflect) VDLIsZero() bool VDLWrite(vdl.Encoder) error } @@ -1940,6 +2803,8 @@ type ( TicketAwsTicket struct{ Value AwsTicket } // TicketS3Ticket represents field S3Ticket of the Ticket union type. TicketS3Ticket struct{ Value S3Ticket } + // TicketSshCertificateTicket represents field SshCertificateTicket of the Ticket union type. + TicketSshCertificateTicket struct{ Value SshCertificateTicket } // TicketEcrTicket represents field EcrTicket of the Ticket union type. TicketEcrTicket struct{ Value EcrTicket } // TicketTlsServerTicket represents field TlsServerTicket of the Ticket union type. @@ -1958,82 +2823,88 @@ type ( TicketVanadiumTicket struct{ Value VanadiumTicket } // TicketGenericTicket represents field GenericTicket of the Ticket union type. TicketGenericTicket struct{ Value GenericTicket } - // __TicketReflect describes the Ticket union type. - __TicketReflect struct { + // vdlTicketReflect describes the Ticket union type. + vdlTicketReflect struct { Name string `vdl:"github.com/grailbio/base/security/ticket.Ticket"` Type Ticket Union struct { - AwsTicket TicketAwsTicket - S3Ticket TicketS3Ticket - EcrTicket TicketEcrTicket - TlsServerTicket TicketTlsServerTicket - TlsClientTicket TicketTlsClientTicket - DockerTicket TicketDockerTicket - DockerServerTicket TicketDockerServerTicket - DockerClientTicket TicketDockerClientTicket - B2Ticket TicketB2Ticket - VanadiumTicket TicketVanadiumTicket - GenericTicket TicketGenericTicket + AwsTicket TicketAwsTicket + S3Ticket TicketS3Ticket + SshCertificateTicket TicketSshCertificateTicket + EcrTicket TicketEcrTicket + TlsServerTicket TicketTlsServerTicket + TlsClientTicket TicketTlsClientTicket + DockerTicket TicketDockerTicket + DockerServerTicket TicketDockerServerTicket + DockerClientTicket TicketDockerClientTicket + B2Ticket TicketB2Ticket + VanadiumTicket TicketVanadiumTicket + GenericTicket TicketGenericTicket } } ) -func (x TicketAwsTicket) Index() int { return 0 } -func (x TicketAwsTicket) Interface() interface{} { return x.Value } -func (x TicketAwsTicket) Name() string { return "AwsTicket" } -func (x TicketAwsTicket) VDLReflect(__TicketReflect) {} - -func (x TicketS3Ticket) Index() int { return 1 } -func (x TicketS3Ticket) Interface() interface{} { return x.Value } -func (x TicketS3Ticket) Name() string { return "S3Ticket" } -func (x TicketS3Ticket) VDLReflect(__TicketReflect) {} - -func (x TicketEcrTicket) Index() int { return 2 } -func (x TicketEcrTicket) Interface() interface{} { return x.Value } -func (x TicketEcrTicket) Name() string { return "EcrTicket" } -func (x TicketEcrTicket) VDLReflect(__TicketReflect) {} - -func (x TicketTlsServerTicket) Index() int { return 3 } -func (x TicketTlsServerTicket) Interface() interface{} { return x.Value } -func (x TicketTlsServerTicket) Name() string { return "TlsServerTicket" } -func (x TicketTlsServerTicket) VDLReflect(__TicketReflect) {} - -func (x TicketTlsClientTicket) Index() int { return 4 } -func (x TicketTlsClientTicket) Interface() interface{} { return x.Value } -func (x TicketTlsClientTicket) Name() string { return "TlsClientTicket" } -func (x TicketTlsClientTicket) VDLReflect(__TicketReflect) {} - -func (x TicketDockerTicket) Index() int { return 5 } -func (x TicketDockerTicket) Interface() interface{} { return x.Value } -func (x TicketDockerTicket) Name() string { return "DockerTicket" } -func (x TicketDockerTicket) VDLReflect(__TicketReflect) {} - -func (x TicketDockerServerTicket) Index() int { return 6 } -func (x TicketDockerServerTicket) Interface() interface{} { return x.Value } -func (x TicketDockerServerTicket) Name() string { return "DockerServerTicket" } -func (x TicketDockerServerTicket) VDLReflect(__TicketReflect) {} - -func (x TicketDockerClientTicket) Index() int { return 7 } -func (x TicketDockerClientTicket) Interface() interface{} { return x.Value } -func (x TicketDockerClientTicket) Name() string { return "DockerClientTicket" } -func (x TicketDockerClientTicket) VDLReflect(__TicketReflect) {} - -func (x TicketB2Ticket) Index() int { return 8 } -func (x TicketB2Ticket) Interface() interface{} { return x.Value } -func (x TicketB2Ticket) Name() string { return "B2Ticket" } -func (x TicketB2Ticket) VDLReflect(__TicketReflect) {} - -func (x TicketVanadiumTicket) Index() int { return 9 } -func (x TicketVanadiumTicket) Interface() interface{} { return x.Value } -func (x TicketVanadiumTicket) Name() string { return "VanadiumTicket" } -func (x TicketVanadiumTicket) VDLReflect(__TicketReflect) {} - -func (x TicketGenericTicket) Index() int { return 10 } -func (x TicketGenericTicket) Interface() interface{} { return x.Value } -func (x TicketGenericTicket) Name() string { return "GenericTicket" } -func (x TicketGenericTicket) VDLReflect(__TicketReflect) {} - -func (x TicketAwsTicket) VDLIsZero() bool { +func (x TicketAwsTicket) Index() int { return 0 } +func (x TicketAwsTicket) Interface() interface{} { return x.Value } +func (x TicketAwsTicket) Name() string { return "AwsTicket" } +func (x TicketAwsTicket) VDLReflect(vdlTicketReflect) {} + +func (x TicketS3Ticket) Index() int { return 1 } +func (x TicketS3Ticket) Interface() interface{} { return x.Value } +func (x TicketS3Ticket) Name() string { return "S3Ticket" } +func (x TicketS3Ticket) VDLReflect(vdlTicketReflect) {} + +func (x TicketSshCertificateTicket) Index() int { return 2 } +func (x TicketSshCertificateTicket) Interface() interface{} { return x.Value } +func (x TicketSshCertificateTicket) Name() string { return "SshCertificateTicket" } +func (x TicketSshCertificateTicket) VDLReflect(vdlTicketReflect) {} + +func (x TicketEcrTicket) Index() int { return 3 } +func (x TicketEcrTicket) Interface() interface{} { return x.Value } +func (x TicketEcrTicket) Name() string { return "EcrTicket" } +func (x TicketEcrTicket) VDLReflect(vdlTicketReflect) {} + +func (x TicketTlsServerTicket) Index() int { return 4 } +func (x TicketTlsServerTicket) Interface() interface{} { return x.Value } +func (x TicketTlsServerTicket) Name() string { return "TlsServerTicket" } +func (x TicketTlsServerTicket) VDLReflect(vdlTicketReflect) {} + +func (x TicketTlsClientTicket) Index() int { return 5 } +func (x TicketTlsClientTicket) Interface() interface{} { return x.Value } +func (x TicketTlsClientTicket) Name() string { return "TlsClientTicket" } +func (x TicketTlsClientTicket) VDLReflect(vdlTicketReflect) {} + +func (x TicketDockerTicket) Index() int { return 6 } +func (x TicketDockerTicket) Interface() interface{} { return x.Value } +func (x TicketDockerTicket) Name() string { return "DockerTicket" } +func (x TicketDockerTicket) VDLReflect(vdlTicketReflect) {} + +func (x TicketDockerServerTicket) Index() int { return 7 } +func (x TicketDockerServerTicket) Interface() interface{} { return x.Value } +func (x TicketDockerServerTicket) Name() string { return "DockerServerTicket" } +func (x TicketDockerServerTicket) VDLReflect(vdlTicketReflect) {} + +func (x TicketDockerClientTicket) Index() int { return 8 } +func (x TicketDockerClientTicket) Interface() interface{} { return x.Value } +func (x TicketDockerClientTicket) Name() string { return "DockerClientTicket" } +func (x TicketDockerClientTicket) VDLReflect(vdlTicketReflect) {} + +func (x TicketB2Ticket) Index() int { return 9 } +func (x TicketB2Ticket) Interface() interface{} { return x.Value } +func (x TicketB2Ticket) Name() string { return "B2Ticket" } +func (x TicketB2Ticket) VDLReflect(vdlTicketReflect) {} + +func (x TicketVanadiumTicket) Index() int { return 10 } +func (x TicketVanadiumTicket) Interface() interface{} { return x.Value } +func (x TicketVanadiumTicket) Name() string { return "VanadiumTicket" } +func (x TicketVanadiumTicket) VDLReflect(vdlTicketReflect) {} + +func (x TicketGenericTicket) Index() int { return 11 } +func (x TicketGenericTicket) Interface() interface{} { return x.Value } +func (x TicketGenericTicket) Name() string { return "GenericTicket" } +func (x TicketGenericTicket) VDLReflect(vdlTicketReflect) {} + +func (x TicketAwsTicket) VDLIsZero() bool { //nolint:gocyclo return x.Value == AwsTicket{} } @@ -2041,6 +2912,10 @@ func (x TicketS3Ticket) VDLIsZero() bool { return false } +func (x TicketSshCertificateTicket) VDLIsZero() bool { + return false +} + func (x TicketEcrTicket) VDLIsZero() bool { return false } @@ -2077,8 +2952,8 @@ func (x TicketGenericTicket) VDLIsZero() bool { return false } -func (x TicketAwsTicket) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_union_26); err != nil { +func (x TicketAwsTicket) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeUnion37); err != nil { return err } if err := enc.NextField(0); err != nil { @@ -2093,8 +2968,8 @@ func (x TicketAwsTicket) VDLWrite(enc vdl.Encoder) error { return enc.FinishValue() } -func (x TicketS3Ticket) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_union_26); err != nil { +func (x TicketS3Ticket) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeUnion37); err != nil { return err } if err := enc.NextField(1); err != nil { @@ -2109,8 +2984,8 @@ func (x TicketS3Ticket) VDLWrite(enc vdl.Encoder) error { return enc.FinishValue() } -func (x TicketEcrTicket) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_union_26); err != nil { +func (x TicketSshCertificateTicket) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeUnion37); err != nil { return err } if err := enc.NextField(2); err != nil { @@ -2125,8 +3000,8 @@ func (x TicketEcrTicket) VDLWrite(enc vdl.Encoder) error { return enc.FinishValue() } -func (x TicketTlsServerTicket) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_union_26); err != nil { +func (x TicketEcrTicket) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeUnion37); err != nil { return err } if err := enc.NextField(3); err != nil { @@ -2141,8 +3016,8 @@ func (x TicketTlsServerTicket) VDLWrite(enc vdl.Encoder) error { return enc.FinishValue() } -func (x TicketTlsClientTicket) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_union_26); err != nil { +func (x TicketTlsServerTicket) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeUnion37); err != nil { return err } if err := enc.NextField(4); err != nil { @@ -2157,8 +3032,8 @@ func (x TicketTlsClientTicket) VDLWrite(enc vdl.Encoder) error { return enc.FinishValue() } -func (x TicketDockerTicket) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_union_26); err != nil { +func (x TicketTlsClientTicket) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeUnion37); err != nil { return err } if err := enc.NextField(5); err != nil { @@ -2173,8 +3048,8 @@ func (x TicketDockerTicket) VDLWrite(enc vdl.Encoder) error { return enc.FinishValue() } -func (x TicketDockerServerTicket) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_union_26); err != nil { +func (x TicketDockerTicket) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeUnion37); err != nil { return err } if err := enc.NextField(6); err != nil { @@ -2189,8 +3064,8 @@ func (x TicketDockerServerTicket) VDLWrite(enc vdl.Encoder) error { return enc.FinishValue() } -func (x TicketDockerClientTicket) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_union_26); err != nil { +func (x TicketDockerServerTicket) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeUnion37); err != nil { return err } if err := enc.NextField(7); err != nil { @@ -2205,8 +3080,8 @@ func (x TicketDockerClientTicket) VDLWrite(enc vdl.Encoder) error { return enc.FinishValue() } -func (x TicketB2Ticket) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_union_26); err != nil { +func (x TicketDockerClientTicket) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeUnion37); err != nil { return err } if err := enc.NextField(8); err != nil { @@ -2221,8 +3096,8 @@ func (x TicketB2Ticket) VDLWrite(enc vdl.Encoder) error { return enc.FinishValue() } -func (x TicketVanadiumTicket) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_union_26); err != nil { +func (x TicketB2Ticket) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeUnion37); err != nil { return err } if err := enc.NextField(9); err != nil { @@ -2237,8 +3112,8 @@ func (x TicketVanadiumTicket) VDLWrite(enc vdl.Encoder) error { return enc.FinishValue() } -func (x TicketGenericTicket) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_union_26); err != nil { +func (x TicketVanadiumTicket) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeUnion37); err != nil { return err } if err := enc.NextField(10); err != nil { @@ -2253,8 +3128,24 @@ func (x TicketGenericTicket) VDLWrite(enc vdl.Encoder) error { return enc.FinishValue() } -func VDLReadTicket(dec vdl.Decoder, x *Ticket) error { - if err := dec.StartValue(__VDLType_union_26); err != nil { +func (x TicketGenericTicket) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeUnion37); err != nil { + return err + } + if err := enc.NextField(11); err != nil { + return err + } + if err := x.Value.VDLWrite(enc); err != nil { + return err + } + if err := enc.NextField(-1); err != nil { + return err + } + return enc.FinishValue() +} + +func VDLReadTicket(dec vdl.Decoder, x *Ticket) error { //nolint:gocyclo + if err := dec.StartValue(vdlTypeUnion37); err != nil { return err } decType := dec.Type() @@ -2265,9 +3156,9 @@ func VDLReadTicket(dec vdl.Decoder, x *Ticket) error { case index == -1: return fmt.Errorf("missing field in union %T, from %v", x, decType) } - if decType != __VDLType_union_26 { + if decType != vdlTypeUnion37 { name := decType.Field(index).Name - index = __VDLType_union_26.FieldIndexByName(name) + index = vdlTypeUnion37.FieldIndexByName(name) if index == -1 { return fmt.Errorf("field %q not in union %T, from %v", name, x, decType) } @@ -2286,54 +3177,60 @@ func VDLReadTicket(dec vdl.Decoder, x *Ticket) error { } *x = field case 2: - var field TicketEcrTicket + var field TicketSshCertificateTicket if err := field.Value.VDLRead(dec); err != nil { return err } *x = field case 3: - var field TicketTlsServerTicket + var field TicketEcrTicket if err := field.Value.VDLRead(dec); err != nil { return err } *x = field case 4: - var field TicketTlsClientTicket + var field TicketTlsServerTicket if err := field.Value.VDLRead(dec); err != nil { return err } *x = field case 5: - var field TicketDockerTicket + var field TicketTlsClientTicket if err := field.Value.VDLRead(dec); err != nil { return err } *x = field case 6: - var field TicketDockerServerTicket + var field TicketDockerTicket if err := field.Value.VDLRead(dec); err != nil { return err } *x = field case 7: - var field TicketDockerClientTicket + var field TicketDockerServerTicket if err := field.Value.VDLRead(dec); err != nil { return err } *x = field case 8: - var field TicketB2Ticket + var field TicketDockerClientTicket if err := field.Value.VDLRead(dec); err != nil { return err } *x = field case 9: - var field TicketVanadiumTicket + var field TicketB2Ticket if err := field.Value.VDLRead(dec); err != nil { return err } *x = field case 10: + var field TicketVanadiumTicket + if err := field.Value.VDLRead(dec); err != nil { + return err + } + *x = field + case 11: var field TicketGenericTicket if err := field.Value.VDLRead(dec); err != nil { return err @@ -2353,6 +3250,7 @@ func VDLReadTicket(dec vdl.Decoder, x *Ticket) error { type TicketConfig struct { Ticket Ticket Permissions access.Permissions + Controls map[Control]bool } func (TicketConfig) VDLReflect(struct { @@ -2360,18 +3258,21 @@ func (TicketConfig) VDLReflect(struct { }) { } -func (x TicketConfig) VDLIsZero() bool { +func (x TicketConfig) VDLIsZero() bool { //nolint:gocyclo if x.Ticket != nil && !x.Ticket.VDLIsZero() { return false } if len(x.Permissions) != 0 { return false } + if len(x.Controls) != 0 { + return false + } return true } -func (x TicketConfig) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_struct_27); err != nil { +func (x TicketConfig) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeStruct38); err != nil { return err } if x.Ticket != nil && !x.Ticket.VDLIsZero() { @@ -2390,17 +3291,46 @@ func (x TicketConfig) VDLWrite(enc vdl.Encoder) error { return err } } + if len(x.Controls) != 0 { + if err := enc.NextField(2); err != nil { + return err + } + if err := vdlWriteAnonMap4(enc, x.Controls); err != nil { + return err + } + } if err := enc.NextField(-1); err != nil { return err } return enc.FinishValue() } -func (x *TicketConfig) VDLRead(dec vdl.Decoder) error { +func vdlWriteAnonMap4(enc vdl.Encoder, x map[Control]bool) error { + if err := enc.StartValue(vdlTypeMap40); err != nil { + return err + } + if err := enc.SetLenHint(len(x)); err != nil { + return err + } + for key, elem := range x { + if err := enc.NextEntryValueString(vdlTypeEnum1, key.String()); err != nil { + return err + } + if err := enc.WriteValueBool(vdl.BoolType, elem); err != nil { + return err + } + } + if err := enc.NextEntry(true); err != nil { + return err + } + return enc.FinishValue() +} + +func (x *TicketConfig) VDLRead(dec vdl.Decoder) error { //nolint:gocyclo *x = TicketConfig{ Ticket: TicketAwsTicket{}, } - if err := dec.StartValue(__VDLType_struct_27); err != nil { + if err := dec.StartValue(vdlTypeStruct38); err != nil { return err } decType := dec.Type() @@ -2412,8 +3342,8 @@ func (x *TicketConfig) VDLRead(dec vdl.Decoder) error { case index == -1: return dec.FinishValue() } - if decType != __VDLType_struct_27 { - index = __VDLType_struct_27.FieldIndexByName(decType.Field(index).Name) + if decType != vdlTypeStruct38 { + index = vdlTypeStruct38.FieldIndexByName(decType.Field(index).Name) if index == -1 { if err := dec.SkipValue(); err != nil { return err @@ -2430,6 +3360,45 @@ func (x *TicketConfig) VDLRead(dec vdl.Decoder) error { if err := x.Permissions.VDLRead(dec); err != nil { return err } + case 2: + if err := vdlReadAnonMap4(dec, &x.Controls); err != nil { + return err + } + } + } +} + +func vdlReadAnonMap4(dec vdl.Decoder, x *map[Control]bool) error { + if err := dec.StartValue(vdlTypeMap40); err != nil { + return err + } + var tmpMap map[Control]bool + if len := dec.LenHint(); len > 0 { + tmpMap = make(map[Control]bool, len) + } + for { + switch done, key, err := dec.NextEntryValueString(); { + case err != nil: + return err + case done: + *x = tmpMap + return dec.FinishValue() + default: + var keyEnum Control + if err := keyEnum.Set(key); err != nil { + return err + } + var elem bool + switch value, err := dec.ReadValueBool(); { + case err != nil: + return err + default: + elem = value + } + if tmpMap == nil { + tmpMap = make(map[Control]bool) + } + tmpMap[keyEnum] = elem } } } @@ -2444,7 +3413,7 @@ func (Config) VDLReflect(struct { }) { } -func (x Config) VDLIsZero() bool { +func (x Config) VDLIsZero() bool { //nolint:gocyclo if len(x.Tickets) != 0 { return false } @@ -2454,15 +3423,15 @@ func (x Config) VDLIsZero() bool { return true } -func (x Config) VDLWrite(enc vdl.Encoder) error { - if err := enc.StartValue(__VDLType_struct_29); err != nil { +func (x Config) VDLWrite(enc vdl.Encoder) error { //nolint:gocyclo + if err := enc.StartValue(vdlTypeStruct41); err != nil { return err } if len(x.Tickets) != 0 { if err := enc.NextField(0); err != nil { return err } - if err := __VDLWriteAnon_map_2(enc, x.Tickets); err != nil { + if err := vdlWriteAnonMap5(enc, x.Tickets); err != nil { return err } } @@ -2480,8 +3449,8 @@ func (x Config) VDLWrite(enc vdl.Encoder) error { return enc.FinishValue() } -func __VDLWriteAnon_map_2(enc vdl.Encoder, x map[string]TicketConfig) error { - if err := enc.StartValue(__VDLType_map_30); err != nil { +func vdlWriteAnonMap5(enc vdl.Encoder, x map[string]TicketConfig) error { + if err := enc.StartValue(vdlTypeMap42); err != nil { return err } if err := enc.SetLenHint(len(x)); err != nil { @@ -2501,9 +3470,9 @@ func __VDLWriteAnon_map_2(enc vdl.Encoder, x map[string]TicketConfig) error { return enc.FinishValue() } -func (x *Config) VDLRead(dec vdl.Decoder) error { +func (x *Config) VDLRead(dec vdl.Decoder) error { //nolint:gocyclo *x = Config{} - if err := dec.StartValue(__VDLType_struct_29); err != nil { + if err := dec.StartValue(vdlTypeStruct41); err != nil { return err } decType := dec.Type() @@ -2515,8 +3484,8 @@ func (x *Config) VDLRead(dec vdl.Decoder) error { case index == -1: return dec.FinishValue() } - if decType != __VDLType_struct_29 { - index = __VDLType_struct_29.FieldIndexByName(decType.Field(index).Name) + if decType != vdlTypeStruct41 { + index = vdlTypeStruct41.FieldIndexByName(decType.Field(index).Name) if index == -1 { if err := dec.SkipValue(); err != nil { return err @@ -2526,7 +3495,7 @@ func (x *Config) VDLRead(dec vdl.Decoder) error { } switch index { case 0: - if err := __VDLReadAnon_map_2(dec, &x.Tickets); err != nil { + if err := vdlReadAnonMap5(dec, &x.Tickets); err != nil { return err } case 1: @@ -2537,8 +3506,8 @@ func (x *Config) VDLRead(dec vdl.Decoder) error { } } -func __VDLReadAnon_map_2(dec vdl.Decoder, x *map[string]TicketConfig) error { - if err := dec.StartValue(__VDLType_map_30); err != nil { +func vdlReadAnonMap5(dec vdl.Decoder, x *map[string]TicketConfig) error { + if err := dec.StartValue(vdlTypeMap42); err != nil { return err } var tmpMap map[string]TicketConfig @@ -2565,8 +3534,8 @@ func __VDLReadAnon_map_2(dec vdl.Decoder, x *map[string]TicketConfig) error { } } -////////////////////////////////////////////////// // Interface definitions +// ===================== // TicketServiceClientMethods is the client interface // containing TicketService methods. @@ -2577,12 +3546,14 @@ type TicketServiceClientMethods interface { GetPermissions(*context.T, ...rpc.CallOpt) (perms access.Permissions, version string, _ error) SetPermissions(_ *context.T, perms access.Permissions, version string, _ ...rpc.CallOpt) error Get(*context.T, ...rpc.CallOpt) (Ticket, error) + GetWithParameters(_ *context.T, parameters []Parameter, _ ...rpc.CallOpt) (Ticket, error) + GetWithArgs(_ *context.T, args map[string]string, _ ...rpc.CallOpt) (Ticket, error) } -// TicketServiceClientStub adds universal methods to TicketServiceClientMethods. +// TicketServiceClientStub embeds TicketServiceClientMethods and is a +// placeholder for additional management operations. type TicketServiceClientStub interface { TicketServiceClientMethods - rpc.UniversalServiceMethods } // TicketServiceClient returns a client stub for TicketService. @@ -2609,6 +3580,16 @@ func (c implTicketServiceClientStub) Get(ctx *context.T, opts ...rpc.CallOpt) (o return } +func (c implTicketServiceClientStub) GetWithParameters(ctx *context.T, i0 []Parameter, opts ...rpc.CallOpt) (o0 Ticket, err error) { + err = v23.GetClient(ctx).Call(ctx, c.name, "GetWithParameters", []interface{}{i0}, []interface{}{&o0}, opts...) + return +} + +func (c implTicketServiceClientStub) GetWithArgs(ctx *context.T, i0 map[string]string, opts ...rpc.CallOpt) (o0 Ticket, err error) { + err = v23.GetClient(ctx).Call(ctx, c.name, "GetWithArgs", []interface{}{i0}, []interface{}{&o0}, opts...) + return +} + // TicketServiceServerMethods is the interface a server writer // implements for TicketService. // @@ -2618,6 +3599,8 @@ type TicketServiceServerMethods interface { GetPermissions(*context.T, rpc.ServerCall) (perms access.Permissions, version string, _ error) SetPermissions(_ *context.T, _ rpc.ServerCall, perms access.Permissions, version string) error Get(*context.T, rpc.ServerCall) (Ticket, error) + GetWithParameters(_ *context.T, _ rpc.ServerCall, parameters []Parameter) (Ticket, error) + GetWithArgs(_ *context.T, _ rpc.ServerCall, args map[string]string) (Ticket, error) } // TicketServiceServerStubMethods is the server interface containing @@ -2629,7 +3612,7 @@ type TicketServiceServerStubMethods TicketServiceServerMethods // TicketServiceServerStub adds universal methods to TicketServiceServerStubMethods. type TicketServiceServerStub interface { TicketServiceServerStubMethods - // Describe the TicketService interfaces. + // DescribeInterfaces the TicketService interfaces. Describe__() []rpc.InterfaceDesc } @@ -2667,6 +3650,14 @@ func (s implTicketServiceServerStub) Get(ctx *context.T, call rpc.ServerCall) (T return s.impl.Get(ctx, call) } +func (s implTicketServiceServerStub) GetWithParameters(ctx *context.T, call rpc.ServerCall, i0 []Parameter) (Ticket, error) { + return s.impl.GetWithParameters(ctx, call, i0) +} + +func (s implTicketServiceServerStub) GetWithArgs(ctx *context.T, call rpc.ServerCall, i0 map[string]string) (Ticket, error) { + return s.impl.GetWithArgs(ctx, call, i0) +} + func (s implTicketServiceServerStub) Globber() *rpc.GlobState { return s.gs } @@ -2687,23 +3678,140 @@ var descTicketService = rpc.InterfaceDesc{ { Name: "GetPermissions", OutArgs: []rpc.ArgDesc{ - {"perms", ``}, // access.Permissions - {"version", ``}, // string + {Name: "perms", Doc: ``}, // access.Permissions + {Name: "version", Doc: ``}, // string }, Tags: []*vdl.Value{vdl.ValueOf(access.Tag("Read"))}, }, { Name: "SetPermissions", InArgs: []rpc.ArgDesc{ - {"perms", ``}, // access.Permissions - {"version", ``}, // string + {Name: "perms", Doc: ``}, // access.Permissions + {Name: "version", Doc: ``}, // string }, Tags: []*vdl.Value{vdl.ValueOf(access.Tag("Admin"))}, }, { Name: "Get", OutArgs: []rpc.ArgDesc{ - {"", ``}, // Ticket + {Name: "", Doc: ``}, // Ticket + }, + Tags: []*vdl.Value{vdl.ValueOf(access.Tag("Read"))}, + }, + { + Name: "GetWithParameters", + InArgs: []rpc.ArgDesc{ + {Name: "parameters", Doc: ``}, // []Parameter + }, + OutArgs: []rpc.ArgDesc{ + {Name: "", Doc: ``}, // Ticket + }, + Tags: []*vdl.Value{vdl.ValueOf(access.Tag("Read"))}, + }, + { + Name: "GetWithArgs", + InArgs: []rpc.ArgDesc{ + {Name: "args", Doc: ``}, // map[string]string + }, + OutArgs: []rpc.ArgDesc{ + {Name: "", Doc: ``}, // Ticket + }, + Tags: []*vdl.Value{vdl.ValueOf(access.Tag("Read"))}, + }, + }, +} + +// ListServiceClientMethods is the client interface +// containing ListService methods. +type ListServiceClientMethods interface { + List(*context.T, ...rpc.CallOpt) ([]string, error) +} + +// ListServiceClientStub embeds ListServiceClientMethods and is a +// placeholder for additional management operations. +type ListServiceClientStub interface { + ListServiceClientMethods +} + +// ListServiceClient returns a client stub for ListService. +func ListServiceClient(name string) ListServiceClientStub { + return implListServiceClientStub{name} +} + +type implListServiceClientStub struct { + name string +} + +func (c implListServiceClientStub) List(ctx *context.T, opts ...rpc.CallOpt) (o0 []string, err error) { + err = v23.GetClient(ctx).Call(ctx, c.name, "List", nil, []interface{}{&o0}, opts...) + return +} + +// ListServiceServerMethods is the interface a server writer +// implements for ListService. +type ListServiceServerMethods interface { + List(*context.T, rpc.ServerCall) ([]string, error) +} + +// ListServiceServerStubMethods is the server interface containing +// ListService methods, as expected by rpc.Server. +// There is no difference between this interface and ListServiceServerMethods +// since there are no streaming methods. +type ListServiceServerStubMethods ListServiceServerMethods + +// ListServiceServerStub adds universal methods to ListServiceServerStubMethods. +type ListServiceServerStub interface { + ListServiceServerStubMethods + // DescribeInterfaces the ListService interfaces. + Describe__() []rpc.InterfaceDesc +} + +// ListServiceServer returns a server stub for ListService. +// It converts an implementation of ListServiceServerMethods into +// an object that may be used by rpc.Server. +func ListServiceServer(impl ListServiceServerMethods) ListServiceServerStub { + stub := implListServiceServerStub{ + impl: impl, + } + // Initialize GlobState; always check the stub itself first, to handle the + // case where the user has the Glob method defined in their VDL source. + if gs := rpc.NewGlobState(stub); gs != nil { + stub.gs = gs + } else if gs := rpc.NewGlobState(impl); gs != nil { + stub.gs = gs + } + return stub +} + +type implListServiceServerStub struct { + impl ListServiceServerMethods + gs *rpc.GlobState +} + +func (s implListServiceServerStub) List(ctx *context.T, call rpc.ServerCall) ([]string, error) { + return s.impl.List(ctx, call) +} + +func (s implListServiceServerStub) Globber() *rpc.GlobState { + return s.gs +} + +func (s implListServiceServerStub) Describe__() []rpc.InterfaceDesc { + return []rpc.InterfaceDesc{ListServiceDesc} +} + +// ListServiceDesc describes the ListService interface. +var ListServiceDesc rpc.InterfaceDesc = descListService + +// descListService hides the desc to keep godoc clean. +var descListService = rpc.InterfaceDesc{ + Name: "ListService", + PkgPath: "github.com/grailbio/base/security/ticket", + Methods: []rpc.MethodDesc{ + { + Name: "List", + OutArgs: []rpc.ArgDesc{ + {Name: "", Doc: ``}, // []string }, Tags: []*vdl.Value{vdl.ValueOf(access.Tag("Read"))}, }, @@ -2711,46 +3819,59 @@ var descTicketService = rpc.InterfaceDesc{ } // Hold type definitions in package-level variables, for better performance. +//nolint:unused var ( - __VDLType_struct_1 *vdl.Type - __VDLType_struct_2 *vdl.Type - __VDLType_struct_3 *vdl.Type - __VDLType_struct_4 *vdl.Type - __VDLType_list_5 *vdl.Type - __VDLType_struct_6 *vdl.Type - __VDLType_struct_7 *vdl.Type - __VDLType_struct_8 *vdl.Type - __VDLType_optional_9 *vdl.Type - __VDLType_optional_10 *vdl.Type - __VDLType_struct_11 *vdl.Type - __VDLType_struct_12 *vdl.Type - __VDLType_struct_13 *vdl.Type - __VDLType_struct_14 *vdl.Type - __VDLType_optional_15 *vdl.Type - __VDLType_struct_16 *vdl.Type - __VDLType_struct_17 *vdl.Type - __VDLType_struct_18 *vdl.Type - __VDLType_struct_19 *vdl.Type - __VDLType_struct_20 *vdl.Type - __VDLType_optional_21 *vdl.Type - __VDLType_struct_22 *vdl.Type - __VDLType_optional_23 *vdl.Type - __VDLType_struct_24 *vdl.Type - __VDLType_list_25 *vdl.Type - __VDLType_union_26 *vdl.Type - __VDLType_struct_27 *vdl.Type - __VDLType_map_28 *vdl.Type - __VDLType_struct_29 *vdl.Type - __VDLType_map_30 *vdl.Type + vdlTypeEnum1 *vdl.Type + vdlTypeStruct2 *vdl.Type + vdlTypeStruct3 *vdl.Type + vdlTypeStruct4 *vdl.Type + vdlTypeStruct5 *vdl.Type + vdlTypeList6 *vdl.Type + vdlTypeStruct7 *vdl.Type + vdlTypeStruct8 *vdl.Type + vdlTypeStruct9 *vdl.Type + vdlTypeStruct10 *vdl.Type + vdlTypeOptional11 *vdl.Type + vdlTypeOptional12 *vdl.Type + vdlTypeStruct13 *vdl.Type + vdlTypeStruct14 *vdl.Type + vdlTypeStruct15 *vdl.Type + vdlTypeStruct16 *vdl.Type + vdlTypeStruct17 *vdl.Type + vdlTypeOptional18 *vdl.Type + vdlTypeStruct19 *vdl.Type + vdlTypeStruct20 *vdl.Type + vdlTypeStruct21 *vdl.Type + vdlTypeStruct22 *vdl.Type + vdlTypeStruct23 *vdl.Type + vdlTypeStruct24 *vdl.Type + vdlTypeList25 *vdl.Type + vdlTypeStruct26 *vdl.Type + vdlTypeStruct27 *vdl.Type + vdlTypeOptional28 *vdl.Type + vdlTypeOptional29 *vdl.Type + vdlTypeList30 *vdl.Type + vdlTypeStruct31 *vdl.Type + vdlTypeOptional32 *vdl.Type + vdlTypeStruct33 *vdl.Type + vdlTypeOptional34 *vdl.Type + vdlTypeStruct35 *vdl.Type + vdlTypeList36 *vdl.Type + vdlTypeUnion37 *vdl.Type + vdlTypeStruct38 *vdl.Type + vdlTypeMap39 *vdl.Type + vdlTypeMap40 *vdl.Type + vdlTypeStruct41 *vdl.Type + vdlTypeMap42 *vdl.Type ) -var __VDLInitCalled bool +var initializeVDLCalled bool -// __VDLInit performs vdl initialization. It is safe to call multiple times. +// initializeVDL performs vdl initialization. It is safe to call multiple times. // If you have an init ordering issue, just insert the following line verbatim // into your source files in this package, right after the "package foo" clause: // -// var _ = __VDLInit() +// var _ = initializeVDL() // // The purpose of this function is to ensure that vdl initialization occurs in // the right order, and very early in the init sequence. In particular, vdl @@ -2759,28 +3880,35 @@ var __VDLInitCalled bool // // This function returns a dummy value, so that it can be used to initialize the // first var in the file, to take advantage of Go's defined init order. -func __VDLInit() struct{} { - if __VDLInitCalled { +func initializeVDL() struct{} { + if initializeVDLCalled { return struct{}{} } - __VDLInitCalled = true + initializeVDLCalled = true // Register types. + vdl.Register((*Control)(nil)) vdl.Register((*AwsCredentials)(nil)) vdl.Register((*AwsAssumeRoleBuilder)(nil)) vdl.Register((*AwsSessionBuilder)(nil)) vdl.Register((*TlsCertAuthorityBuilder)(nil)) + vdl.Register((*SshCertAuthorityBuilder)(nil)) vdl.Register((*B2AccountAuthorizationBuilder)(nil)) vdl.Register((*VanadiumBuilder)(nil)) vdl.Register((*AwsTicket)(nil)) vdl.Register((*S3Ticket)(nil)) vdl.Register((*EcrTicket)(nil)) + vdl.Register((*SshCert)(nil)) vdl.Register((*TlsCredentials)(nil)) vdl.Register((*TlsServerTicket)(nil)) vdl.Register((*TlsClientTicket)(nil)) vdl.Register((*DockerTicket)(nil)) vdl.Register((*DockerServerTicket)(nil)) vdl.Register((*DockerClientTicket)(nil)) + vdl.Register((*Parameter)(nil)) + vdl.Register((*AwsComputeInstancesBuilder)(nil)) + vdl.Register((*ComputeInstance)(nil)) + vdl.Register((*SshCertificateTicket)(nil)) vdl.Register((*B2Ticket)(nil)) vdl.Register((*VanadiumTicket)(nil)) vdl.Register((*GenericTicket)(nil)) @@ -2789,36 +3917,48 @@ func __VDLInit() struct{} { vdl.Register((*Config)(nil)) // Initialize type definitions. - __VDLType_struct_1 = vdl.TypeOf((*AwsCredentials)(nil)).Elem() - __VDLType_struct_2 = vdl.TypeOf((*AwsAssumeRoleBuilder)(nil)).Elem() - __VDLType_struct_3 = vdl.TypeOf((*AwsSessionBuilder)(nil)).Elem() - __VDLType_struct_4 = vdl.TypeOf((*TlsCertAuthorityBuilder)(nil)).Elem() - __VDLType_list_5 = vdl.TypeOf((*[]string)(nil)) - __VDLType_struct_6 = vdl.TypeOf((*B2AccountAuthorizationBuilder)(nil)).Elem() - __VDLType_struct_7 = vdl.TypeOf((*VanadiumBuilder)(nil)).Elem() - __VDLType_struct_8 = vdl.TypeOf((*AwsTicket)(nil)).Elem() - __VDLType_optional_9 = vdl.TypeOf((*AwsAssumeRoleBuilder)(nil)) - __VDLType_optional_10 = vdl.TypeOf((*AwsSessionBuilder)(nil)) - __VDLType_struct_11 = vdl.TypeOf((*S3Ticket)(nil)).Elem() - __VDLType_struct_12 = vdl.TypeOf((*EcrTicket)(nil)).Elem() - __VDLType_struct_13 = vdl.TypeOf((*TlsCredentials)(nil)).Elem() - __VDLType_struct_14 = vdl.TypeOf((*TlsServerTicket)(nil)).Elem() - __VDLType_optional_15 = vdl.TypeOf((*TlsCertAuthorityBuilder)(nil)) - __VDLType_struct_16 = vdl.TypeOf((*TlsClientTicket)(nil)).Elem() - __VDLType_struct_17 = vdl.TypeOf((*DockerTicket)(nil)).Elem() - __VDLType_struct_18 = vdl.TypeOf((*DockerServerTicket)(nil)).Elem() - __VDLType_struct_19 = vdl.TypeOf((*DockerClientTicket)(nil)).Elem() - __VDLType_struct_20 = vdl.TypeOf((*B2Ticket)(nil)).Elem() - __VDLType_optional_21 = vdl.TypeOf((*B2AccountAuthorizationBuilder)(nil)) - __VDLType_struct_22 = vdl.TypeOf((*VanadiumTicket)(nil)).Elem() - __VDLType_optional_23 = vdl.TypeOf((*VanadiumBuilder)(nil)) - __VDLType_struct_24 = vdl.TypeOf((*GenericTicket)(nil)).Elem() - __VDLType_list_25 = vdl.TypeOf((*[]byte)(nil)) - __VDLType_union_26 = vdl.TypeOf((*Ticket)(nil)) - __VDLType_struct_27 = vdl.TypeOf((*TicketConfig)(nil)).Elem() - __VDLType_map_28 = vdl.TypeOf((*access.Permissions)(nil)) - __VDLType_struct_29 = vdl.TypeOf((*Config)(nil)).Elem() - __VDLType_map_30 = vdl.TypeOf((*map[string]TicketConfig)(nil)) + vdlTypeEnum1 = vdl.TypeOf((*Control)(nil)) + vdlTypeStruct2 = vdl.TypeOf((*AwsCredentials)(nil)).Elem() + vdlTypeStruct3 = vdl.TypeOf((*AwsAssumeRoleBuilder)(nil)).Elem() + vdlTypeStruct4 = vdl.TypeOf((*AwsSessionBuilder)(nil)).Elem() + vdlTypeStruct5 = vdl.TypeOf((*TlsCertAuthorityBuilder)(nil)).Elem() + vdlTypeList6 = vdl.TypeOf((*[]string)(nil)) + vdlTypeStruct7 = vdl.TypeOf((*SshCertAuthorityBuilder)(nil)).Elem() + vdlTypeStruct8 = vdl.TypeOf((*B2AccountAuthorizationBuilder)(nil)).Elem() + vdlTypeStruct9 = vdl.TypeOf((*VanadiumBuilder)(nil)).Elem() + vdlTypeStruct10 = vdl.TypeOf((*AwsTicket)(nil)).Elem() + vdlTypeOptional11 = vdl.TypeOf((*AwsAssumeRoleBuilder)(nil)) + vdlTypeOptional12 = vdl.TypeOf((*AwsSessionBuilder)(nil)) + vdlTypeStruct13 = vdl.TypeOf((*S3Ticket)(nil)).Elem() + vdlTypeStruct14 = vdl.TypeOf((*EcrTicket)(nil)).Elem() + vdlTypeStruct15 = vdl.TypeOf((*SshCert)(nil)).Elem() + vdlTypeStruct16 = vdl.TypeOf((*TlsCredentials)(nil)).Elem() + vdlTypeStruct17 = vdl.TypeOf((*TlsServerTicket)(nil)).Elem() + vdlTypeOptional18 = vdl.TypeOf((*TlsCertAuthorityBuilder)(nil)) + vdlTypeStruct19 = vdl.TypeOf((*TlsClientTicket)(nil)).Elem() + vdlTypeStruct20 = vdl.TypeOf((*DockerTicket)(nil)).Elem() + vdlTypeStruct21 = vdl.TypeOf((*DockerServerTicket)(nil)).Elem() + vdlTypeStruct22 = vdl.TypeOf((*DockerClientTicket)(nil)).Elem() + vdlTypeStruct23 = vdl.TypeOf((*Parameter)(nil)).Elem() + vdlTypeStruct24 = vdl.TypeOf((*AwsComputeInstancesBuilder)(nil)).Elem() + vdlTypeList25 = vdl.TypeOf((*[]Parameter)(nil)) + vdlTypeStruct26 = vdl.TypeOf((*ComputeInstance)(nil)).Elem() + vdlTypeStruct27 = vdl.TypeOf((*SshCertificateTicket)(nil)).Elem() + vdlTypeOptional28 = vdl.TypeOf((*SshCertAuthorityBuilder)(nil)) + vdlTypeOptional29 = vdl.TypeOf((*AwsComputeInstancesBuilder)(nil)) + vdlTypeList30 = vdl.TypeOf((*[]ComputeInstance)(nil)) + vdlTypeStruct31 = vdl.TypeOf((*B2Ticket)(nil)).Elem() + vdlTypeOptional32 = vdl.TypeOf((*B2AccountAuthorizationBuilder)(nil)) + vdlTypeStruct33 = vdl.TypeOf((*VanadiumTicket)(nil)).Elem() + vdlTypeOptional34 = vdl.TypeOf((*VanadiumBuilder)(nil)) + vdlTypeStruct35 = vdl.TypeOf((*GenericTicket)(nil)).Elem() + vdlTypeList36 = vdl.TypeOf((*[]byte)(nil)) + vdlTypeUnion37 = vdl.TypeOf((*Ticket)(nil)) + vdlTypeStruct38 = vdl.TypeOf((*TicketConfig)(nil)).Elem() + vdlTypeMap39 = vdl.TypeOf((*access.Permissions)(nil)) + vdlTypeMap40 = vdl.TypeOf((*map[Control]bool)(nil)) + vdlTypeStruct41 = vdl.TypeOf((*Config)(nil)).Elem() + vdlTypeMap42 = vdl.TypeOf((*map[string]TicketConfig)(nil)) return struct{}{} } diff --git a/security/ticket/ticket_test.go b/security/ticket/ticket_test.go index ecde5ec0..427d0706 100644 --- a/security/ticket/ticket_test.go +++ b/security/ticket/ticket_test.go @@ -7,10 +7,15 @@ package ticket import ( "reflect" "testing" + + "github.com/grailbio/base/vcontext" ) func TestMerge(t *testing.T) { - got := mergeOrDie(&S3Ticket{Endpoint: "xxx"}, &S3Ticket{Bucket: "yyy"}) + ctx := &TicketContext{ + ctx: vcontext.Background(), + } + got := mergeOrDie(ctx, &S3Ticket{Endpoint: "xxx"}, &S3Ticket{Bucket: "yyy"}) want := &S3Ticket{ Endpoint: "xxx", Bucket: "yyy", diff --git a/security/ticket/tls.go b/security/ticket/tls.go index b86d637d..027b0980 100644 --- a/security/ticket/tls.go +++ b/security/ticket/tls.go @@ -10,15 +10,15 @@ import ( "encoding/pem" "time" + "github.com/grailbio/base/common/log" "github.com/grailbio/base/security/keycrypt" "github.com/grailbio/base/security/tls/certificateauthority" - "v.io/x/lib/vlog" ) -const driftMargin = 10 * time.Minute +const tlsDriftMargin = 10 * time.Minute -func (b *TlsCertAuthorityBuilder) newTlsClientTicket() (TicketTlsClientTicket, error) { - tlsCredentials, err := b.genTlsCredentials() +func (b *TlsCertAuthorityBuilder) newTlsClientTicket(ctx *TicketContext) (TicketTlsClientTicket, error) { + tlsCredentials, err := b.genTlsCredentials(ctx) if err != nil { return TicketTlsClientTicket{}, err @@ -31,8 +31,8 @@ func (b *TlsCertAuthorityBuilder) newTlsClientTicket() (TicketTlsClientTicket, e }, nil } -func (b *TlsCertAuthorityBuilder) newTlsServerTicket() (TicketTlsServerTicket, error) { - tlsCredentials, err := b.genTlsCredentials() +func (b *TlsCertAuthorityBuilder) newTlsServerTicket(ctx *TicketContext) (TicketTlsServerTicket, error) { + tlsCredentials, err := b.genTlsCredentials(ctx) if err != nil { return TicketTlsServerTicket{}, err @@ -45,8 +45,8 @@ func (b *TlsCertAuthorityBuilder) newTlsServerTicket() (TicketTlsServerTicket, e }, nil } -func (b *TlsCertAuthorityBuilder) newDockerTicket() (TicketDockerTicket, error) { - tlsCredentials, err := b.genTlsCredentials() +func (b *TlsCertAuthorityBuilder) newDockerTicket(ctx *TicketContext) (TicketDockerTicket, error) { + tlsCredentials, err := b.genTlsCredentials(ctx) if err != nil { return TicketDockerTicket{}, err @@ -59,8 +59,8 @@ func (b *TlsCertAuthorityBuilder) newDockerTicket() (TicketDockerTicket, error) }, nil } -func (b *TlsCertAuthorityBuilder) newDockerServerTicket() (TicketDockerServerTicket, error) { - tlsCredentials, err := b.genTlsCredentialsWithKeyUsage([]x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}) +func (b *TlsCertAuthorityBuilder) newDockerServerTicket(ctx *TicketContext) (TicketDockerServerTicket, error) { + tlsCredentials, err := b.genTlsCredentialsWithKeyUsage(ctx, []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}) if err != nil { return TicketDockerServerTicket{}, err @@ -73,8 +73,8 @@ func (b *TlsCertAuthorityBuilder) newDockerServerTicket() (TicketDockerServerTic }, nil } -func (b *TlsCertAuthorityBuilder) newDockerClientTicket() (TicketDockerClientTicket, error) { - tlsCredentials, err := b.genTlsCredentialsWithKeyUsage([]x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}) +func (b *TlsCertAuthorityBuilder) newDockerClientTicket(ctx *TicketContext) (TicketDockerClientTicket, error) { + tlsCredentials, err := b.genTlsCredentialsWithKeyUsage(ctx, []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}) if err != nil { return TicketDockerClientTicket{}, err @@ -87,24 +87,42 @@ func (b *TlsCertAuthorityBuilder) newDockerClientTicket() (TicketDockerClientTic }, nil } -func (b *TlsCertAuthorityBuilder) genTlsCredentials() (TlsCredentials, error) { - return b.genTlsCredentialsWithKeyUsage([]x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}) +func (b *TlsCertAuthorityBuilder) genTlsCredentials(ctx *TicketContext) (TlsCredentials, error) { + log.Info(ctx.ctx, "Generating TLS credentials.", "TlsCertAuthorityBuilder", b) + return b.genTlsCredentialsWithKeyUsage(ctx, []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}) } -func (b *TlsCertAuthorityBuilder) genTlsCredentialsWithKeyUsage(keyUsage []x509.ExtKeyUsage) (TlsCredentials, error) { - vlog.Infof("TlsCertAuthorityBuilder: %+v", b) +func contains(array []string, entry string) bool { + for _, e := range array { + if e == entry { + return true + } + } + + return false +} + +func (b *TlsCertAuthorityBuilder) genTlsCredentialsWithKeyUsage(ctx *TicketContext, keyUsage []x509.ExtKeyUsage) (TlsCredentials, error) { empty := TlsCredentials{} secret, err := keycrypt.Lookup(b.Authority) if err != nil { return empty, err } - authority := certificateauthority.CertificateAuthority{DriftMargin: driftMargin, Signer: secret} + authority := certificateauthority.CertificateAuthority{DriftMargin: tlsDriftMargin, Signer: secret} if err := authority.Init(); err != nil { return empty, err } ttl := time.Duration(b.TtlSec) * time.Second - cert, key, err := authority.IssueWithKeyUsage(b.CommonName, ttl, nil, b.San, keyUsage) + commonName := b.CommonName + if commonName == "" { + commonName = ctx.remoteBlessings.String() + } + updatedSan := b.San + if !contains(updatedSan, commonName) { + updatedSan = append(updatedSan, commonName) + } + cert, key, err := authority.IssueWithKeyUsage(commonName, ttl, nil, updatedSan, keyUsage) if err != nil { return empty, err } diff --git a/security/ticket/vanadium.go b/security/ticket/vanadium.go index 4941c0a6..de400469 100644 --- a/security/ticket/vanadium.go +++ b/security/ticket/vanadium.go @@ -11,10 +11,10 @@ import ( "strings" "time" - "v.io/v23" + "github.com/grailbio/base/common/log" + v23 "v.io/v23" "v.io/v23/security" "v.io/v23/vom" - "v.io/x/lib/vlog" ) const requiredSuffix = security.ChainSeparator + "_role" @@ -53,7 +53,7 @@ func (b *VanadiumBuilder) newVanadiumTicket(ctx *TicketContext) (TicketVanadiumT return empty, err } - vlog.VI(1).Infof("resultBlessings: %#v", ctx.remoteBlessings) + log.Infof(ctx.ctx, "resultBlessings: %+v", resultBlessings) s, err := base64urlVomEncode(resultBlessings) if err != nil { diff --git a/shutdown/shutdown.go b/shutdown/shutdown.go new file mode 100644 index 00000000..896bbaae --- /dev/null +++ b/shutdown/shutdown.go @@ -0,0 +1,41 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +// Package shutdown implements a global process shutdown mechanism. +// It is used by package github.com/grailbio/base/grail to perform +// graceful shutdown of software components. It is a separate package +// in order to avoid circular dependencies. +package shutdown + +import "sync" + +// Func is the type of function run on shutdowns. +type Func func() + +var ( + mu sync.Mutex + funcs []Func +) + +// Register registers a function to be run in the Init shutdown +// callback. The callbacks will run in the reverse order of +// registration. + +func Register(f Func) { + mu.Lock() + funcs = append(funcs, f) + mu.Unlock() +} + +// Run run callbacks added by Register. This function is not for +// general use. +func Run() { + mu.Lock() + fns := funcs + funcs = nil + mu.Unlock() + for i := len(fns) - 1; i >= 0; i-- { + fns[i]() + } +} diff --git a/simd/add_amd64.go b/simd/add_amd64.go index 1c7372fb..4cb0d8bb 100644 --- a/simd/add_amd64.go +++ b/simd/add_amd64.go @@ -117,8 +117,8 @@ func AddConst8(dst, src []byte, val byte) { addConst8OddSSSE3Asm(unsafe.Pointer(dstHeader.Data), unsafe.Pointer(srcHeader.Data), int(val), srcLen) } -// SubtractFromConst8UnsafeInplace subtracts the given constant from every byte -// of main[], with unsigned underflow. +// SubtractFromConst8UnsafeInplace subtracts every byte of main[] from the +// given constant, with unsigned underflow. // // WARNING: This is a function designed to be used in inner loops, which // assumes without checking that capacity is at least RoundUpPow2(len(main), @@ -138,8 +138,8 @@ func SubtractFromConst8UnsafeInplace(main []byte, val byte) { subtractFromConst8OddInplaceSSSE3Asm(unsafe.Pointer(mainHeader.Data), int(val), mainLen) } -// SubtractFromConst8Inplace subtracts the given constant from every byte of -// main[], with unsigned underflow. +// SubtractFromConst8Inplace subtracts every byte of main[] from the given +// constant, with unsigned underflow. func SubtractFromConst8Inplace(main []byte, val byte) { mainLen := len(main) if mainLen < 16 { diff --git a/simd/add_generic.go b/simd/add_generic.go new file mode 100644 index 00000000..4dcabd0b --- /dev/null +++ b/simd/add_generic.go @@ -0,0 +1,123 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +// +build !amd64 appengine + +package simd + +// AddConst8UnsafeInplace adds the given constant to every byte of main[], with +// unsigned overflow. +// +// WARNING: This is a function designed to be used in inner loops, which +// assumes without checking that capacity is at least RoundUpPow2(len(main), +// bytesPerVec). It also assumes that the caller does not care if a few bytes +// past the end of main[] are changed. Use the safe version of this function +// if any of these properties are problematic. +// These assumptions are always satisfied when the last +// potentially-size-increasing operation on main[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(). +func AddConst8UnsafeInplace(main []byte, val byte) { + for i, x := range main { + main[i] = x + val + } +} + +// AddConst8Inplace adds the given constant to every byte of main[], with +// unsigned overflow. +func AddConst8Inplace(main []byte, val byte) { + for i, x := range main { + main[i] = x + val + } +} + +// AddConst8Unsafe sets dst[pos] := src[pos] + val for every byte in src (with +// the usual unsigned overflow). +// +// WARNING: This is a function designed to be used in inner loops, which makes +// assumptions about length and capacity which aren't checked at runtime. Use +// the safe version of this function when that's a problem. +// Assumptions #2-3 are always satisfied when the last +// potentially-size-increasing operation on src[] is {Re}makeUnsafe(), +// ResizeUnsafe() or XcapUnsafe(), and the same is true for dst[]. +// +// 1. len(src) and len(dst) are equal. +// +// 2. Capacities are at least RoundUpPow2(len(src) + 1, bytesPerVec). +// +// 3. The caller does not care if a few bytes past the end of dst[] are +// changed. +func AddConst8Unsafe(dst, src []byte, val byte) { + for i, x := range src { + dst[i] = x + val + } +} + +// AddConst8 sets dst[pos] := src[pos] + val for every byte in src (with the +// usual unsigned overflow). It panics if len(src) != len(dst). +func AddConst8(dst, src []byte, val byte) { + if len(dst) != len(src) { + panic("AddConst8() requires len(src) == len(dst).") + } + for i, x := range src { + dst[i] = x + val + } +} + +// SubtractFromConst8UnsafeInplace subtracts every byte of main[] from the +// given constant, with unsigned underflow. +// +// WARNING: This is a function designed to be used in inner loops, which +// assumes without checking that capacity is at least RoundUpPow2(len(main), +// bytesPerVec). It also assumes that the caller does not care if a few bytes +// past the end of main[] are changed. Use the safe version of this function +// if any of these properties are problematic. +// These assumptions are always satisfied when the last +// potentially-size-increasing operation on main[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(). +func SubtractFromConst8UnsafeInplace(main []byte, val byte) { + for i, x := range main { + main[i] = val - x + } +} + +// SubtractFromConst8Inplace subtracts every byte of main[] from the given +// constant, with unsigned underflow. +func SubtractFromConst8Inplace(main []byte, val byte) { + for i, x := range main { + main[i] = val - x + } +} + +// SubtractFromConst8Unsafe sets dst[pos] := val - src[pos] for every byte in +// src (with the usual unsigned overflow). +// +// WARNING: This is a function designed to be used in inner loops, which makes +// assumptions about length and capacity which aren't checked at runtime. Use +// the safe version of this function when that's a problem. +// Assumptions #2-3 are always satisfied when the last +// potentially-size-increasing operation on src[] is {Re}makeUnsafe(), +// ResizeUnsafe() or XcapUnsafe(), and the same is true for dst[]. +// +// 1. len(src) and len(dst) are equal. +// +// 2. Capacities are at least RoundUpPow2(len(src) + 1, bytesPerVec). +// +// 3. The caller does not care if a few bytes past the end of dst[] are +// changed. +func SubtractFromConst8Unsafe(dst, src []byte, val byte) { + for i, x := range src { + dst[i] = val - x + } +} + +// SubtractFromConst8 sets dst[pos] := val - src[pos] for every byte in src +// (with the usual unsigned overflow). It panics if len(src) != len(dst). +func SubtractFromConst8(dst, src []byte, val byte) { + if len(dst) != len(src) { + panic("SubtractFromConst8() requires len(src) == len(dst).") + } + for i, x := range src { + dst[i] = val - x + } +} diff --git a/simd/add_test.go b/simd/add_test.go index abf0d2e4..1f99b431 100644 --- a/simd/add_test.go +++ b/simd/add_test.go @@ -7,111 +7,11 @@ package simd_test import ( "bytes" "math/rand" - "runtime" "testing" "github.com/grailbio/base/simd" ) -/* -Initial benchmark results: - MacBook Pro (15-inch, 2016) - 2.7 GHz Intel Core i7, 16 GB 2133 MHz LPDDR3 - -Benchmark_AddConstShort1-8 20 65355577 ns/op -Benchmark_AddConstShort4-8 100 22503200 ns/op -Benchmark_AddConstShortMax-8 100 18402022 ns/op -Benchmark_AddConstLong1-8 1 1314456098 ns/op -Benchmark_AddConstLong4-8 1 1963028701 ns/op -Benchmark_AddConstLongMax-8 1 2640851500 ns/op - -For comparison, addConst8Slow: -Benchmark_AddConstShort1-8 3 394073399 ns/op -Benchmark_AddConstShort4-8 10 112302717 ns/op -Benchmark_AddConstShortMax-8 10 101881678 ns/op -Benchmark_AddConstLong1-8 1 5806941582 ns/op -Benchmark_AddConstLong4-8 1 2451731455 ns/op -Benchmark_AddConstLongMax-8 1 3305500509 ns/op -*/ - -func addConstSubtask(dst []byte, nIter int) int { - for iter := 0; iter < nIter; iter++ { - simd.AddConst8UnsafeInplace(dst, 33) - } - return int(dst[0]) -} - -func addConstSubtaskFuture(dst []byte, nIter int) chan int { - future := make(chan int) - go func() { future <- addConstSubtask(dst, nIter) }() - return future -} - -func multiAddConst(dsts [][]byte, cpus int, nJob int) { - sumFutures := make([]chan int, cpus) - shardSizeBase := nJob / cpus - shardRemainder := nJob - shardSizeBase*cpus - shardSizeP1 := shardSizeBase + 1 - var taskIdx int - for ; taskIdx < shardRemainder; taskIdx++ { - sumFutures[taskIdx] = addConstSubtaskFuture(dsts[taskIdx], shardSizeP1) - } - for ; taskIdx < cpus; taskIdx++ { - sumFutures[taskIdx] = addConstSubtaskFuture(dsts[taskIdx], shardSizeBase) - } - var sum int - for taskIdx = 0; taskIdx < cpus; taskIdx++ { - sum += <-sumFutures[taskIdx] - } -} - -func benchmarkAddConst(cpus int, nByte int, nJob int, b *testing.B) { - if cpus > runtime.NumCPU() { - b.Skipf("only have %v cpus", runtime.NumCPU()) - } - - mainSlices := make([][]byte, cpus) - for ii := range mainSlices { - // Add 63 to prevent false sharing. - newArr := simd.MakeUnsafe(nByte + 63) - for jj := 0; jj < nByte; jj++ { - newArr[jj] = byte(jj * 3) - } - mainSlices[ii] = newArr[:nByte] - } - for i := 0; i < b.N; i++ { - multiAddConst(mainSlices, cpus, nJob) - } -} - -// Base sequence in length-150 .bam read occupies 75 bytes, so 75 is a good -// size for the short-array benchmark. -func Benchmark_AddConstShort1(b *testing.B) { - benchmarkAddConst(1, 75, 9999999, b) -} - -func Benchmark_AddConstShort4(b *testing.B) { - benchmarkAddConst(4, 75, 9999999, b) -} - -func Benchmark_AddConstShortMax(b *testing.B) { - benchmarkAddConst(runtime.NumCPU(), 75, 9999999, b) -} - -// GRCh37 chromosome 1 length is 249250621, so that's a plausible long-array -// use case. -func Benchmark_AddConstLong1(b *testing.B) { - benchmarkAddConst(1, 249250621, 50, b) -} - -func Benchmark_AddConstLong4(b *testing.B) { - benchmarkAddConst(4, 249250621, 50, b) -} - -func Benchmark_AddConstLongMax(b *testing.B) { - benchmarkAddConst(runtime.NumCPU(), 249250621, 50, b) -} - func addConst8Slow(dst []byte, val byte) { // strangely, this takes ~35% less time than the single-parameter for loop on // the AddConstLong4 benchmark, though performance is usually @@ -180,6 +80,58 @@ func TestAddConst(t *testing.T) { } } +/* +Benchmark results: + MacBook Pro (15-inch, 2016) + 2.7 GHz Intel Core i7, 16 GB 2133 MHz LPDDR3 + +Benchmark_AddConst8Inplace/SIMDShort1Cpu-8 20 94449590 ns/op +Benchmark_AddConst8Inplace/SIMDShortHalfCpu-8 50 28197917 ns/op +Benchmark_AddConst8Inplace/SIMDShortAllCpu-8 50 27452313 ns/op +Benchmark_AddConst8Inplace/SIMDLong1Cpu-8 1 1145256373 ns/op +Benchmark_AddConst8Inplace/SIMDLongHalfCpu-8 2 959236835 ns/op +Benchmark_AddConst8Inplace/SIMDLongAllCpu-8 2 982555560 ns/op +Benchmark_AddConst8Inplace/SlowShort1Cpu-8 2 707287108 ns/op +Benchmark_AddConst8Inplace/SlowShortHalfCpu-8 10 199415710 ns/op +Benchmark_AddConst8Inplace/SlowShortAllCpu-8 5 245220685 ns/op +Benchmark_AddConst8Inplace/SlowLong1Cpu-8 1 5480013373 ns/op +Benchmark_AddConst8Inplace/SlowLongHalfCpu-8 1 1467424090 ns/op +Benchmark_AddConst8Inplace/SlowLongAllCpu-8 1 1554565031 ns/op +*/ + +func addConst8InplaceSimdSubtask(dst, src []byte, nIter int) int { + for iter := 0; iter < nIter; iter++ { + simd.AddConst8Inplace(dst, 33) + } + return int(dst[0]) +} + +func addConst8InplaceSlowSubtask(dst, src []byte, nIter int) int { + for iter := 0; iter < nIter; iter++ { + addConst8Slow(dst, 33) + } + return int(dst[0]) +} + +func Benchmark_AddConst8Inplace(b *testing.B) { + funcs := []taggedMultiBenchFunc{ + { + f: addConst8InplaceSimdSubtask, + tag: "SIMD", + }, + { + f: addConst8InplaceSlowSubtask, + tag: "Slow", + }, + } + for _, f := range funcs { + multiBenchmark(f.f, f.tag+"Short", 150, 0, 9999999, b) + // GRCh37 chromosome 1 length is 249250621, so that's a plausible + // long-array use case. + multiBenchmark(f.f, f.tag+"Long", 249250621, 0, 50, b) + } +} + func subtractFromConst8Slow(dst []byte, val byte) { for idx, dstByte := range dst { dst[idx] = val - dstByte diff --git a/simd/and_amd64.go b/simd/and_amd64.go index 6f9003d2..720ad97b 100644 --- a/simd/and_amd64.go +++ b/simd/and_amd64.go @@ -1,8 +1,10 @@ -// Code generated from " ../gtl/generate.py --prefix=And -DOPCHAR=& --package=simd --output=and_amd64.go bitwise_amd64.go.tpl ". DO NOT EDIT. -// Copyright 2018 GRAIL, Inc. All rights reserved. +// Code generated by "../gtl/generate.py --prefix=And -DOPCHAR=& --package=simd --output=and_amd64.go bitwise_amd64.go.tpl". DO NOT EDIT. + +// Copyright 2021 GRAIL, Inc. All rights reserved. // Use of this source code is governed by the Apache-2.0 // license that can be found in the LICENSE file. +//go:build amd64 && !appengine // +build amd64,!appengine package simd @@ -12,7 +14,7 @@ import ( "unsafe" ) -// AndUnsafeInplace sets main[pos] := arg[pos] & main[pos] for every position +// AndUnsafeInplace sets main[pos] := main[pos] & arg[pos] for every position // in main[]. // // WARNING: This is a function designed to be used in inner loops, which makes @@ -30,18 +32,18 @@ import ( // changed. func AndUnsafeInplace(main, arg []byte) { mainLen := len(main) - argHeader := (*reflect.SliceHeader)(unsafe.Pointer(&arg)) - mainHeader := (*reflect.SliceHeader)(unsafe.Pointer(&main)) - argWordsIter := unsafe.Pointer(argHeader.Data) - mainWordsIter := unsafe.Pointer(mainHeader.Data) + argData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&arg)).Data) + mainData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&main)).Data) + argWordsIter := argData + mainWordsIter := mainData if mainLen > 2*BytesPerWord { nWordMinus2 := (mainLen - BytesPerWord - 1) >> Log2BytesPerWord for widx := 0; widx < nWordMinus2; widx++ { mainWord := *((*uintptr)(mainWordsIter)) argWord := *((*uintptr)(argWordsIter)) *((*uintptr)(mainWordsIter)) = mainWord & argWord - mainWordsIter = unsafe.Pointer(uintptr(mainWordsIter) + BytesPerWord) - argWordsIter = unsafe.Pointer(uintptr(argWordsIter) + BytesPerWord) + mainWordsIter = unsafe.Add(mainWordsIter, BytesPerWord) + argWordsIter = unsafe.Add(argWordsIter, BytesPerWord) } } else if mainLen <= BytesPerWord { mainWord := *((*uintptr)(mainWordsIter)) @@ -56,8 +58,8 @@ func AndUnsafeInplace(main, arg []byte) { mainWord1 := *((*uintptr)(mainWordsIter)) argWord1 := *((*uintptr)(argWordsIter)) finalOffset := uintptr(mainLen - BytesPerWord) - mainFinalWordPtr := unsafe.Pointer(mainHeader.Data + finalOffset) - argFinalWordPtr := unsafe.Pointer(argHeader.Data + finalOffset) + mainFinalWordPtr := unsafe.Add(mainData, finalOffset) + argFinalWordPtr := unsafe.Add(argData, finalOffset) mainWord2 := *((*uintptr)(mainFinalWordPtr)) argWord2 := *((*uintptr)(argFinalWordPtr)) *((*uintptr)(mainWordsIter)) = mainWord1 & argWord1 @@ -82,25 +84,25 @@ func AndInplace(main, arg []byte) { } return } - argHeader := (*reflect.SliceHeader)(unsafe.Pointer(&arg)) - mainHeader := (*reflect.SliceHeader)(unsafe.Pointer(&main)) - argWordsIter := unsafe.Pointer(argHeader.Data) - mainWordsIter := unsafe.Pointer(mainHeader.Data) + argData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&arg)).Data) + mainData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&main)).Data) + argWordsIter := argData + mainWordsIter := mainData if mainLen > 2*BytesPerWord { nWordMinus2 := (mainLen - BytesPerWord - 1) >> Log2BytesPerWord for widx := 0; widx < nWordMinus2; widx++ { mainWord := *((*uintptr)(mainWordsIter)) argWord := *((*uintptr)(argWordsIter)) *((*uintptr)(mainWordsIter)) = mainWord & argWord - mainWordsIter = unsafe.Pointer(uintptr(mainWordsIter) + BytesPerWord) - argWordsIter = unsafe.Pointer(uintptr(argWordsIter) + BytesPerWord) + mainWordsIter = unsafe.Add(mainWordsIter, BytesPerWord) + argWordsIter = unsafe.Add(argWordsIter, BytesPerWord) } } mainWord1 := *((*uintptr)(mainWordsIter)) argWord1 := *((*uintptr)(argWordsIter)) finalOffset := uintptr(mainLen - BytesPerWord) - mainFinalWordPtr := unsafe.Pointer(mainHeader.Data + finalOffset) - argFinalWordPtr := unsafe.Pointer(argHeader.Data + finalOffset) + mainFinalWordPtr := unsafe.Add(mainData, finalOffset) + argFinalWordPtr := unsafe.Add(argData, finalOffset) mainWord2 := *((*uintptr)(mainFinalWordPtr)) argWord2 := *((*uintptr)(argFinalWordPtr)) *((*uintptr)(mainWordsIter)) = mainWord1 & argWord1 @@ -135,9 +137,9 @@ func AndUnsafe(dst, src1, src2 []byte) { src1Word := *((*uintptr)(src1Iter)) src2Word := *((*uintptr)(src2Iter)) *((*uintptr)(dstIter)) = src1Word & src2Word - src1Iter = unsafe.Pointer(uintptr(src1Iter) + BytesPerWord) - src2Iter = unsafe.Pointer(uintptr(src2Iter) + BytesPerWord) - dstIter = unsafe.Pointer(uintptr(dstIter) + BytesPerWord) + src1Iter = unsafe.Add(src1Iter, BytesPerWord) + src2Iter = unsafe.Add(src2Iter, BytesPerWord) + dstIter = unsafe.Add(dstIter, BytesPerWord) } } @@ -154,27 +156,27 @@ func And(dst, src1, src2 []byte) { } return } - src1Header := (*reflect.SliceHeader)(unsafe.Pointer(&src1)) - src2Header := (*reflect.SliceHeader)(unsafe.Pointer(&src2)) - dstHeader := (*reflect.SliceHeader)(unsafe.Pointer(&dst)) + src1Data := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&src1)).Data) + src2Data := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&src2)).Data) + dstData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&dst)).Data) nWordMinus1 := (dstLen - 1) >> Log2BytesPerWord - src1Iter := unsafe.Pointer(src1Header.Data) - src2Iter := unsafe.Pointer(src2Header.Data) - dstIter := unsafe.Pointer(dstHeader.Data) + src1Iter := src1Data + src2Iter := src2Data + dstIter := dstData for widx := 0; widx < nWordMinus1; widx++ { src1Word := *((*uintptr)(src1Iter)) src2Word := *((*uintptr)(src2Iter)) *((*uintptr)(dstIter)) = src1Word & src2Word - src1Iter = unsafe.Pointer(uintptr(src1Iter) + BytesPerWord) - src2Iter = unsafe.Pointer(uintptr(src2Iter) + BytesPerWord) - dstIter = unsafe.Pointer(uintptr(dstIter) + BytesPerWord) + src1Iter = unsafe.Add(src1Iter, BytesPerWord) + src2Iter = unsafe.Add(src2Iter, BytesPerWord) + dstIter = unsafe.Add(dstIter, BytesPerWord) } // No store-forwarding problem here. finalOffset := uintptr(dstLen - BytesPerWord) - src1Iter = unsafe.Pointer(src1Header.Data + finalOffset) - src2Iter = unsafe.Pointer(src2Header.Data + finalOffset) - dstIter = unsafe.Pointer(dstHeader.Data + finalOffset) + src1Iter = unsafe.Add(src1Data, finalOffset) + src2Iter = unsafe.Add(src2Data, finalOffset) + dstIter = unsafe.Add(dstData, finalOffset) src1Word := *((*uintptr)(src1Iter)) src2Word := *((*uintptr)(src2Iter)) *((*uintptr)(dstIter)) = src1Word & src2Word @@ -197,14 +199,14 @@ func And(dst, src1, src2 []byte) { func AndConst8UnsafeInplace(main []byte, val byte) { mainLen := len(main) argWord := 0x101010101010101 * uintptr(val) - mainHeader := (*reflect.SliceHeader)(unsafe.Pointer(&main)) - mainWordsIter := unsafe.Pointer(mainHeader.Data) + mainData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&main)).Data) + mainWordsIter := mainData if mainLen > 2*BytesPerWord { nWordMinus2 := (mainLen - BytesPerWord - 1) >> Log2BytesPerWord for widx := 0; widx < nWordMinus2; widx++ { mainWord := *((*uintptr)(mainWordsIter)) *((*uintptr)(mainWordsIter)) = mainWord & argWord - mainWordsIter = unsafe.Pointer(uintptr(mainWordsIter) + BytesPerWord) + mainWordsIter = unsafe.Add(mainWordsIter, BytesPerWord) } } else if mainLen <= BytesPerWord { mainWord := *((*uintptr)(mainWordsIter)) @@ -213,7 +215,7 @@ func AndConst8UnsafeInplace(main []byte, val byte) { } mainWord1 := *((*uintptr)(mainWordsIter)) finalOffset := uintptr(mainLen - BytesPerWord) - mainFinalWordPtr := unsafe.Pointer(mainHeader.Data + finalOffset) + mainFinalWordPtr := unsafe.Add(mainData, finalOffset) mainWord2 := *((*uintptr)(mainFinalWordPtr)) *((*uintptr)(mainWordsIter)) = mainWord1 & argWord *((*uintptr)(mainFinalWordPtr)) = mainWord2 & argWord @@ -230,19 +232,19 @@ func AndConst8Inplace(main []byte, val byte) { return } argWord := 0x101010101010101 * uintptr(val) - mainHeader := (*reflect.SliceHeader)(unsafe.Pointer(&main)) - mainWordsIter := unsafe.Pointer(mainHeader.Data) + mainData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&main)).Data) + mainWordsIter := mainData if mainLen > 2*BytesPerWord { nWordMinus2 := (mainLen - BytesPerWord - 1) >> Log2BytesPerWord for widx := 0; widx < nWordMinus2; widx++ { mainWord := *((*uintptr)(mainWordsIter)) *((*uintptr)(mainWordsIter)) = mainWord & argWord - mainWordsIter = unsafe.Pointer(uintptr(mainWordsIter) + BytesPerWord) + mainWordsIter = unsafe.Add(mainWordsIter, BytesPerWord) } } mainWord1 := *((*uintptr)(mainWordsIter)) finalOffset := uintptr(mainLen - BytesPerWord) - mainFinalWordPtr := unsafe.Pointer(mainHeader.Data + finalOffset) + mainFinalWordPtr := unsafe.Add(mainData, finalOffset) mainWord2 := *((*uintptr)(mainFinalWordPtr)) *((*uintptr)(mainWordsIter)) = mainWord1 & argWord *((*uintptr)(mainFinalWordPtr)) = mainWord2 & argWord @@ -274,8 +276,8 @@ func AndConst8Unsafe(dst, src []byte, val byte) { for widx := 0; widx < nWord; widx++ { srcWord := *((*uintptr)(srcIter)) *((*uintptr)(dstIter)) = srcWord & argWord - srcIter = unsafe.Pointer(uintptr(srcIter) + BytesPerWord) - dstIter = unsafe.Pointer(uintptr(dstIter) + BytesPerWord) + srcIter = unsafe.Add(srcIter, BytesPerWord) + dstIter = unsafe.Add(dstIter, BytesPerWord) } } @@ -292,22 +294,22 @@ func AndConst8(dst, src []byte, val byte) { } return } - srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src)) - dstHeader := (*reflect.SliceHeader)(unsafe.Pointer(&dst)) + srcData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&src)).Data) + dstData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&dst)).Data) nWordMinus1 := (dstLen - 1) >> Log2BytesPerWord argWord := 0x101010101010101 * uintptr(val) - srcIter := unsafe.Pointer(srcHeader.Data) - dstIter := unsafe.Pointer(dstHeader.Data) + srcIter := unsafe.Pointer(srcData) + dstIter := unsafe.Pointer(dstData) for widx := 0; widx < nWordMinus1; widx++ { srcWord := *((*uintptr)(srcIter)) *((*uintptr)(dstIter)) = srcWord & argWord - srcIter = unsafe.Pointer(uintptr(srcIter) + BytesPerWord) - dstIter = unsafe.Pointer(uintptr(dstIter) + BytesPerWord) + srcIter = unsafe.Add(srcIter, BytesPerWord) + dstIter = unsafe.Add(dstIter, BytesPerWord) } finalOffset := uintptr(dstLen - BytesPerWord) - srcIter = unsafe.Pointer(srcHeader.Data + finalOffset) - dstIter = unsafe.Pointer(dstHeader.Data + finalOffset) + srcIter = unsafe.Add(srcData, finalOffset) + dstIter = unsafe.Add(dstData, finalOffset) srcWord := *((*uintptr)(srcIter)) *((*uintptr)(dstIter)) = srcWord & argWord } diff --git a/simd/and_generic.go b/simd/and_generic.go new file mode 100644 index 00000000..68b99027 --- /dev/null +++ b/simd/and_generic.go @@ -0,0 +1,135 @@ +// Code generated by " ../gtl/generate.py --prefix=And -DOPCHAR=& --package=simd --output=and_generic.go bitwise_generic.go.tpl ". DO NOT EDIT. + +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +// +build !amd64 appengine + +package simd + +// AndUnsafeInplace sets main[pos] := main[pos] & arg[pos] for every position +// in main[]. +// +// WARNING: This is a function designed to be used in inner loops, which makes +// assumptions about length and capacity which aren't checked at runtime. Use +// the safe version of this function when that's a problem. +// Assumptions #2-3 are always satisfied when the last +// potentially-size-increasing operation on arg[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(), and the same is true for main[]. +// +// 1. len(arg) and len(main) must be equal. +// +// 2. Capacities are at least RoundUpPow2(len(main) + 1, bytesPerVec). +// +// 3. The caller does not care if a few bytes past the end of main[] are +// changed. +func AndUnsafeInplace(main, arg []byte) { + for i, x := range main { + main[i] = x & arg[i] + } +} + +// AndInplace sets main[pos] := main[pos] & arg[pos] for every position in +// main[]. It panics if slice lengths don't match. +func AndInplace(main, arg []byte) { + if len(arg) != len(main) { + panic("AndInplace() requires len(arg) == len(main).") + } + for i, x := range main { + main[i] = x & arg[i] + } +} + +// AndUnsafe sets dst[pos] := src1[pos] & src2[pos] for every position in dst. +// +// WARNING: This is a function designed to be used in inner loops, which makes +// assumptions about length and capacity which aren't checked at runtime. Use +// the safe version of this function when that's a problem. +// Assumptions #2-3 are always satisfied when the last +// potentially-size-increasing operation on src1[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(), and the same is true for src2[] and dst[]. +// +// 1. len(src1), len(src2), and len(dst) must be equal. +// +// 2. Capacities are at least RoundUpPow2(len(dst) + 1, bytesPerVec). +// +// 3. The caller does not care if a few bytes past the end of dst[] are +// changed. +func AndUnsafe(dst, src1, src2 []byte) { + for i, x := range src1 { + dst[i] = x & src2[i] + } +} + +// And sets dst[pos] := src1[pos] & src2[pos] for every position in dst. It +// panics if slice lengths don't match. +func And(dst, src1, src2 []byte) { + dstLen := len(dst) + if (len(src1) != dstLen) || (len(src2) != dstLen) { + panic("And() requires len(src1) == len(src2) == len(dst).") + } + for i, x := range src1 { + dst[i] = x & src2[i] + } +} + +// AndConst8UnsafeInplace sets main[pos] := main[pos] & val for every position +// in main[]. +// +// WARNING: This is a function designed to be used in inner loops, which makes +// assumptions about length and capacity which aren't checked at runtime. Use +// the safe version of this function when that's a problem. +// These assumptions are always satisfied when the last +// potentially-size-increasing operation on main[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(). +// +// 1. cap(main) is at least RoundUpPow2(len(main) + 1, bytesPerVec). +// +// 2. The caller does not care if a few bytes past the end of main[] are +// changed. +func AndConst8UnsafeInplace(main []byte, val byte) { + for i, x := range main { + main[i] = x & val + } +} + +// AndConst8Inplace sets main[pos] := main[pos] & val for every position in +// main[]. +func AndConst8Inplace(main []byte, val byte) { + for i, x := range main { + main[i] = x & val + } +} + +// AndConst8Unsafe sets dst[pos] := src[pos] & val for every position in dst. +// +// WARNING: This is a function designed to be used in inner loops, which makes +// assumptions about length and capacity which aren't checked at runtime. Use +// the safe version of this function when that's a problem. +// Assumptions #2-3 are always satisfied when the last +// potentially-size-increasing operation on src[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(), and the same is true for dst[]. +// +// 1. len(src) and len(dst) must be equal. +// +// 2. Capacities are at least RoundUpPow2(len(dst) + 1, bytesPerVec). +// +// 3. The caller does not care if a few bytes past the end of dst[] are +// changed. +func AndConst8Unsafe(dst, src []byte, val byte) { + for i, x := range src { + dst[i] = x & val + } +} + +// AndConst8 sets dst[pos] := src[pos] & val for every position in dst. It +// panics if slice lengths don't match. +func AndConst8(dst, src []byte, val byte) { + if len(src) != len(dst) { + panic("AndConst8() requires len(src) == len(dst).") + } + for i, x := range src { + dst[i] = x & val + } +} diff --git a/simd/bitwise_amd64.go.tpl b/simd/bitwise_amd64.go.tpl index f7eaeef6..e58fbe09 100644 --- a/simd/bitwise_amd64.go.tpl +++ b/simd/bitwise_amd64.go.tpl @@ -1,4 +1,4 @@ -// Copyright 2018 GRAIL, Inc. All rights reserved. +// Copyright 2021 GRAIL, Inc. All rights reserved. // Use of this source code is governed by the Apache-2.0 // license that can be found in the LICENSE file. @@ -11,7 +11,7 @@ import ( "unsafe" ) -// ZZUnsafeInplace sets main[pos] := arg[pos] OPCHAR main[pos] for every position +// ZZUnsafeInplace sets main[pos] := main[pos] OPCHAR arg[pos] for every position // in main[]. // // WARNING: This is a function designed to be used in inner loops, which makes @@ -29,18 +29,18 @@ import ( // changed. func ZZUnsafeInplace(main, arg []byte) { mainLen := len(main) - argHeader := (*reflect.SliceHeader)(unsafe.Pointer(&arg)) - mainHeader := (*reflect.SliceHeader)(unsafe.Pointer(&main)) - argWordsIter := unsafe.Pointer(argHeader.Data) - mainWordsIter := unsafe.Pointer(mainHeader.Data) + argData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&arg)).Data) + mainData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&main)).Data) + argWordsIter := argData + mainWordsIter := mainData if mainLen > 2*BytesPerWord { nWordMinus2 := (mainLen - BytesPerWord - 1) >> Log2BytesPerWord for widx := 0; widx < nWordMinus2; widx++ { mainWord := *((*uintptr)(mainWordsIter)) argWord := *((*uintptr)(argWordsIter)) *((*uintptr)(mainWordsIter)) = mainWord OPCHAR argWord - mainWordsIter = unsafe.Pointer(uintptr(mainWordsIter) + BytesPerWord) - argWordsIter = unsafe.Pointer(uintptr(argWordsIter) + BytesPerWord) + mainWordsIter = unsafe.Add(mainWordsIter, BytesPerWord) + argWordsIter = unsafe.Add(argWordsIter, BytesPerWord) } } else if mainLen <= BytesPerWord { mainWord := *((*uintptr)(mainWordsIter)) @@ -55,8 +55,8 @@ func ZZUnsafeInplace(main, arg []byte) { mainWord1 := *((*uintptr)(mainWordsIter)) argWord1 := *((*uintptr)(argWordsIter)) finalOffset := uintptr(mainLen - BytesPerWord) - mainFinalWordPtr := unsafe.Pointer(mainHeader.Data + finalOffset) - argFinalWordPtr := unsafe.Pointer(argHeader.Data + finalOffset) + mainFinalWordPtr := unsafe.Add(mainData, finalOffset) + argFinalWordPtr := unsafe.Add(argData, finalOffset) mainWord2 := *((*uintptr)(mainFinalWordPtr)) argWord2 := *((*uintptr)(argFinalWordPtr)) *((*uintptr)(mainWordsIter)) = mainWord1 OPCHAR argWord1 @@ -81,25 +81,25 @@ func ZZInplace(main, arg []byte) { } return } - argHeader := (*reflect.SliceHeader)(unsafe.Pointer(&arg)) - mainHeader := (*reflect.SliceHeader)(unsafe.Pointer(&main)) - argWordsIter := unsafe.Pointer(argHeader.Data) - mainWordsIter := unsafe.Pointer(mainHeader.Data) + argData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&arg)).Data) + mainData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&main)).Data) + argWordsIter := argData + mainWordsIter := mainData if mainLen > 2*BytesPerWord { nWordMinus2 := (mainLen - BytesPerWord - 1) >> Log2BytesPerWord for widx := 0; widx < nWordMinus2; widx++ { mainWord := *((*uintptr)(mainWordsIter)) argWord := *((*uintptr)(argWordsIter)) *((*uintptr)(mainWordsIter)) = mainWord OPCHAR argWord - mainWordsIter = unsafe.Pointer(uintptr(mainWordsIter) + BytesPerWord) - argWordsIter = unsafe.Pointer(uintptr(argWordsIter) + BytesPerWord) + mainWordsIter = unsafe.Add(mainWordsIter, BytesPerWord) + argWordsIter = unsafe.Add(argWordsIter, BytesPerWord) } } mainWord1 := *((*uintptr)(mainWordsIter)) argWord1 := *((*uintptr)(argWordsIter)) finalOffset := uintptr(mainLen - BytesPerWord) - mainFinalWordPtr := unsafe.Pointer(mainHeader.Data + finalOffset) - argFinalWordPtr := unsafe.Pointer(argHeader.Data + finalOffset) + mainFinalWordPtr := unsafe.Add(mainData, finalOffset) + argFinalWordPtr := unsafe.Add(argData, finalOffset) mainWord2 := *((*uintptr)(mainFinalWordPtr)) argWord2 := *((*uintptr)(argFinalWordPtr)) *((*uintptr)(mainWordsIter)) = mainWord1 OPCHAR argWord1 @@ -134,9 +134,9 @@ func ZZUnsafe(dst, src1, src2 []byte) { src1Word := *((*uintptr)(src1Iter)) src2Word := *((*uintptr)(src2Iter)) *((*uintptr)(dstIter)) = src1Word OPCHAR src2Word - src1Iter = unsafe.Pointer(uintptr(src1Iter) + BytesPerWord) - src2Iter = unsafe.Pointer(uintptr(src2Iter) + BytesPerWord) - dstIter = unsafe.Pointer(uintptr(dstIter) + BytesPerWord) + src1Iter = unsafe.Add(src1Iter, BytesPerWord) + src2Iter = unsafe.Add(src2Iter, BytesPerWord) + dstIter = unsafe.Add(dstIter, BytesPerWord) } } @@ -153,27 +153,27 @@ func ZZ(dst, src1, src2 []byte) { } return } - src1Header := (*reflect.SliceHeader)(unsafe.Pointer(&src1)) - src2Header := (*reflect.SliceHeader)(unsafe.Pointer(&src2)) - dstHeader := (*reflect.SliceHeader)(unsafe.Pointer(&dst)) + src1Data := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&src1)).Data) + src2Data := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&src2)).Data) + dstData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&dst)).Data) nWordMinus1 := (dstLen - 1) >> Log2BytesPerWord - src1Iter := unsafe.Pointer(src1Header.Data) - src2Iter := unsafe.Pointer(src2Header.Data) - dstIter := unsafe.Pointer(dstHeader.Data) + src1Iter := src1Data + src2Iter := src2Data + dstIter := dstData for widx := 0; widx < nWordMinus1; widx++ { src1Word := *((*uintptr)(src1Iter)) src2Word := *((*uintptr)(src2Iter)) *((*uintptr)(dstIter)) = src1Word OPCHAR src2Word - src1Iter = unsafe.Pointer(uintptr(src1Iter) + BytesPerWord) - src2Iter = unsafe.Pointer(uintptr(src2Iter) + BytesPerWord) - dstIter = unsafe.Pointer(uintptr(dstIter) + BytesPerWord) + src1Iter = unsafe.Add(src1Iter, BytesPerWord) + src2Iter = unsafe.Add(src2Iter, BytesPerWord) + dstIter = unsafe.Add(dstIter, BytesPerWord) } // No store-forwarding problem here. finalOffset := uintptr(dstLen - BytesPerWord) - src1Iter = unsafe.Pointer(src1Header.Data + finalOffset) - src2Iter = unsafe.Pointer(src2Header.Data + finalOffset) - dstIter = unsafe.Pointer(dstHeader.Data + finalOffset) + src1Iter = unsafe.Add(src1Data, finalOffset) + src2Iter = unsafe.Add(src2Data, finalOffset) + dstIter = unsafe.Add(dstData, finalOffset) src1Word := *((*uintptr)(src1Iter)) src2Word := *((*uintptr)(src2Iter)) *((*uintptr)(dstIter)) = src1Word OPCHAR src2Word @@ -196,14 +196,14 @@ func ZZ(dst, src1, src2 []byte) { func ZZConst8UnsafeInplace(main []byte, val byte) { mainLen := len(main) argWord := 0x101010101010101 * uintptr(val) - mainHeader := (*reflect.SliceHeader)(unsafe.Pointer(&main)) - mainWordsIter := unsafe.Pointer(mainHeader.Data) + mainData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&main)).Data) + mainWordsIter := mainData if mainLen > 2*BytesPerWord { nWordMinus2 := (mainLen - BytesPerWord - 1) >> Log2BytesPerWord for widx := 0; widx < nWordMinus2; widx++ { mainWord := *((*uintptr)(mainWordsIter)) *((*uintptr)(mainWordsIter)) = mainWord OPCHAR argWord - mainWordsIter = unsafe.Pointer(uintptr(mainWordsIter) + BytesPerWord) + mainWordsIter = unsafe.Add(mainWordsIter, BytesPerWord) } } else if mainLen <= BytesPerWord { mainWord := *((*uintptr)(mainWordsIter)) @@ -212,7 +212,7 @@ func ZZConst8UnsafeInplace(main []byte, val byte) { } mainWord1 := *((*uintptr)(mainWordsIter)) finalOffset := uintptr(mainLen - BytesPerWord) - mainFinalWordPtr := unsafe.Pointer(mainHeader.Data + finalOffset) + mainFinalWordPtr := unsafe.Add(mainData, finalOffset) mainWord2 := *((*uintptr)(mainFinalWordPtr)) *((*uintptr)(mainWordsIter)) = mainWord1 OPCHAR argWord *((*uintptr)(mainFinalWordPtr)) = mainWord2 OPCHAR argWord @@ -229,19 +229,19 @@ func ZZConst8Inplace(main []byte, val byte) { return } argWord := 0x101010101010101 * uintptr(val) - mainHeader := (*reflect.SliceHeader)(unsafe.Pointer(&main)) - mainWordsIter := unsafe.Pointer(mainHeader.Data) + mainData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&main)).Data) + mainWordsIter := mainData if mainLen > 2*BytesPerWord { nWordMinus2 := (mainLen - BytesPerWord - 1) >> Log2BytesPerWord for widx := 0; widx < nWordMinus2; widx++ { mainWord := *((*uintptr)(mainWordsIter)) *((*uintptr)(mainWordsIter)) = mainWord OPCHAR argWord - mainWordsIter = unsafe.Pointer(uintptr(mainWordsIter) + BytesPerWord) + mainWordsIter = unsafe.Add(mainWordsIter, BytesPerWord) } } mainWord1 := *((*uintptr)(mainWordsIter)) finalOffset := uintptr(mainLen - BytesPerWord) - mainFinalWordPtr := unsafe.Pointer(mainHeader.Data + finalOffset) + mainFinalWordPtr := unsafe.Add(mainData, finalOffset) mainWord2 := *((*uintptr)(mainFinalWordPtr)) *((*uintptr)(mainWordsIter)) = mainWord1 OPCHAR argWord *((*uintptr)(mainFinalWordPtr)) = mainWord2 OPCHAR argWord @@ -273,8 +273,8 @@ func ZZConst8Unsafe(dst, src []byte, val byte) { for widx := 0; widx < nWord; widx++ { srcWord := *((*uintptr)(srcIter)) *((*uintptr)(dstIter)) = srcWord OPCHAR argWord - srcIter = unsafe.Pointer(uintptr(srcIter) + BytesPerWord) - dstIter = unsafe.Pointer(uintptr(dstIter) + BytesPerWord) + srcIter = unsafe.Add(srcIter, BytesPerWord) + dstIter = unsafe.Add(dstIter, BytesPerWord) } } @@ -291,22 +291,22 @@ func ZZConst8(dst, src []byte, val byte) { } return } - srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src)) - dstHeader := (*reflect.SliceHeader)(unsafe.Pointer(&dst)) + srcData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&src)).Data) + dstData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&dst)).Data) nWordMinus1 := (dstLen - 1) >> Log2BytesPerWord argWord := 0x101010101010101 * uintptr(val) - srcIter := unsafe.Pointer(srcHeader.Data) - dstIter := unsafe.Pointer(dstHeader.Data) + srcIter := unsafe.Pointer(srcData) + dstIter := unsafe.Pointer(dstData) for widx := 0; widx < nWordMinus1; widx++ { srcWord := *((*uintptr)(srcIter)) *((*uintptr)(dstIter)) = srcWord OPCHAR argWord - srcIter = unsafe.Pointer(uintptr(srcIter) + BytesPerWord) - dstIter = unsafe.Pointer(uintptr(dstIter) + BytesPerWord) + srcIter = unsafe.Add(srcIter, BytesPerWord) + dstIter = unsafe.Add(dstIter, BytesPerWord) } finalOffset := uintptr(dstLen - BytesPerWord) - srcIter = unsafe.Pointer(srcHeader.Data + finalOffset) - dstIter = unsafe.Pointer(dstHeader.Data + finalOffset) + srcIter = unsafe.Add(srcData, finalOffset) + dstIter = unsafe.Add(dstData, finalOffset) srcWord := *((*uintptr)(srcIter)) *((*uintptr)(dstIter)) = srcWord OPCHAR argWord } diff --git a/simd/bitwise_generic.go.tpl b/simd/bitwise_generic.go.tpl new file mode 100644 index 00000000..fb9d5f8d --- /dev/null +++ b/simd/bitwise_generic.go.tpl @@ -0,0 +1,133 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +// +build !amd64 appengine + +package PACKAGE + +// ZZUnsafeInplace sets main[pos] := main[pos] OPCHAR arg[pos] for every position +// in main[]. +// +// WARNING: This is a function designed to be used in inner loops, which makes +// assumptions about length and capacity which aren't checked at runtime. Use +// the safe version of this function when that's a problem. +// Assumptions #2-3 are always satisfied when the last +// potentially-size-increasing operation on arg[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(), and the same is true for main[]. +// +// 1. len(arg) and len(main) must be equal. +// +// 2. Capacities are at least RoundUpPow2(len(main) + 1, bytesPerVec). +// +// 3. The caller does not care if a few bytes past the end of main[] are +// changed. +func ZZUnsafeInplace(main, arg []byte) { + for i, x := range main { + main[i] = x OPCHAR arg[i] + } +} + +// ZZInplace sets main[pos] := main[pos] OPCHAR arg[pos] for every position in +// main[]. It panics if slice lengths don't match. +func ZZInplace(main, arg []byte) { + if len(arg) != len(main) { + panic("ZZInplace() requires len(arg) == len(main).") + } + for i, x := range main { + main[i] = x OPCHAR arg[i] + } +} + +// ZZUnsafe sets dst[pos] := src1[pos] OPCHAR src2[pos] for every position in dst. +// +// WARNING: This is a function designed to be used in inner loops, which makes +// assumptions about length and capacity which aren't checked at runtime. Use +// the safe version of this function when that's a problem. +// Assumptions #2-3 are always satisfied when the last +// potentially-size-increasing operation on src1[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(), and the same is true for src2[] and dst[]. +// +// 1. len(src1), len(src2), and len(dst) must be equal. +// +// 2. Capacities are at least RoundUpPow2(len(dst) + 1, bytesPerVec). +// +// 3. The caller does not care if a few bytes past the end of dst[] are +// changed. +func ZZUnsafe(dst, src1, src2 []byte) { + for i, x := range src1 { + dst[i] = x OPCHAR src2[i] + } +} + +// ZZ sets dst[pos] := src1[pos] OPCHAR src2[pos] for every position in dst. It +// panics if slice lengths don't match. +func ZZ(dst, src1, src2 []byte) { + dstLen := len(dst) + if (len(src1) != dstLen) || (len(src2) != dstLen) { + panic("ZZ() requires len(src1) == len(src2) == len(dst).") + } + for i, x := range src1 { + dst[i] = x OPCHAR src2[i] + } +} + +// ZZConst8UnsafeInplace sets main[pos] := main[pos] OPCHAR val for every position +// in main[]. +// +// WARNING: This is a function designed to be used in inner loops, which makes +// assumptions about length and capacity which aren't checked at runtime. Use +// the safe version of this function when that's a problem. +// These assumptions are always satisfied when the last +// potentially-size-increasing operation on main[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(). +// +// 1. cap(main) is at least RoundUpPow2(len(main) + 1, bytesPerVec). +// +// 2. The caller does not care if a few bytes past the end of main[] are +// changed. +func ZZConst8UnsafeInplace(main []byte, val byte) { + for i, x := range main { + main[i] = x OPCHAR val + } +} + +// ZZConst8Inplace sets main[pos] := main[pos] OPCHAR val for every position in +// main[]. +func ZZConst8Inplace(main []byte, val byte) { + for i, x := range main { + main[i] = x OPCHAR val + } +} + +// ZZConst8Unsafe sets dst[pos] := src[pos] OPCHAR val for every position in dst. +// +// WARNING: This is a function designed to be used in inner loops, which makes +// assumptions about length and capacity which aren't checked at runtime. Use +// the safe version of this function when that's a problem. +// Assumptions #2-3 are always satisfied when the last +// potentially-size-increasing operation on src[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(), and the same is true for dst[]. +// +// 1. len(src) and len(dst) must be equal. +// +// 2. Capacities are at least RoundUpPow2(len(dst) + 1, bytesPerVec). +// +// 3. The caller does not care if a few bytes past the end of dst[] are +// changed. +func ZZConst8Unsafe(dst, src []byte, val byte) { + for i, x := range src { + dst[i] = x OPCHAR val + } +} + +// ZZConst8 sets dst[pos] := src[pos] OPCHAR val for every position in dst. It +// panics if slice lengths don't match. +func ZZConst8(dst, src []byte, val byte) { + if len(src) != len(dst) { + panic("ZZConst8() requires len(src) == len(dst).") + } + for i, x := range src { + dst[i] = x OPCHAR val + } +} diff --git a/simd/bitwise_test.go b/simd/bitwise_test.go index c5fbb838..a0b65a77 100644 --- a/simd/bitwise_test.go +++ b/simd/bitwise_test.go @@ -7,133 +7,11 @@ package simd_test import ( "bytes" "math/rand" - "runtime" "testing" "github.com/grailbio/base/simd" ) -/* -Initial benchmark results: - MacBook Pro (15-inch, 2016) - 2.7 GHz Intel Core i7, 16 GB 2133 MHz LPDDR3 - -Benchmark_AndShort1-8 20 79001639 ns/op -Benchmark_AndShort4-8 100 22674670 ns/op -Benchmark_AndShortMax-8 50 31081823 ns/op -Benchmark_AndLong1-8 1 2341570438 ns/op -Benchmark_AndLong4-8 1 2215749534 ns/op -Benchmark_AndLongMax-8 1 3206256001 ns/op - -Benchmark_XorConstShort1-8 20 70140769 ns/op -Benchmark_XorConstShort4-8 100 19764681 ns/op -Benchmark_XorConstShortMax-8 100 18666329 ns/op -Benchmark_XorConstLong1-8 1 1425834375 ns/op -Benchmark_XorConstLong4-8 1 1986577047 ns/op -Benchmark_XorConstLongMax-8 1 2824665438 ns/op - -For reference, andInplaceSlow has the following results: -Benchmark_AndShort1-8 2 568462797 ns/op -Benchmark_AndShort4-8 10 150619381 ns/op -Benchmark_AndShortMax-8 5 228610110 ns/op -Benchmark_AndLong1-8 1 8455390684 ns/op -Benchmark_AndLong4-8 1 5252196746 ns/op -Benchmark_AndLongMax-8 1 4038874956 ns/op - -xorConst8InplaceSlow: -Benchmark_XorConstShort1-8 2 537646968 ns/op -Benchmark_XorConstShort4-8 10 141203425 ns/op -Benchmark_XorConstShortMax-8 10 135202486 ns/op -Benchmark_XorConstLong1-8 1 7982759770 ns/op -Benchmark_XorConstLong4-8 1 4977037903 ns/op -Benchmark_XorConstLongMax-8 1 3903526748 ns/op -*/ - -func andSubtask(main, arg []byte, nIter int) int { - for iter := 0; iter < nIter; iter++ { - simd.AndUnsafeInplace(main, arg) - } - return int(main[0]) -} - -func andSubtaskFuture(main, arg []byte, nIter int) chan int { - future := make(chan int) - go func() { future <- andSubtask(main, arg, nIter) }() - return future -} - -func multiAnd(mains [][]byte, arg []byte, cpus int, nJob int) { - sumFutures := make([]chan int, cpus) - shardSizeBase := nJob / cpus - shardRemainder := nJob - shardSizeBase*cpus - shardSizeP1 := shardSizeBase + 1 - var taskIdx int - for ; taskIdx < shardRemainder; taskIdx++ { - sumFutures[taskIdx] = andSubtaskFuture(mains[taskIdx], arg, shardSizeP1) - } - for ; taskIdx < cpus; taskIdx++ { - sumFutures[taskIdx] = andSubtaskFuture(mains[taskIdx], arg, shardSizeBase) - } - var sum int - for taskIdx = 0; taskIdx < cpus; taskIdx++ { - sum += <-sumFutures[taskIdx] - } -} - -func benchmarkAnd(cpus int, nByte int, nJob int, b *testing.B) { - if cpus > runtime.NumCPU() { - b.Skipf("only have %v cpus", runtime.NumCPU()) - } - - argArr := simd.MakeUnsafe(nByte) - for ii := 0; ii < nByte; ii++ { - argArr[ii] = byte(ii * 6) - } - - mainSlices := make([][]byte, cpus) - for ii := range mainSlices { - // Add 63 to prevent false sharing. - newArr := simd.MakeUnsafe(nByte + 63) - for jj := 0; jj < nByte; jj++ { - newArr[jj] = byte(jj * 3) - } - mainSlices[ii] = newArr[:nByte] - } - for i := 0; i < b.N; i++ { - multiAnd(mainSlices, argArr, cpus, nJob) - } -} - -// Base sequence in length-150 .bam read occupies 75 bytes, so 75 is a good -// size for the short-array benchmark. -func Benchmark_AndShort1(b *testing.B) { - benchmarkAnd(1, 75, 9999999, b) -} - -func Benchmark_AndShort4(b *testing.B) { - benchmarkAnd(4, 75, 9999999, b) -} - -func Benchmark_AndShortMax(b *testing.B) { - benchmarkAnd(runtime.NumCPU(), 75, 9999999, b) -} - -// GRCh37 chromosome 1 length is 249250621, so that's a plausible long-array -// use case. -func Benchmark_AndLong1(b *testing.B) { - benchmarkAnd(1, 249250621, 50, b) -} - -func Benchmark_AndLong4(b *testing.B) { - benchmarkAnd(4, 249250621, 50, b) -} - -func Benchmark_AndLongMax(b *testing.B) { - benchmarkAnd(runtime.NumCPU(), 249250621, 50, b) -} - -// Don't bother with separate benchmarks for Or/Xor/Invmask. - func andInplaceSlow(main, arg []byte) { // Slow, but straightforward-to-verify implementation. for idx := range main { @@ -198,6 +76,60 @@ func TestAnd(t *testing.T) { } } +/* +Benchmark results: + MacBook Pro (15-inch, 2016) + 2.7 GHz Intel Core i7, 16 GB 2133 MHz LPDDR3 + +Benchmark_AndInplace/SIMDShort1Cpu-8 20 91832421 ns/op +Benchmark_AndInplace/SIMDShortHalfCpu-8 50 25323744 ns/op +Benchmark_AndInplace/SIMDShortAllCpu-8 100 23869031 ns/op +Benchmark_AndInplace/SIMDLong1Cpu-8 1 1715379622 ns/op +Benchmark_AndInplace/SIMDLongHalfCpu-8 1 1372591170 ns/op +Benchmark_AndInplace/SIMDLongAllCpu-8 1 1427476449 ns/op +Benchmark_AndInplace/SlowShort1Cpu-8 2 550667201 ns/op +Benchmark_AndInplace/SlowShortHalfCpu-8 10 145756965 ns/op +Benchmark_AndInplace/SlowShortAllCpu-8 10 135311356 ns/op +Benchmark_AndInplace/SlowLong1Cpu-8 1 7711233274 ns/op +Benchmark_AndInplace/SlowLongHalfCpu-8 1 2144409827 ns/op +Benchmark_AndInplace/SlowLongAllCpu-8 1 2158206158 ns/op +*/ + +func andSimdSubtask(dst, src []byte, nIter int) int { + for iter := 0; iter < nIter; iter++ { + simd.AndInplace(dst, src) + } + return int(dst[0]) +} + +func andSlowSubtask(dst, src []byte, nIter int) int { + for iter := 0; iter < nIter; iter++ { + andInplaceSlow(dst, src) + } + return int(dst[0]) +} + +func Benchmark_AndInplace(b *testing.B) { + funcs := []taggedMultiBenchFunc{ + { + f: andSimdSubtask, + tag: "SIMD", + }, + { + f: andSlowSubtask, + tag: "Slow", + }, + } + for _, f := range funcs { + // This is relevant to .bam reads in packed form, so 150/2=75 is a good + // size for the short-array benchmark. + multiBenchmark(f.f, f.tag+"Short", 75, 75, 9999999, b) + multiBenchmark(f.f, f.tag+"Long", 249250621, 249250621, 50, b) + } +} + +// Don't bother with separate benchmarks for Or/Xor/Invmask. + func orInplaceSlow(main, arg []byte) { for idx := range main { main[idx] = main[idx] | arg[idx] @@ -496,84 +428,6 @@ func TestOrConst8(t *testing.T) { } } -func xorConstSubtask(main []byte, nIter int) int { - for iter := 0; iter < nIter; iter++ { - simd.XorConst8UnsafeInplace(main, 3) - } - return int(main[0]) -} - -func xorConstSubtaskFuture(main []byte, nIter int) chan int { - future := make(chan int) - go func() { future <- xorConstSubtask(main, nIter) }() - return future -} - -func multiXorConst(mains [][]byte, cpus int, nJob int) { - sumFutures := make([]chan int, cpus) - shardSizeBase := nJob / cpus - shardRemainder := nJob - shardSizeBase*cpus - shardSizeP1 := shardSizeBase + 1 - var taskIdx int - for ; taskIdx < shardRemainder; taskIdx++ { - sumFutures[taskIdx] = xorConstSubtaskFuture(mains[taskIdx], shardSizeP1) - } - for ; taskIdx < cpus; taskIdx++ { - sumFutures[taskIdx] = xorConstSubtaskFuture(mains[taskIdx], shardSizeBase) - } - var sum int - for taskIdx = 0; taskIdx < cpus; taskIdx++ { - sum += <-sumFutures[taskIdx] - } -} - -func benchmarkXorConst(cpus int, nByte int, nJob int, b *testing.B) { - if cpus > runtime.NumCPU() { - b.Skipf("only have %v cpus", runtime.NumCPU()) - } - - mainSlices := make([][]byte, cpus) - for ii := range mainSlices { - // Add 63 to prevent false sharing. - newArr := simd.MakeUnsafe(nByte + 63) - for jj := 0; jj < nByte; jj++ { - newArr[jj] = byte(jj * 3) - } - mainSlices[ii] = newArr[:nByte] - } - for i := 0; i < b.N; i++ { - multiXorConst(mainSlices, cpus, nJob) - } -} - -// Base sequence in length-150 .bam read occupies 75 bytes, so 75 is a good -// size for the short-array benchmark. -func Benchmark_XorConstShort1(b *testing.B) { - benchmarkXorConst(1, 75, 9999999, b) -} - -func Benchmark_XorConstShort4(b *testing.B) { - benchmarkXorConst(4, 75, 9999999, b) -} - -func Benchmark_XorConstShortMax(b *testing.B) { - benchmarkXorConst(runtime.NumCPU(), 75, 9999999, b) -} - -// GRCh37 chromosome 1 length is 249250621, so that's a plausible long-array -// use case. -func Benchmark_XorConstLong1(b *testing.B) { - benchmarkXorConst(1, 249250621, 50, b) -} - -func Benchmark_XorConstLong4(b *testing.B) { - benchmarkXorConst(4, 249250621, 50, b) -} - -func Benchmark_XorConstLongMax(b *testing.B) { - benchmarkXorConst(runtime.NumCPU(), 249250621, 50, b) -} - func xorConst8InplaceSlow(main []byte, val byte) { for idx, mainByte := range main { main[idx] = mainByte ^ val @@ -637,3 +491,53 @@ func TestXorConst8(t *testing.T) { } } } + +/* +Benchmark results: + MacBook Pro (15-inch, 2016) + 2.7 GHz Intel Core i7, 16 GB 2133 MHz LPDDR3 + +Benchmark_XorConst8Inplace/SIMDShort1Cpu-8 20 79730366 ns/op +Benchmark_XorConst8Inplace/SIMDShortHalfCpu-8 100 21216542 ns/op +Benchmark_XorConst8Inplace/SIMDShortAllCpu-8 100 18902385 ns/op +Benchmark_XorConst8Inplace/SIMDLong1Cpu-8 1 1291770636 ns/op +Benchmark_XorConst8Inplace/SIMDLongHalfCpu-8 2 958003320 ns/op +Benchmark_XorConst8Inplace/SIMDLongAllCpu-8 2 967333286 ns/op +Benchmark_XorConst8Inplace/SlowShort1Cpu-8 3 417781174 ns/op +Benchmark_XorConst8Inplace/SlowShortHalfCpu-8 10 112255124 ns/op +Benchmark_XorConst8Inplace/SlowShortAllCpu-8 10 100138643 ns/op +Benchmark_XorConst8Inplace/SlowLong1Cpu-8 1 5476605564 ns/op +Benchmark_XorConst8Inplace/SlowLongHalfCpu-8 1 1480923705 ns/op +Benchmark_XorConst8Inplace/SlowLongAllCpu-8 1 1588216831 ns/op +*/ + +func xorConst8InplaceSimdSubtask(dst, src []byte, nIter int) int { + for iter := 0; iter < nIter; iter++ { + simd.XorConst8Inplace(dst, 3) + } + return int(dst[0]) +} + +func xorConst8InplaceSlowSubtask(dst, src []byte, nIter int) int { + for iter := 0; iter < nIter; iter++ { + xorConst8InplaceSlow(dst, 3) + } + return int(dst[0]) +} + +func Benchmark_XorConst8Inplace(b *testing.B) { + funcs := []taggedMultiBenchFunc{ + { + f: xorConst8InplaceSimdSubtask, + tag: "SIMD", + }, + { + f: xorConst8InplaceSlowSubtask, + tag: "Slow", + }, + } + for _, f := range funcs { + multiBenchmark(f.f, f.tag+"Short", 75, 0, 9999999, b) + multiBenchmark(f.f, f.tag+"Long", 249250621, 0, 50, b) + } +} diff --git a/simd/cmp_amd64.go b/simd/cmp_amd64.go index bd67fb66..f41a708a 100644 --- a/simd/cmp_amd64.go +++ b/simd/cmp_amd64.go @@ -1,7 +1,8 @@ -// Copyright 2018 GRAIL, Inc. All rights reserved. +// Copyright 2021 GRAIL, Inc. All rights reserved. // Use of this source code is governed by the Apache-2.0 // license that can be found in the LICENSE file. +//go:build amd64 && !appengine // +build amd64,!appengine package simd @@ -41,11 +42,11 @@ func FirstUnequal8Unsafe(arg1, arg2 []byte, startPos int) int { if nByte <= 0 { return endPos } - arg1Header := (*reflect.SliceHeader)(unsafe.Pointer(&arg1)) - arg2Header := (*reflect.SliceHeader)(unsafe.Pointer(&arg2)) + arg1Data := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&arg1)).Data) + arg2Data := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&arg2)).Data) nWordMinus1 := (nByte - 1) >> Log2BytesPerWord - arg1Iter := unsafe.Pointer(arg1Header.Data + uintptr(startPos)) - arg2Iter := unsafe.Pointer(arg2Header.Data + uintptr(startPos)) + arg1Iter := unsafe.Add(arg1Data, startPos) + arg2Iter := unsafe.Add(arg2Data, startPos) // Tried replacing this with simple (non-unrolled) vector-based loops very // similar to the main runtime's go/src/internal/bytealg/compare_amd64.s, but // they were actually worse than the safe function on the short-array @@ -65,8 +66,8 @@ func FirstUnequal8Unsafe(arg1, arg2 []byte, startPos int) int { // matter at all. return startPos + (widx * BytesPerWord) + (bits.TrailingZeros64(uint64(xorWord)) >> 3) } - arg1Iter = unsafe.Pointer(uintptr(arg1Iter) + BytesPerWord) - arg2Iter = unsafe.Pointer(uintptr(arg2Iter) + BytesPerWord) + arg1Iter = unsafe.Add(arg1Iter, BytesPerWord) + arg2Iter = unsafe.Add(arg2Iter, BytesPerWord) } xorWord := (*((*uintptr)(arg1Iter))) ^ (*((*uintptr)(arg2Iter))) if xorWord == 0 { @@ -86,19 +87,15 @@ func FirstUnequal8Unsafe(arg1, arg2 []byte, startPos int) int { // // This is essentially an extension of bytes.Compare(). func FirstUnequal8(arg1, arg2 []byte, startPos int) int { - // This takes ~20-25% longer on the short-array benchmark. + // This takes ~10% longer on the short-array benchmark. endPos := len(arg1) - if endPos != len(arg2) { - panic("FirstUnequal8() requires len(arg1) == len(arg2).") - } - if startPos < 0 { - // This check is kind of paranoid. It's here because - // unsafe.Pointer(arg1Header.Data + uintptr(startPos)) does not - // automatically error out on negative startPos, and it also doesn't hurt - // to protect against (endPos - startPos) integer overflow; but feel free - // to request its removal if you are using this function in a time-critical - // loop. - panic("FirstUnequal8() requires nonnegative startPos.") + if endPos != len(arg2) || (startPos < 0) { + // The startPos < 0 check is kind of paranoid. It's here because + // unsafe.Add(arg1Data, startPos) does not automatically error out on + // negative startPos, and it also doesn't hurt to protect against (endPos - + // startPos) integer overflow; but feel free to request its removal if you + // are using this function in a time-critical loop. + panic("FirstUnequal8() requires len(arg1) == len(arg2) and nonnegative startPos.") } nByte := endPos - startPos if nByte < BytesPerWord { @@ -109,22 +106,22 @@ func FirstUnequal8(arg1, arg2 []byte, startPos int) int { } return endPos } - arg1Header := (*reflect.SliceHeader)(unsafe.Pointer(&arg1)) - arg2Header := (*reflect.SliceHeader)(unsafe.Pointer(&arg2)) + arg1Data := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&arg1)).Data) + arg2Data := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&arg2)).Data) nWordMinus1 := (nByte - 1) >> Log2BytesPerWord - arg1Iter := unsafe.Pointer(arg1Header.Data + uintptr(startPos)) - arg2Iter := unsafe.Pointer(arg2Header.Data + uintptr(startPos)) + arg1Iter := unsafe.Add(arg1Data, startPos) + arg2Iter := unsafe.Add(arg2Data, startPos) for widx := 0; widx < nWordMinus1; widx++ { xorWord := (*((*uintptr)(arg1Iter))) ^ (*((*uintptr)(arg2Iter))) if xorWord != 0 { return startPos + (widx * BytesPerWord) + (bits.TrailingZeros64(uint64(xorWord)) >> 3) } - arg1Iter = unsafe.Pointer(uintptr(arg1Iter) + BytesPerWord) - arg2Iter = unsafe.Pointer(uintptr(arg2Iter) + BytesPerWord) + arg1Iter = unsafe.Add(arg1Iter, BytesPerWord) + arg2Iter = unsafe.Add(arg2Iter, BytesPerWord) } finalOffset := uintptr(endPos - BytesPerWord) - arg1FinalPtr := unsafe.Pointer(arg1Header.Data + finalOffset) - arg2FinalPtr := unsafe.Pointer(arg2Header.Data + finalOffset) + arg1FinalPtr := unsafe.Add(arg1Data, finalOffset) + arg2FinalPtr := unsafe.Add(arg2Data, finalOffset) xorWord := (*((*uintptr)(arg1FinalPtr))) ^ (*((*uintptr)(arg2FinalPtr))) if xorWord == 0 { return endPos diff --git a/simd/cmp_generic.go b/simd/cmp_generic.go new file mode 100644 index 00000000..11e34df2 --- /dev/null +++ b/simd/cmp_generic.go @@ -0,0 +1,140 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +// +build !amd64 appengine + +package simd + +// FirstUnequal8Unsafe scans arg1[startPos:] and arg2[startPos:] for the first +// mismatching byte, returning its position if one is found, or the common +// length if all bytes match (or startPos >= len). This has essentially the +// same speed as bytes.Compare(). +// +// WARNING: This is a function designed to be used in inner loops, which makes +// assumptions about length and capacity which aren't checked at runtime. Use +// the safe version of this function when that's a problem. +// The second assumption is always satisfied when the last +// potentially-size-increasing operation on arg1[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(), and the same is true for arg2[]. +// +// 1. len(arg1) == len(arg2). +// +// 2. Capacities are at least RoundUpPow2(len, bytesPerVec). +func FirstUnequal8Unsafe(arg1, arg2 []byte, startPos int) int { + endPos := len(arg1) + for i := startPos; i < endPos; i++ { + if arg1[i] != arg2[i] { + return i + } + } + return endPos +} + +// FirstUnequal8 scans arg1[startPos:] and arg2[startPos:] for the first +// mismatching byte, returning its position if one is found, or the common +// length if all bytes match (or startPos >= len). It panics if the lengths +// are not identical, or startPos is negative. +// +// This is essentially an extension of bytes.Compare(). +func FirstUnequal8(arg1, arg2 []byte, startPos int) int { + endPos := len(arg1) + if endPos != len(arg2) { + panic("FirstUnequal8() requires len(arg1) == len(arg2).") + } + if startPos < 0 { + panic("FirstUnequal8() requires nonnegative startPos.") + } + for pos := startPos; pos < endPos; pos++ { + if arg1[pos] != arg2[pos] { + return pos + } + } + return endPos +} + +// FirstGreater8Unsafe scans arg[startPos:] for the first value larger than the +// given constant, returning its position if one is found, or len(arg) if all +// bytes are <= (or startPos >= len). +// +// This should only be used when greater values are usually present at ~5% or +// lower frequency. Above that, use a simple for loop. +// +// WARNING: This is a function designed to be used in inner loops, which makes +// assumptions about length and capacity which aren't checked at runtime. Use +// the safe version of this function when that's a problem. +// The second assumption is always satisfied when the last +// potentially-size-increasing operation on arg[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(). +// +// 1. startPos is nonnegative. +// +// 2. cap(arg) >= RoundUpPow2(len, bytesPerVec). +func FirstGreater8Unsafe(arg []byte, val byte, startPos int) int { + endPos := len(arg) + for pos := startPos; pos < endPos; pos++ { + if arg[pos] > val { + return pos + } + } + return endPos +} + +// FirstGreater8 scans arg[startPos:] for the first value larger than the given +// constant, returning its position if one is found, or len(arg) if all bytes +// are <= (or startPos >= len). +// +// This should only be used when greater values are usually present at ~5% or +// lower frequency. Above that, use a simple for loop. +func FirstGreater8(arg []byte, val byte, startPos int) int { + if startPos < 0 { + panic("FirstGreater8() requires nonnegative startPos.") + } + endPos := len(arg) + for pos := startPos; pos < endPos; pos++ { + if arg[pos] > val { + return pos + } + } + return endPos +} + +// FirstLeq8Unsafe scans arg[startPos:] for the first value <= the given +// constant, returning its position if one is found, or len(arg) if all bytes +// are greater (or startPos >= len). +// +// This should only be used when <= values are usually present at ~5% or +// lower frequency. Above that, use a simple for loop. +// +// See warning for FirstGreater8Unsafe. +func FirstLeq8Unsafe(arg []byte, val byte, startPos int) int { + endPos := len(arg) + for pos := startPos; pos < endPos; pos++ { + if arg[pos] <= val { + return pos + } + } + return endPos +} + +// FirstLeq8 scans arg[startPos:] for the first value <= the given constant, +// returning its position if one is found, or len(arg) if all bytes are greater +// (or startPos >= len). +// +// This should only be used when <= values are usually present at ~5% or lower +// frequency. Above that, use a simple for loop. +func FirstLeq8(arg []byte, val byte, startPos int) int { + // This currently has practically no performance penalty relative to the + // Unsafe version, since the implementation is identical except for the + // startPos check. + if startPos < 0 { + panic("FirstLeq8() requires nonnegative startPos.") + } + endPos := len(arg) + for pos := startPos; pos < endPos; pos++ { + if arg[pos] <= val { + return pos + } + } + return endPos +} diff --git a/simd/cmp_test.go b/simd/cmp_test.go index 3f258a17..07b58a8a 100644 --- a/simd/cmp_test.go +++ b/simd/cmp_test.go @@ -6,133 +6,11 @@ package simd_test import ( "math/rand" - "runtime" "testing" "github.com/grailbio/base/simd" ) -/* -Initial benchmark results: - MacBook Pro (15-inch, 2016) - 2.7 GHz Intel Core i7, 16 GB 2133 MHz LPDDR3 - -Benchmark_FirstUnequalShort1-8 20 63531902 ns/op -Benchmark_FirstUnequalShort4-8 100 17527367 ns/op -Benchmark_FirstUnequalShortMax-8 100 16960390 ns/op -Benchmark_FirstUnequalLong1-8 2 730374209 ns/op -Benchmark_FirstUnequalLong4-8 3 334514352 ns/op -Benchmark_FirstUnequalLongMax-8 5 296666922 ns/op - -(bytes.Compare()'s speed is essentially identical.) - -Benchmark_FirstLeqShort1-8 20 66917028 ns/op -Benchmark_FirstLeqShort4-8 100 18748334 ns/op -Benchmark_FirstLeqShortMax-8 100 18819918 ns/op -Benchmark_FirstLeqLong1-8 3 402510849 ns/op -Benchmark_FirstLeqLong4-8 10 118810967 ns/op -Benchmark_FirstLeqLongMax-8 10 122304803 ns/op - -For reference, firstUnequal8Slow has the following results: -Benchmark_FirstUnequalShort1-8 5 255419211 ns/op -Benchmark_FirstUnequalShort4-8 20 72590461 ns/op -Benchmark_FirstUnequalShortMax-8 20 68392202 ns/op -Benchmark_FirstUnequalLong1-8 1 4258976363 ns/op -Benchmark_FirstUnequalLong4-8 1 1088713962 ns/op -Benchmark_FirstUnequalLongMax-8 1 1326682888 ns/op - -firstLeq8Slow: -Benchmark_FirstLeqShort1-8 5 248776883 ns/op -Benchmark_FirstLeqShort4-8 20 67078584 ns/op -Benchmark_FirstLeqShortMax-8 20 65117954 ns/op -Benchmark_FirstLeqLong1-8 1 3972184399 ns/op -Benchmark_FirstLeqLong4-8 1 1069477371 ns/op -Benchmark_FirstLeqLongMax-8 1 1238397122 ns/op -*/ - -func firstUnequalSubtask(arg1, arg2 []byte, nIter int) int { - curPos := 0 - endPos := len(arg1) - for iter := 0; iter < nIter; iter++ { - if curPos >= endPos { - curPos = 0 - } - curPos = simd.FirstUnequal8Unsafe(arg1, arg2, curPos) - curPos++ - } - return curPos -} - -func firstUnequalSubtaskFuture(arg1, arg2 []byte, nIter int) chan int { - future := make(chan int) - go func() { future <- firstUnequalSubtask(arg1, arg2, nIter) }() - return future -} - -func multiFirstUnequal(arg1s, arg2s [][]byte, cpus int, nJob int) { - sumFutures := make([]chan int, cpus) - shardSizeBase := nJob / cpus - shardRemainder := nJob - shardSizeBase*cpus - shardSizeP1 := shardSizeBase + 1 - var taskIdx int - for ; taskIdx < shardRemainder; taskIdx++ { - sumFutures[taskIdx] = firstUnequalSubtaskFuture(arg1s[0], arg2s[0], shardSizeP1) - } - for ; taskIdx < cpus; taskIdx++ { - sumFutures[taskIdx] = firstUnequalSubtaskFuture(arg1s[0], arg2s[0], shardSizeBase) - } - var sum int - for taskIdx = 0; taskIdx < cpus; taskIdx++ { - sum += <-sumFutures[taskIdx] - } -} - -func benchmarkFirstUnequal(cpus int, nByte int, nJob int, b *testing.B) { - if cpus > runtime.NumCPU() { - b.Skipf("only have %v cpus", runtime.NumCPU()) - } - - arg1Slices := make([][]byte, 1) - for ii := range arg1Slices { - // Add 63 to prevent false sharing. - newArr := simd.MakeUnsafe(nByte + 63) - arg1Slices[ii] = newArr[:nByte] - } - arg2Slices := make([][]byte, 1) - for ii := range arg2Slices { - newArr := simd.MakeUnsafe(nByte + 63) - arg2Slices[ii] = newArr[:nByte] - arg2Slices[ii][nByte/2] = 128 - } - for i := 0; i < b.N; i++ { - multiFirstUnequal(arg1Slices, arg2Slices, cpus, nJob) - } -} - -func Benchmark_FirstUnequalShort1(b *testing.B) { - benchmarkFirstUnequal(1, 75, 9999999, b) -} - -func Benchmark_FirstUnequalShort4(b *testing.B) { - benchmarkFirstUnequal(4, 75, 9999999, b) -} - -func Benchmark_FirstUnequalShortMax(b *testing.B) { - benchmarkFirstUnequal(runtime.NumCPU(), 75, 9999999, b) -} - -func Benchmark_FirstUnequalLong1(b *testing.B) { - benchmarkFirstUnequal(1, 249250621, 50, b) -} - -func Benchmark_FirstUnequalLong4(b *testing.B) { - benchmarkFirstUnequal(4, 249250621, 50, b) -} - -func Benchmark_FirstUnequalLongMax(b *testing.B) { - benchmarkFirstUnequal(runtime.NumCPU(), 249250621, 50, b) -} - func firstUnequal8Slow(arg1, arg2 []byte, startPos int) int { // Slow, but straightforward-to-verify implementation. endPos := len(arg1) @@ -189,6 +67,102 @@ func TestFirstUnequal(t *testing.T) { } } +/* +Benchmark results: + MacBook Pro (15-inch, 2016) + 2.7 GHz Intel Core i7, 16 GB 2133 MHz LPDDR3 + +Benchmark_FirstUnequal8/UnsafeShort1Cpu-8 10 104339029 ns/op +Benchmark_FirstUnequal8/UnsafeShortHalfCpu-8 50 28360826 ns/op +Benchmark_FirstUnequal8/UnsafeShortAllCpu-8 100 24272646 ns/op +Benchmark_FirstUnequal8/UnsafeLong1Cpu-8 2 654616638 ns/op +Benchmark_FirstUnequal8/UnsafeLongHalfCpu-8 3 499705618 ns/op +Benchmark_FirstUnequal8/UnsafeLongAllCpu-8 3 477807746 ns/op +Benchmark_FirstUnequal8/SIMDShort1Cpu-8 10 114335599 ns/op +Benchmark_FirstUnequal8/SIMDShortHalfCpu-8 50 30189426 ns/op +Benchmark_FirstUnequal8/SIMDShortAllCpu-8 50 26847829 ns/op +Benchmark_FirstUnequal8/SIMDLong1Cpu-8 2 735662635 ns/op +Benchmark_FirstUnequal8/SIMDLongHalfCpu-8 3 488191229 ns/op +Benchmark_FirstUnequal8/SIMDLongAllCpu-8 3 480315740 ns/op +Benchmark_FirstUnequal8/SlowShort1Cpu-8 2 608618106 ns/op +Benchmark_FirstUnequal8/SlowShortHalfCpu-8 10 166658947 ns/op +Benchmark_FirstUnequal8/SlowShortAllCpu-8 10 154372585 ns/op +Benchmark_FirstUnequal8/SlowLong1Cpu-8 1 3883830889 ns/op +Benchmark_FirstUnequal8/SlowLongHalfCpu-8 1 1080159614 ns/op +Benchmark_FirstUnequal8/SlowLongAllCpu-8 1 1046794857 ns/op + +Notes: There is practically no speed penalty relative to bytes.Compare(). +*/ + +func firstUnequal8UnsafeSubtask(dst, src []byte, nIter int) int { + curPos := 0 + endPos := len(dst) + for iter := 0; iter < nIter; iter++ { + if curPos >= endPos { + curPos = 0 + } + curPos = simd.FirstUnequal8Unsafe(dst, src, curPos) + curPos++ + } + return curPos +} + +func firstUnequal8SimdSubtask(dst, src []byte, nIter int) int { + curPos := 0 + endPos := len(dst) + for iter := 0; iter < nIter; iter++ { + if curPos >= endPos { + curPos = 0 + } + curPos = simd.FirstUnequal8(dst, src, curPos) + curPos++ + } + return curPos +} + +func firstUnequal8SlowSubtask(dst, src []byte, nIter int) int { + curPos := 0 + endPos := len(dst) + for iter := 0; iter < nIter; iter++ { + if curPos >= endPos { + curPos = 0 + } + curPos = firstUnequal8Slow(dst, src, curPos) + curPos++ + } + return curPos +} + +func Benchmark_FirstUnequal8(b *testing.B) { + funcs := []taggedMultiBenchFunc{ + { + f: firstUnequal8UnsafeSubtask, + tag: "Unsafe", + }, + { + f: firstUnequal8SimdSubtask, + tag: "SIMD", + }, + { + f: firstUnequal8SlowSubtask, + tag: "Slow", + }, + } + // Necessary to customize the initialization functions; the default setting + // of src = {0, 3, 6, 9, ...} and dst = {0, 0, 0, 0, ...} results in too many + // mismatches for a realistic benchmark. + opts := multiBenchmarkOpts{ + dstInit: func(src []byte) { + src[len(src)/2] = 128 + }, + srcInit: bytesInit0, + } + for _, f := range funcs { + multiBenchmark(f.f, f.tag+"Short", 150, 150, 9999999, b, opts) + multiBenchmark(f.f, f.tag+"Long", 249250621, 249250621, 50, b, opts) + } +} + func firstGreater8Slow(arg []byte, val byte, startPos int) int { // Slow, but straightforward-to-verify implementation. endPos := len(arg) @@ -235,86 +209,6 @@ func TestFirstGreater(t *testing.T) { } } -func firstLeqSubtask(arg []byte, nIter int) int { - curPos := 0 - endPos := len(arg) - for iter := 0; iter < nIter; iter++ { - if curPos >= endPos { - curPos = 0 - } - curPos = simd.FirstLeq8Unsafe(arg, 0, curPos) - curPos++ - } - return curPos -} - -func firstLeqSubtaskFuture(arg []byte, nIter int) chan int { - future := make(chan int) - go func() { future <- firstLeqSubtask(arg, nIter) }() - return future -} - -func multiFirstLeq(args [][]byte, cpus int, nJob int) { - sumFutures := make([]chan int, cpus) - shardSizeBase := nJob / cpus - shardRemainder := nJob - shardSizeBase*cpus - shardSizeP1 := shardSizeBase + 1 - var taskIdx int - for ; taskIdx < shardRemainder; taskIdx++ { - sumFutures[taskIdx] = firstLeqSubtaskFuture(args[0], shardSizeP1) - } - for ; taskIdx < cpus; taskIdx++ { - sumFutures[taskIdx] = firstLeqSubtaskFuture(args[0], shardSizeBase) - } - var sum int - for taskIdx = 0; taskIdx < cpus; taskIdx++ { - sum += <-sumFutures[taskIdx] - } -} - -func benchmarkFirstLeq(cpus int, nByte int, nJob int, b *testing.B) { - if cpus > runtime.NumCPU() { - b.Skipf("only have %v cpus", runtime.NumCPU()) - } - - argSlices := make([][]byte, 1) - for ii := range argSlices { - // Add 63 to prevent false sharing. - newArr := simd.MakeUnsafe(nByte + 63) - simd.Memset8(newArr, 255) - // Just change one byte in the middle. - newArr[nByte/2] = 0 - argSlices[ii] = newArr[:nByte] - } - for i := 0; i < b.N; i++ { - multiFirstLeq(argSlices, cpus, nJob) - } -} - -func Benchmark_FirstLeqShort1(b *testing.B) { - benchmarkFirstLeq(1, 75, 9999999, b) -} - -func Benchmark_FirstLeqShort4(b *testing.B) { - benchmarkFirstLeq(4, 75, 9999999, b) -} - -func Benchmark_FirstLeqShortMax(b *testing.B) { - benchmarkFirstLeq(runtime.NumCPU(), 75, 9999999, b) -} - -func Benchmark_FirstLeqLong1(b *testing.B) { - benchmarkFirstLeq(1, 249250621, 50, b) -} - -func Benchmark_FirstLeqLong4(b *testing.B) { - benchmarkFirstLeq(4, 249250621, 50, b) -} - -func Benchmark_FirstLeqLongMax(b *testing.B) { - benchmarkFirstLeq(runtime.NumCPU(), 249250621, 50, b) -} - func firstLeq8Slow(arg []byte, val byte, startPos int) int { // Slow, but straightforward-to-verify implementation. endPos := len(arg) @@ -360,3 +254,72 @@ func TestFirstLeq8(t *testing.T) { } } } + +/* +Benchmark results: + MacBook Pro (15-inch, 2016) + 2.7 GHz Intel Core i7, 16 GB 2133 MHz LPDDR3 + +Benchmark_FirstLeq8/SIMDShort1Cpu-8 20 87235782 ns/op +Benchmark_FirstLeq8/SIMDShortHalfCpu-8 50 23864936 ns/op +Benchmark_FirstLeq8/SIMDShortAllCpu-8 100 21211734 ns/op +Benchmark_FirstLeq8/SIMDLong1Cpu-8 3 402996726 ns/op +Benchmark_FirstLeq8/SIMDLongHalfCpu-8 5 245066128 ns/op +Benchmark_FirstLeq8/SIMDLongAllCpu-8 5 231557103 ns/op +Benchmark_FirstLeq8/SlowShort1Cpu-8 2 549800977 ns/op +Benchmark_FirstLeq8/SlowShortHalfCpu-8 10 152074140 ns/op +Benchmark_FirstLeq8/SlowShortAllCpu-8 10 142355855 ns/op +Benchmark_FirstLeq8/SlowLong1Cpu-8 1 3687059961 ns/op +Benchmark_FirstLeq8/SlowLongHalfCpu-8 1 1030280464 ns/op +Benchmark_FirstLeq8/SlowLongAllCpu-8 1 1019364554 ns/op +*/ + +func firstLeq8SimdSubtask(dst, src []byte, nIter int) int { + curPos := 0 + endPos := len(src) + for iter := 0; iter < nIter; iter++ { + if curPos >= endPos { + curPos = 0 + } + curPos = simd.FirstLeq8(src, 0, curPos) + curPos++ + } + return curPos +} + +func firstLeq8SlowSubtask(dst, src []byte, nIter int) int { + curPos := 0 + endPos := len(src) + for iter := 0; iter < nIter; iter++ { + if curPos >= endPos { + curPos = 0 + } + curPos = firstLeq8Slow(src, 0, curPos) + curPos++ + } + return curPos +} + +func Benchmark_FirstLeq8(b *testing.B) { + funcs := []taggedMultiBenchFunc{ + { + f: firstLeq8SimdSubtask, + tag: "SIMD", + }, + { + f: firstLeq8SlowSubtask, + tag: "Slow", + }, + } + opts := multiBenchmarkOpts{ + srcInit: func(src []byte) { + simd.Memset8(src, 255) + // Just change one byte in the middle. + src[len(src)/2] = 128 + }, + } + for _, f := range funcs { + multiBenchmark(f.f, f.tag+"Short", 0, 150, 9999999, b, opts) + multiBenchmark(f.f, f.tag+"Long", 0, 249250621, 50, b, opts) + } +} diff --git a/simd/count_amd64.go b/simd/count_amd64.go index be128b5b..97b5201b 100644 --- a/simd/count_amd64.go +++ b/simd/count_amd64.go @@ -1,16 +1,19 @@ -// Copyright 2018 GRAIL, Inc. All rights reserved. +// Copyright 2021 GRAIL, Inc. All rights reserved. // Use of this source code is governed by the Apache-2.0 // license that can be found in the LICENSE file. +//go:build amd64 && !appengine // +build amd64,!appengine // This is derived from github.com/willf/bitset . package simd -import "math/bits" -import "reflect" -import "unsafe" +import ( + "math/bits" + "reflect" + "unsafe" +) // *** the following function is defined in count_amd64.s @@ -31,15 +34,15 @@ func count2BytesSSE41Asm(src unsafe.Pointer, val1, val2, nByte int) int func count3BytesSSE41Asm(src unsafe.Pointer, val1, val2, val3, nByte int) int //go:noescape -func countNibblesInSetSSE41Asm(src unsafe.Pointer, tablePtr *[16]byte, nByte int) int +func countNibblesInSetSSE41Asm(src unsafe.Pointer, tablePtr *NibbleLookupTable, nByte int) int -func countNibblesInTwoSetsSSE41Asm(cnt2Ptr *int, src unsafe.Pointer, table1Ptr, table2Ptr *[16]byte, nByte int) int +func countNibblesInTwoSetsSSE41Asm(cnt2Ptr *int, src unsafe.Pointer, table1Ptr, table2Ptr *NibbleLookupTable, nByte int) int //go:noescape -func countUnpackedNibblesInSetSSE41Asm(src unsafe.Pointer, tablePtr *[16]byte, nByte int) int +func countUnpackedNibblesInSetSSE41Asm(src unsafe.Pointer, tablePtr *NibbleLookupTable, nByte int) int //go:noescape -func countUnpackedNibblesInTwoSetsSSE41Asm(cnt2Ptr *int, src unsafe.Pointer, table1Ptr, table2Ptr *[16]byte, nByte int) int +func countUnpackedNibblesInTwoSetsSSE41Asm(cnt2Ptr *int, src unsafe.Pointer, table1Ptr, table2Ptr *NibbleLookupTable, nByte int) int //go:noescape func accumulate8SSE41Asm(src unsafe.Pointer, nByte int) int @@ -89,7 +92,7 @@ func Popcnt(bytes []byte) int { leadingWord := uint64(0) if (nLeadingByte & 1) != 0 { leadingWord = (uint64)(*(*byte)(bytearr)) - bytearr = unsafe.Pointer(uintptr(bytearr) + 1) + bytearr = unsafe.Add(bytearr, 1) } if (nLeadingByte & 2) != 0 { // Note that this does not keep the bytes in the original little-endian @@ -98,12 +101,12 @@ func Popcnt(bytes []byte) int { // plink2_base.h for code which does keep the bytes in order. leadingWord <<= 16 leadingWord |= (uint64)(*(*uint16)(bytearr)) - bytearr = unsafe.Pointer(uintptr(bytearr) + 2) + bytearr = unsafe.Add(bytearr, 2) } if (nLeadingByte & 4) != 0 { leadingWord <<= 32 leadingWord |= (uint64)(*(*uint32)(bytearr)) - bytearr = unsafe.Pointer(uintptr(bytearr) + 4) + bytearr = unsafe.Add(bytearr, 4) } tot = bits.OnesCount64(leadingWord) } @@ -209,12 +212,12 @@ func Count3Bytes(src []byte, val1, val2, val3 byte) int { // // WARNING: This function does not validate the table. It may return a garbage // result on invalid input. (However, it won't corrupt memory.) -func CountNibblesInSet(src []byte, tablePtr *[16]byte) int { +func CountNibblesInSet(src []byte, tablePtr *NibbleLookupTable) int { nSrcByte := len(src) if nSrcByte < 16 { cnt := 0 for _, srcByte := range src { - cnt += int(tablePtr[srcByte&15] + tablePtr[srcByte>>4]) + cnt += int(tablePtr.Get(srcByte&15) + tablePtr.Get(srcByte>>4)) } return cnt } @@ -228,7 +231,7 @@ func CountNibblesInSet(src []byte, tablePtr *[16]byte) int { // // WARNING: This function does not validate the tables. It may crash or return // garbage results on invalid input. (However, it won't corrupt memory.) -func CountNibblesInTwoSets(src []byte, table1Ptr, table2Ptr *[16]byte) (int, int) { +func CountNibblesInTwoSets(src []byte, table1Ptr, table2Ptr *NibbleLookupTable) (int, int) { nSrcByte := len(src) cnt2 := 0 if nSrcByte < 16 { @@ -236,8 +239,8 @@ func CountNibblesInTwoSets(src []byte, table1Ptr, table2Ptr *[16]byte) (int, int for _, srcByte := range src { lowBits := srcByte & 15 highBits := srcByte >> 4 - cnt1 += int(table1Ptr[lowBits] + table1Ptr[highBits]) - cnt2 += int(table2Ptr[lowBits] + table2Ptr[highBits]) + cnt1 += int(table1Ptr.Get(lowBits) + table1Ptr.Get(highBits)) + cnt2 += int(table2Ptr.Get(lowBits) + table2Ptr.Get(highBits)) } return cnt1, cnt2 } @@ -252,12 +255,12 @@ func CountNibblesInTwoSets(src []byte, table1Ptr, table2Ptr *[16]byte) (int, int // // WARNING: This function does not validate the table. It may crash or return // a garbage result on invalid input. (However, it won't corrupt memory.) -func CountUnpackedNibblesInSet(src []byte, tablePtr *[16]byte) int { +func CountUnpackedNibblesInSet(src []byte, tablePtr *NibbleLookupTable) int { nSrcByte := len(src) if nSrcByte < 16 { cnt := 0 for _, srcByte := range src { - cnt += int(tablePtr[srcByte]) + cnt += int(tablePtr.Get(srcByte)) } return cnt } @@ -272,7 +275,7 @@ func CountUnpackedNibblesInSet(src []byte, tablePtr *[16]byte) int { // // WARNING: This function does not validate the tables. It may crash or return // garbage results on invalid input. (However, it won't corrupt memory.) -func CountUnpackedNibblesInTwoSets(src []byte, table1Ptr, table2Ptr *[16]byte) (int, int) { +func CountUnpackedNibblesInTwoSets(src []byte, table1Ptr, table2Ptr *NibbleLookupTable) (int, int) { // Building this out now so that biosimd.PackedSeqCountTwo is not a valid // reason to stick to packed .bam seq[] representation. nSrcByte := len(src) @@ -280,8 +283,8 @@ func CountUnpackedNibblesInTwoSets(src []byte, table1Ptr, table2Ptr *[16]byte) ( if nSrcByte < 16 { cnt1 := 0 for _, srcByte := range src { - cnt1 += int(table1Ptr[srcByte]) - cnt2 += int(table2Ptr[srcByte]) + cnt1 += int(table1Ptr.Get(srcByte)) + cnt2 += int(table2Ptr.Get(srcByte)) } return cnt1, cnt2 } diff --git a/simd/count_generic.go b/simd/count_generic.go new file mode 100644 index 00000000..7f09e119 --- /dev/null +++ b/simd/count_generic.go @@ -0,0 +1,213 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +// +build !amd64 appengine + +// This is derived from github.com/willf/bitset . + +package simd + +import "math/bits" + +// PopcntUnsafe returns the number of set bits in the given []byte, assuming +// that any trailing bytes up to the next multiple of BytesPerWord are zeroed +// out. +func PopcntUnsafe(bytes []byte) int { + // Get the base-pointer for the slice, in a way that doesn't trigger a + // bounds-check and fail when length == 0. (Yes, I found out during testing + // that the &bytes[0] idiom doesn't actually work in the length-0 + // case...) + cnt := 0 + for _, b := range bytes { + cnt += bits.OnesCount8(uint8(b)) + } + return cnt +} + +// Popcnt returns the number of set bits in the given []byte. +// +// Some effort has been made to make this run acceptably fast on relatively +// short arrays, since I expect knowing how to do so to be helpful when working +// with hundreds of millions of .bam reads with ~75 bytes of base data and ~150 +// bytes of quality data. Interestingly, moving the leading-byte handling code +// to assembly didn't make a difference. +// +// Some single-threaded benchmark results calling Popcnt 99999999 times on a +// 14-byte unaligned array: +// C implementation: 0.219-0.232s +// This code: 0.606-0.620s +// C implementation using memcpy for trailing bytes: 0.964-0.983s +// So Go's extra looping and function call overhead can almost triple runtime +// in the short-array limit, but that's actually not as bad as the 4.5x +// overhead of trusting memcpy to handle trailing bytes. +func Popcnt(bytes []byte) int { + cnt := 0 + for _, b := range bytes { + cnt += bits.OnesCount8(uint8(b)) + } + return cnt +} + +// We may want a PopcntW function in the future which operates on a []uintptr, +// along with AndW, OrW, XorW, InvmaskW, etc. This would amount to a +// lower-overhead version of willf/bitset (which also uses []uintptr +// internally). +// The main thing I would want to benchmark before making that decision is +// bitset.NextSetMany() vs. a loop of the form +// uidx_base := 0 +// cur_bits := bitarr[0] +// for idx := 0; idx != nSetBit; idx++ { +// // see plink2_base.h BitIter1() +// if cur_bits == 0 { +// widx := uidx_base >> (3 + Log2BytesPerWord) +// for { +// widx++ +// cur_bits = bitarr[widx] +// if cur_bits != 0 { +// break +// } +// } +// uidx_base = widx << (3 + Log2BytesPerWord) +// } +// uidx := uidx_base + bits.TrailingZeros(uint(cur_bits)) +// cur_bits = cur_bits & (cur_bits - 1) +// // (do something with uidx, possibly very simple) +// } +// (Note that there are *hundreds* of loops of this form in plink2.) +// If bitset.NextSetMany() does not impose a large performance penalty, we may +// just want to write a version of it which takes a []byte as input. +// (update: https://go-review.googlesource.com/c/go/+/109716 suggests that +// bitset.NextSetMany() is not good enough.) + +// todo: add ZeroTrailingBits, etc. once we need it + +// MaskThenCountByte counts the number of bytes in src[] satisfying +// src[pos] & mask == val. +func MaskThenCountByte(src []byte, mask, val byte) int { + // This is especially useful for CG counting: + // - Count 'C'/'G' ASCII characters: mask = 0xfb (only capital) or 0xdb + // (either capital or lowercase), val = 'C' + // - Count C/G bytes in .bam unpacked seq8 data, assuming '=' is not in + // input: mask = 0x9, val = 0 + // It can also be used to ignore capitalization when counting instances of a + // single letter. + cnt := 0 + for _, srcByte := range src { + if (srcByte & mask) == val { + cnt++ + } + } + return cnt +} + +// Count2Bytes counts the number of bytes in src[] which are equal to either +// val1 or val2. +// (bytes.Count() should be good enough for a single byte.) +func Count2Bytes(src []byte, val1, val2 byte) int { + cnt := 0 + for _, srcByte := range src { + if (srcByte == val1) || (srcByte == val2) { + cnt++ + } + } + return cnt +} + +// Count3Bytes counts the number of bytes in src[] which are equal to val1, +// val2, or val3. +func Count3Bytes(src []byte, val1, val2, val3 byte) int { + cnt := 0 + for _, srcByte := range src { + if (srcByte == val1) || (srcByte == val2) || (srcByte == val3) { + cnt++ + } + } + return cnt +} + +// CountNibblesInSet counts the number of nibbles in src[] which are in the +// given set. The set must be represented as table[x] == 1 when value x is in +// the set, and table[x] == 0 when x isn't. +// +// WARNING: This function does not validate the table. It may return a garbage +// result on invalid input. (However, it won't corrupt memory.) +func CountNibblesInSet(src []byte, tablePtr *NibbleLookupTable) int { + cnt := 0 + for _, srcByte := range src { + cnt += int(tablePtr.Get(srcByte&15) + tablePtr.Get(srcByte>>4)) + } + return cnt +} + +// CountNibblesInTwoSets counts the number of bytes in src[] which are in the +// given two sets, assuming all bytes are <16. The sets must be represented as +// table[x] == 1 when value x is in the set, and table[x] == 0 when x isn't. +// +// WARNING: This function does not validate the tables. It may crash or return +// garbage results on invalid input. (However, it won't corrupt memory.) +func CountNibblesInTwoSets(src []byte, table1Ptr, table2Ptr *NibbleLookupTable) (int, int) { + cnt1 := 0 + cnt2 := 0 + for _, srcByte := range src { + lowBits := srcByte & 15 + highBits := srcByte >> 4 + cnt1 += int(table1Ptr.Get(lowBits) + table1Ptr.Get(highBits)) + cnt2 += int(table2Ptr.Get(lowBits) + table2Ptr.Get(highBits)) + } + return cnt1, cnt2 +} + +// CountUnpackedNibblesInSet counts the number of bytes in src[] which are in +// the given set, assuming all bytes are <16. The set must be represented as +// table[x] == 1 when value x is in the set, and table[x] == 0 when x isn't. +// +// WARNING: This function does not validate the table. It may crash or return +// a garbage result on invalid input. (However, it won't corrupt memory.) +func CountUnpackedNibblesInSet(src []byte, tablePtr *NibbleLookupTable) int { + cnt := 0 + for _, srcByte := range src { + cnt += int(tablePtr.Get(srcByte)) + } + return cnt +} + +// CountUnpackedNibblesInTwoSets counts the number of bytes in src[] which are +// in the given two sets, assuming all bytes are <16. The sets must be +// represented as table[x] == 1 when value x is in the set, and table[x] == 0 +// when x isn't. +// +// WARNING: This function does not validate the tables. It may crash or return +// garbage results on invalid input. (However, it won't corrupt memory.) +func CountUnpackedNibblesInTwoSets(src []byte, table1Ptr, table2Ptr *NibbleLookupTable) (int, int) { + cnt1 := 0 + cnt2 := 0 + for _, srcByte := range src { + cnt1 += int(table1Ptr.Get(srcByte)) + cnt2 += int(table2Ptr.Get(srcByte)) + } + return cnt1, cnt2 +} + +// (could rename Popcnt to Accumulate1 for consistency...) + +// Accumulate8 returns the sum of the (unsigned) bytes in src[]. +func Accumulate8(src []byte) int { + cnt := 0 + for _, srcByte := range src { + cnt += int(srcByte) + } + return cnt +} + +// Accumulate8Greater returns the sum of all bytes in src[] greater than the +// given value. +func Accumulate8Greater(src []byte, val byte) int { + cnt := 0 + for _, srcByte := range src { + if srcByte > val { + cnt += int(srcByte) + } + } + return cnt +} diff --git a/simd/count_test.go b/simd/count_test.go index 7c585fe8..fdc26ef2 100644 --- a/simd/count_test.go +++ b/simd/count_test.go @@ -1,4 +1,4 @@ -// Copyright 2018 GRAIL, Inc. All rights reserved. +// Copyright 2021 GRAIL, Inc. All rights reserved. // Use of this source code is governed by the Apache-2.0 // license that can be found in the LICENSE file. @@ -9,108 +9,12 @@ import ( "math/bits" "math/rand" "reflect" - "runtime" "testing" "unsafe" "github.com/grailbio/base/simd" ) -/* -Initial benchmark results: - MacBook Pro (15-inch, 2016) - 2.7 GHz Intel Core i7, 16 GB 2133 MHz LPDDR3 - -Benchmark_ByteShortPopcnt1-8 2 880567116 ns/op -Benchmark_ByteShortPopcnt4-8 5 249242503 ns/op -Benchmark_ByteShortPopcntMax-8 10 197222506 ns/op -Benchmark_ByteShortOldPopcnt1-8 2 992499107 ns/op -Benchmark_ByteShortOldPopcnt4-8 5 273206944 ns/op -Benchmark_ByteShortOldPopcntMax-8 5 270066225 ns/op -Benchmark_ByteLongPopcnt1-8 1 1985211630 ns/op -Benchmark_ByteLongPopcnt4-8 2 511618631 ns/op -Benchmark_ByteLongPopcntMax-8 3 506767183 ns/op -Benchmark_ByteLongOldPopcnt1-8 1 2473936090 ns/op -Benchmark_ByteLongOldPopcnt4-8 2 747990164 ns/op -Benchmark_ByteLongOldPopcntMax-8 2 743366724 ns/op - -Only differences between new and old code so far are - (i) 2x loop-unroll, and - (ii) per-call "do we have SSE4.2" flag checks. -I did try implementing the 2x loop-unroll in the math/bits code, but it did not -have much of an effect there. - - -Benchmark_CountCGShort1-8 20 86156256 ns/op -Benchmark_CountCGShort4-8 100 23032100 ns/op -Benchmark_CountCGShortMax-8 100 22272840 ns/op -Benchmark_CountCGLong1-8 1 1025820961 ns/op -Benchmark_CountCGLong4-8 1 1461443453 ns/op -Benchmark_CountCGLongMax-8 1 2180675096 ns/op - -Benchmark_Count3BytesShort1-8 10 105075618 ns/op -Benchmark_Count3BytesShort4-8 50 29571207 ns/op -Benchmark_Count3BytesShortMax-8 50 27188249 ns/op -Benchmark_Count3BytesLong1-8 1 1308949833 ns/op -Benchmark_Count3BytesLong4-8 1 1612130605 ns/op -Benchmark_Count3BytesLongMax-8 1 2319478110 ns/op - -Benchmark_Accumulate8Short1-8 20 67181978 ns/op -Benchmark_Accumulate8Short4-8 100 18216505 ns/op -Benchmark_Accumulate8ShortMax-8 100 17359664 ns/op -Benchmark_Accumulate8Long1-8 1 1050107761 ns/op -Benchmark_Accumulate8Long4-8 1 1440863620 ns/op -Benchmark_Accumulate8LongMax-8 1 2202725361 ns/op - -Benchmark_Accumulate8GreaterShort1-8 20 91913187 ns/op -Benchmark_Accumulate8GreaterShort4-8 50 25629176 ns/op -Benchmark_Accumulate8GreaterShortMax-8 100 22020836 ns/op -Benchmark_Accumulate8GreaterLong1-8 1 1166256065 ns/op -Benchmark_Accumulate8GreaterLong4-8 1 1529133163 ns/op -Benchmark_Accumulate8GreaterLongMax-8 1 2447755677 ns/op - -For comparison, countCGStandard: -Benchmark_CountCGShort1-8 5 206159939 ns/op -Benchmark_CountCGShort4-8 20 55653414 ns/op -Benchmark_CountCGShortMax-8 30 49566408 ns/op -Benchmark_CountCGLong1-8 1 1786864086 ns/op -Benchmark_CountCGLong4-8 1 1975270955 ns/op -Benchmark_CountCGLongMax-8 1 2846417721 ns/op - -countCGNaive: -Benchmark_CountCGShort1-8 2 753564012 ns/op -Benchmark_CountCGShort4-8 5 200074546 ns/op -Benchmark_CountCGShortMax-8 10 193392413 ns/op -Benchmark_CountCGLong1-8 1 12838546141 ns/op -Benchmark_CountCGLong4-8 1 4371080727 ns/op -Benchmark_CountCGLongMax-8 1 5023199989 ns/op -(lesson: don't forget to use bytes.Count() when it's applicable!) - -count3BytesStandard: -Benchmark_Count3BytesShort1-8 5 288822460 ns/op -Benchmark_Count3BytesShort4-8 20 81116028 ns/op -Benchmark_Count3BytesShortMax-8 20 75587001 ns/op -Benchmark_Count3BytesLong1-8 1 2526123231 ns/op -Benchmark_Count3BytesLong4-8 1 2425857828 ns/op -Benchmark_Count3BytesLongMax-8 1 3235725694 ns/op - -accumulate8Slow: -Benchmark_Accumulate8Short1-8 3 394838027 ns/op -Benchmark_Accumulate8Short4-8 10 105763035 ns/op -Benchmark_Accumulate8ShortMax-8 20 93473300 ns/op -Benchmark_Accumulate8Long1-8 1 5143881564 ns/op -Benchmark_Accumulate8Long4-8 1 3501219437 ns/op -Benchmark_Accumulate8LongMax-8 1 3096559063 ns/op - -accumulate8GreaterSlow: -Benchmark_Accumulate8GreaterShort1-8 3 466978266 ns/op -Benchmark_Accumulate8GreaterShort4-8 10 125637387 ns/op -Benchmark_Accumulate8GreaterShortMax-8 10 117808985 ns/op -Benchmark_Accumulate8GreaterLong1-8 1 9825147670 ns/op -Benchmark_Accumulate8GreaterLong4-8 1 5815093074 ns/op -Benchmark_Accumulate8GreaterLongMax-8 1 4554119137 ns/op -*/ - func init() { if unsafe.Sizeof(uintptr(0)) != 8 { // popcnt_amd64.go shouldn't compile at all in this case, but just in @@ -124,24 +28,24 @@ func popcntBytesNoasm(byteslice []byte) int { ct := uintptr(len(byteslice)) bytearr := unsafe.Pointer(bytesliceHeader.Data) - endptr := unsafe.Pointer(uintptr(bytearr) + ct) + endptr := unsafe.Add(bytearr, ct) tot := 0 nLeadingByte := ct % 8 if nLeadingByte != 0 { leadingWord := uint64(0) if (nLeadingByte & 1) != 0 { leadingWord = (uint64)(*(*byte)(bytearr)) - bytearr = unsafe.Pointer(uintptr(bytearr) + 1) + bytearr = unsafe.Add(bytearr, 1) } if (nLeadingByte & 2) != 0 { leadingWord <<= 16 leadingWord |= (uint64)(*(*uint16)(bytearr)) - bytearr = unsafe.Pointer(uintptr(bytearr) + 2) + bytearr = unsafe.Add(bytearr, 2) } if (nLeadingByte & 4) != 0 { leadingWord <<= 32 leadingWord |= (uint64)(*(*uint32)(bytearr)) - bytearr = unsafe.Pointer(uintptr(bytearr) + 4) + bytearr = unsafe.Add(bytearr, 4) } tot = bits.OnesCount64(leadingWord) } @@ -149,139 +53,11 @@ func popcntBytesNoasm(byteslice []byte) int { // depending on which of several equivalent ways I use to write it. for bytearr != endptr { tot += bits.OnesCount64((uint64)(*((*uint64)(bytearr)))) - bytearr = unsafe.Pointer(uintptr(bytearr) + 8) + bytearr = unsafe.Add(bytearr, 8) } return tot } -func popcntSubtaskOld(byteSlice []byte, nIter int) int { - sum := 0 - for iter := 0; iter < nIter; iter++ { - sum += popcntBytesNoasm(byteSlice) - } - return sum -} - -func popcntSubtaskOldFuture(byteSlice []byte, nIter int) chan int { - future := make(chan int) - go func() { future <- popcntSubtaskOld(byteSlice, nIter) }() - return future -} - -func popcntSubtask(byteSlice []byte, nIter int) int { - sum := 0 - for iter := 0; iter < nIter; iter++ { - sum += simd.Popcnt(byteSlice) - } - return sum -} - -func popcntSubtaskFuture(byteSlice []byte, nIter int) chan int { - future := make(chan int) - go func() { future <- popcntSubtask(byteSlice, nIter) }() - return future -} - -func multiBytePopcnt(byteSlice []byte, cpus int, nJob int, useOld bool) { - sumFutures := make([]chan int, cpus) - shardSizeBase := nJob / cpus - shardRemainder := nJob - shardSizeBase*cpus - shardSizeP1 := shardSizeBase + 1 - // Note that this straightforward sharding scheme sometimes doesn't work well - // with hyperthreading: on my Mac, I get more consistent performance dividing - // cpus by two to set it to the actual number of cores. However, this - // doesn't happen on my adhoc instance. - // In any case, I'll experiment with other concurrency patterns soon. - var taskIdx int - if useOld { - for ; taskIdx < shardRemainder; taskIdx++ { - sumFutures[taskIdx] = popcntSubtaskOldFuture(byteSlice, shardSizeP1) - } - for ; taskIdx < cpus; taskIdx++ { - sumFutures[taskIdx] = popcntSubtaskOldFuture(byteSlice, shardSizeBase) - } - } else { - for ; taskIdx < shardRemainder; taskIdx++ { - sumFutures[taskIdx] = popcntSubtaskFuture(byteSlice, shardSizeP1) - } - for ; taskIdx < cpus; taskIdx++ { - sumFutures[taskIdx] = popcntSubtaskFuture(byteSlice, shardSizeBase) - } - } - var sum int - for taskIdx = 0; taskIdx < cpus; taskIdx++ { - sum += <-sumFutures[taskIdx] - } - // fmt.Println(sum) -} - -func benchmarkBytePopcnt(cpus int, nByte int, nJob int, useOld bool, b *testing.B) { - if cpus > runtime.NumCPU() { - b.Skipf("only have %v cpus", runtime.NumCPU()) - } - - byteArr := make([]byte, nByte+1) - for uii := uint(0); uii < uint(nByte); uii++ { - byteArr[uii] = (byte)(uii) - } - byteSlice := byteArr[1 : nByte+1] // force unaligned - for i := 0; i < b.N; i++ { - multiBytePopcnt(byteSlice, cpus, nJob, useOld) - } -} - -// Base sequence in length-150 .bam read occupies 75 bytes, so 75 is a good -// size for the short-array benchmark. -func Benchmark_ByteShortPopcnt1(b *testing.B) { - benchmarkBytePopcnt(1, 75, 99999999, false, b) -} - -func Benchmark_ByteShortPopcnt4(b *testing.B) { - benchmarkBytePopcnt(4, 75, 99999999, false, b) -} - -func Benchmark_ByteShortPopcntMax(b *testing.B) { - benchmarkBytePopcnt(runtime.NumCPU(), 75, 99999999, false, b) -} - -func Benchmark_ByteShortOldPopcnt1(b *testing.B) { - benchmarkBytePopcnt(1, 75, 99999999, true, b) -} - -func Benchmark_ByteShortOldPopcnt4(b *testing.B) { - benchmarkBytePopcnt(4, 75, 99999999, true, b) -} - -func Benchmark_ByteShortOldPopcntMax(b *testing.B) { - benchmarkBytePopcnt(runtime.NumCPU(), 75, 99999999, true, b) -} - -// GRCh37 chromosome 1 length is 249250621, so that's a plausible long-array -// use case. -func Benchmark_ByteLongPopcnt1(b *testing.B) { - benchmarkBytePopcnt(1, 249250621, 100, false, b) -} - -func Benchmark_ByteLongPopcnt4(b *testing.B) { - benchmarkBytePopcnt(4, 249250621, 100, false, b) -} - -func Benchmark_ByteLongPopcntMax(b *testing.B) { - benchmarkBytePopcnt(runtime.NumCPU(), 249250621, 100, false, b) -} - -func Benchmark_ByteLongOldPopcnt1(b *testing.B) { - benchmarkBytePopcnt(1, 249250621, 100, true, b) -} - -func Benchmark_ByteLongOldPopcnt4(b *testing.B) { - benchmarkBytePopcnt(4, 249250621, 100, true, b) -} - -func Benchmark_ByteLongOldPopcntMax(b *testing.B) { - benchmarkBytePopcnt(runtime.NumCPU(), 249250621, 100, true, b) -} - func popcntBytesSlow(bytes []byte) int { // Slow (factor of 5-8x), but straightforward-to-verify implementation. tot := 0 @@ -323,85 +99,81 @@ func TestBytePopcnt(t *testing.T) { } } -func countCGSubtask(src []byte, nIter int) int { - tot := 0 +/* +Benchmark results: + MacBook Pro (15-inch, 2016) + 2.7 GHz Intel Core i7, 16 GB 2133 MHz LPDDR3 + +Benchmark_Popcnt/SIMDShort1Cpu-8 20 90993141 ns/op +Benchmark_Popcnt/SIMDShortHalfCpu-8 50 24639468 ns/op +Benchmark_Popcnt/SIMDShortAllCpu-8 100 23098747 ns/op +Benchmark_Popcnt/SIMDLong1Cpu-8 2 909927976 ns/op +Benchmark_Popcnt/SIMDLongHalfCpu-8 3 488961048 ns/op +Benchmark_Popcnt/SIMDLongAllCpu-8 3 466249901 ns/op +Benchmark_Popcnt/NoasmShort1Cpu-8 10 106873386 ns/op +Benchmark_Popcnt/NoasmShortHalfCpu-8 50 29290668 ns/op +Benchmark_Popcnt/NoasmShortAllCpu-8 50 29559455 ns/op +Benchmark_Popcnt/NoasmLong1Cpu-8 1 1217844097 ns/op +Benchmark_Popcnt/NoasmLongHalfCpu-8 2 507946501 ns/op +Benchmark_Popcnt/NoasmLongAllCpu-8 3 483458386 ns/op +Benchmark_Popcnt/SlowShort1Cpu-8 2 519449562 ns/op +Benchmark_Popcnt/SlowShortHalfCpu-8 10 139108095 ns/op +Benchmark_Popcnt/SlowShortAllCpu-8 10 143346876 ns/op +Benchmark_Popcnt/SlowLong1Cpu-8 1 7515831696 ns/op +Benchmark_Popcnt/SlowLongHalfCpu-8 1 2083880380 ns/op +Benchmark_Popcnt/SlowLongAllCpu-8 1 2064129411 ns/op + +Notes: The current SSE4.2 SIMD implementation just amounts to a 2x-unrolled +OnesCount64 loop without flag-rechecking overhead; they're using the same +underlying instruction. AVX2/AVX-512 allow for faster bulk processing, though; +see e.g. https://github.com/kimwalisch/libpopcnt . +*/ + +func popcntSimdSubtask(dst, src []byte, nIter int) int { + sum := 0 for iter := 0; iter < nIter; iter++ { - tot += simd.MaskThenCountByte(src, 0xfb, 'C') + sum += simd.Popcnt(src) } - return tot -} - -func countCGSubtaskFuture(src []byte, nIter int) chan int { - future := make(chan int) - go func() { future <- countCGSubtask(src, nIter) }() - return future + return sum } -func multiCountCG(srcs [][]byte, cpus int, nJob int) { - sumFutures := make([]chan int, cpus) - shardSizeBase := nJob / cpus - shardRemainder := nJob - shardSizeBase*cpus - shardSizeP1 := shardSizeBase + 1 - var taskIdx int - for ; taskIdx < shardRemainder; taskIdx++ { - sumFutures[taskIdx] = countCGSubtaskFuture(srcs[taskIdx], shardSizeP1) - } - for ; taskIdx < cpus; taskIdx++ { - sumFutures[taskIdx] = countCGSubtaskFuture(srcs[taskIdx], shardSizeBase) - } - var sum int - for taskIdx = 0; taskIdx < cpus; taskIdx++ { - sum += <-sumFutures[taskIdx] +func popcntNoasmSubtask(dst, src []byte, nIter int) int { + sum := 0 + for iter := 0; iter < nIter; iter++ { + sum += popcntBytesNoasm(src) } + return sum } -func benchmarkCountCG(cpus int, nByte int, nJob int, b *testing.B) { - if cpus > runtime.NumCPU() { - b.Skipf("only have %v cpus", runtime.NumCPU()) +func popcntSlowSubtask(dst, src []byte, nIter int) int { + sum := 0 + for iter := 0; iter < nIter; iter++ { + sum += popcntBytesSlow(src) } + return sum +} - mainSlices := make([][]byte, cpus) - for ii := range mainSlices { - // Add 63 to prevent false sharing. - newArr := simd.MakeUnsafe(nByte + 63) - for jj := 0; jj < nByte; jj++ { - newArr[jj] = byte(jj * 3) - } - mainSlices[ii] = newArr[1 : nByte+1] +func Benchmark_Popcnt(b *testing.B) { + funcs := []taggedMultiBenchFunc{ + { + f: popcntSimdSubtask, + tag: "SIMD", + }, + { + f: popcntNoasmSubtask, + tag: "Noasm", + }, + { + f: popcntSlowSubtask, + tag: "Slow", + }, } - for i := 0; i < b.N; i++ { - multiCountCG(mainSlices, cpus, nJob) + for _, f := range funcs { + multiBenchmark(f.f, f.tag+"Short", 0, 75, 9999999, b) + multiBenchmark(f.f, f.tag+"Long", 0, 249250621, 50, b) } } -// Base sequence in length-150 .bam read occupies 75 bytes, so 75 is a good -// size for the short-array benchmark. -func Benchmark_CountCGShort1(b *testing.B) { - benchmarkCountCG(1, 75, 9999999, b) -} - -func Benchmark_CountCGShort4(b *testing.B) { - benchmarkCountCG(4, 75, 9999999, b) -} - -func Benchmark_CountCGShortMax(b *testing.B) { - benchmarkCountCG(runtime.NumCPU(), 75, 9999999, b) -} - -// GRCh37 chromosome 1 length is 249250621, so that's a plausible long-array -// use case. -func Benchmark_CountCGLong1(b *testing.B) { - benchmarkCountCG(1, 249250621, 50, b) -} - -func Benchmark_CountCGLong4(b *testing.B) { - benchmarkCountCG(4, 249250621, 50, b) -} - -func Benchmark_CountCGLongMax(b *testing.B) { - benchmarkCountCG(runtime.NumCPU(), 249250621, 50, b) -} - var cgArr = [...]byte{'C', 'G'} func countCGStandard(src []byte) int { @@ -442,6 +214,58 @@ func TestCountCG(t *testing.T) { } } +/* +Benchmark results: + MacBook Pro (15-inch, 2016) + 2.7 GHz Intel Core i7, 16 GB 2133 MHz LPDDR3 + +Benchmark_CountCG/SIMDShort1Cpu-8 10 119280079 ns/op +Benchmark_CountCG/SIMDShortHalfCpu-8 50 34743805 ns/op +Benchmark_CountCG/SIMDShortAllCpu-8 50 28507338 ns/op +Benchmark_CountCG/SIMDLong1Cpu-8 2 765099599 ns/op +Benchmark_CountCG/SIMDLongHalfCpu-8 3 491655239 ns/op +Benchmark_CountCG/SIMDLongAllCpu-8 3 452592924 ns/op +Benchmark_CountCG/StandardShort1Cpu-8 5 237081120 ns/op +Benchmark_CountCG/StandardShortHalfCpu-8 20 64949969 ns/op +Benchmark_CountCG/StandardShortAllCpu-8 20 59167932 ns/op +Benchmark_CountCG/StandardLong1Cpu-8 1 1496389230 ns/op +Benchmark_CountCG/StandardLongHalfCpu-8 2 931898463 ns/op +Benchmark_CountCG/StandardLongAllCpu-8 2 980615182 ns/op +*/ + +func countCGSimdSubtask(dst, src []byte, nIter int) int { + tot := 0 + for iter := 0; iter < nIter; iter++ { + tot += simd.MaskThenCountByte(src, 0xfb, 'C') + } + return tot +} + +func countCGStandardSubtask(dst, src []byte, nIter int) int { + tot := 0 + for iter := 0; iter < nIter; iter++ { + tot += countCGStandard(src) + } + return tot +} + +func Benchmark_CountCG(b *testing.B) { + funcs := []taggedMultiBenchFunc{ + { + f: countCGSimdSubtask, + tag: "SIMD", + }, + { + f: countCGStandardSubtask, + tag: "Standard", + }, + } + for _, f := range funcs { + multiBenchmark(f.f, f.tag+"Short", 0, 150, 9999999, b) + multiBenchmark(f.f, f.tag+"Long", 0, 249250621, 50, b) + } +} + func count2BytesStandard(src, vals []byte) int { // Not 'Slow' since bytes.Count is decently optimized for a single byte. return bytes.Count(src, vals[:1]) + bytes.Count(src, vals[1:2]) @@ -472,88 +296,6 @@ func TestCount2Bytes(t *testing.T) { } } -func count3BytesSubtask(src []byte, nIter int) int { - tot := 0 - // vals := [...]byte{'A', 'T', 'N'} - // valsSlice := vals[:] - for iter := 0; iter < nIter; iter++ { - tot += simd.Count3Bytes(src, 'A', 'T', 'N') - // tot += count3BytesStandard(src, valsSlice) - } - return tot -} - -func count3BytesSubtaskFuture(src []byte, nIter int) chan int { - future := make(chan int) - go func() { future <- count3BytesSubtask(src, nIter) }() - return future -} - -func multiCount3Bytes(srcs [][]byte, cpus int, nJob int) { - sumFutures := make([]chan int, cpus) - shardSizeBase := nJob / cpus - shardRemainder := nJob - shardSizeBase*cpus - shardSizeP1 := shardSizeBase + 1 - var taskIdx int - for ; taskIdx < shardRemainder; taskIdx++ { - sumFutures[taskIdx] = count3BytesSubtaskFuture(srcs[taskIdx], shardSizeP1) - } - for ; taskIdx < cpus; taskIdx++ { - sumFutures[taskIdx] = count3BytesSubtaskFuture(srcs[taskIdx], shardSizeBase) - } - var sum int - for taskIdx = 0; taskIdx < cpus; taskIdx++ { - sum += <-sumFutures[taskIdx] - } -} - -func benchmarkCount3Bytes(cpus int, nByte int, nJob int, b *testing.B) { - if cpus > runtime.NumCPU() { - b.Skipf("only have %v cpus", runtime.NumCPU()) - } - - mainSlices := make([][]byte, cpus) - for ii := range mainSlices { - // Add 63 to prevent false sharing. - newArr := simd.MakeUnsafe(nByte + 63) - for jj := 0; jj < nByte; jj++ { - newArr[jj] = byte(jj * 3) - } - mainSlices[ii] = newArr[:nByte] - } - for i := 0; i < b.N; i++ { - multiCount3Bytes(mainSlices, cpus, nJob) - } -} - -// Base sequence in length-150 .bam read occupies 75 bytes, so 75 is a good -// size for the short-array benchmark. -func Benchmark_Count3BytesShort1(b *testing.B) { - benchmarkCount3Bytes(1, 75, 9999999, b) -} - -func Benchmark_Count3BytesShort4(b *testing.B) { - benchmarkCount3Bytes(4, 75, 9999999, b) -} - -func Benchmark_Count3BytesShortMax(b *testing.B) { - benchmarkCount3Bytes(runtime.NumCPU(), 75, 9999999, b) -} - -// GRCh37 chromosome 1 length is 249250621, so that's a plausible long-array -// use case. -func Benchmark_Count3BytesLong1(b *testing.B) { - benchmarkCount3Bytes(1, 249250621, 50, b) -} - -func Benchmark_Count3BytesLong4(b *testing.B) { - benchmarkCount3Bytes(4, 249250621, 50, b) -} - -func Benchmark_Count3BytesLongMax(b *testing.B) { - benchmarkCount3Bytes(runtime.NumCPU(), 249250621, 50, b) -} - func count3BytesStandard(src, vals []byte) int { return bytes.Count(src, vals[:1]) + bytes.Count(src, vals[1:2]) + bytes.Count(src, vals[2:3]) } @@ -584,10 +326,63 @@ func TestCount3Bytes(t *testing.T) { } } -func countNibblesInSetSlow(src []byte, tablePtr *[16]byte) int { +/* +Benchmark results: + MacBook Pro (15-inch, 2016) + 2.7 GHz Intel Core i7, 16 GB 2133 MHz LPDDR3 + +Benchmark_Count3Bytes/SIMDShort1Cpu-8 10 141085860 ns/op +Benchmark_Count3Bytes/SIMDShortHalfCpu-8 30 40371892 ns/op +Benchmark_Count3Bytes/SIMDShortAllCpu-8 30 37769995 ns/op +Benchmark_Count3Bytes/SIMDLong1Cpu-8 2 945534510 ns/op +Benchmark_Count3Bytes/SIMDLongHalfCpu-8 3 499146889 ns/op +Benchmark_Count3Bytes/SIMDLongAllCpu-8 3 475811932 ns/op +Benchmark_Count3Bytes/StandardShort1Cpu-8 3 346637595 ns/op +Benchmark_Count3Bytes/StandardShortHalfCpu-8 20 96524251 ns/op +Benchmark_Count3Bytes/StandardShortAllCpu-8 20 87056185 ns/op +Benchmark_Count3Bytes/StandardLong1Cpu-8 1 2260954596 ns/op +Benchmark_Count3Bytes/StandardLongHalfCpu-8 1 1518757560 ns/op +Benchmark_Count3Bytes/StandardLongAllCpu-8 1 1468352229 ns/op +*/ + +func count3BytesSimdSubtask(dst, src []byte, nIter int) int { + tot := 0 + for iter := 0; iter < nIter; iter++ { + tot += simd.Count3Bytes(src, 'A', 'T', 'N') + } + return tot +} + +func count3BytesStandardSubtask(dst, src []byte, nIter int) int { + tot := 0 + vals := []byte{'A', 'T', 'N'} + for iter := 0; iter < nIter; iter++ { + tot += count3BytesStandard(src, vals) + } + return tot +} + +func Benchmark_Count3Bytes(b *testing.B) { + funcs := []taggedMultiBenchFunc{ + { + f: count3BytesSimdSubtask, + tag: "SIMD", + }, + { + f: count3BytesStandardSubtask, + tag: "Standard", + }, + } + for _, f := range funcs { + multiBenchmark(f.f, f.tag+"Short", 0, 150, 9999999, b) + multiBenchmark(f.f, f.tag+"Long", 0, 249250621, 50, b) + } +} + +func countNibblesInSetSlow(src []byte, tablePtr *simd.NibbleLookupTable) int { cnt := 0 for _, srcByte := range src { - cnt += int(tablePtr[srcByte&15] + tablePtr[srcByte>>4]) + cnt += int(tablePtr.Get(srcByte&15) + tablePtr.Get(srcByte>>4)) } return cnt } @@ -608,9 +403,10 @@ func TestCountNibblesInSet(t *testing.T) { baseCode2 := baseCode1 + 1 + byte(rand.Intn(int(15-baseCode1))) table[baseCode1] = 1 table[baseCode2] = 1 + nlt := simd.MakeNibbleLookupTable(table) - result1 := countNibblesInSetSlow(srcSlice, &table) - result2 := simd.CountNibblesInSet(srcSlice, &table) + result1 := countNibblesInSetSlow(srcSlice, &nlt) + result2 := simd.CountNibblesInSet(srcSlice, &nlt) if result1 != result2 { t.Fatal("Mismatched CountNibblesInSet result.") } @@ -639,10 +435,12 @@ func TestCountNibblesInTwoSets(t *testing.T) { for ii := 0; ii != 5; ii++ { table2[rand.Intn(16)] = 1 } + nlt1 := simd.MakeNibbleLookupTable(table1) + nlt2 := simd.MakeNibbleLookupTable(table2) - result1a := countNibblesInSetSlow(srcSlice, &table1) - result1b := countNibblesInSetSlow(srcSlice, &table2) - result2a, result2b := simd.CountNibblesInTwoSets(srcSlice, &table1, &table2) + result1a := countNibblesInSetSlow(srcSlice, &nlt1) + result1b := countNibblesInSetSlow(srcSlice, &nlt2) + result2a, result2b := simd.CountNibblesInTwoSets(srcSlice, &nlt1, &nlt2) if (result1a != result2a) || (result1b != result2b) { t.Fatal("Mismatched CountNibblesInTwoSets result.") } @@ -654,10 +452,10 @@ func TestCountNibblesInTwoSets(t *testing.T) { } } -func countUnpackedNibblesInSetSlow(src []byte, tablePtr *[16]byte) int { +func countUnpackedNibblesInSetSlow(src []byte, tablePtr *simd.NibbleLookupTable) int { cnt := 0 for _, srcByte := range src { - cnt += int(tablePtr[srcByte]) + cnt += int(tablePtr.Get(srcByte)) } return cnt } @@ -678,9 +476,10 @@ func TestCountUnpackedNibblesInSet(t *testing.T) { baseCode2 := baseCode1 + 1 + byte(rand.Intn(int(15-baseCode1))) table[baseCode1] = 1 table[baseCode2] = 1 + nlt := simd.MakeNibbleLookupTable(table) - result1 := countUnpackedNibblesInSetSlow(srcSlice, &table) - result2 := simd.CountUnpackedNibblesInSet(srcSlice, &table) + result1 := countUnpackedNibblesInSetSlow(srcSlice, &nlt) + result2 := simd.CountUnpackedNibblesInSet(srcSlice, &nlt) if result1 != result2 { t.Fatal("Mismatched CountUnpackedNibblesInSet result.") } @@ -709,10 +508,12 @@ func TestCountUnpackedNibblesInTwoSets(t *testing.T) { for ii := 0; ii != 5; ii++ { table2[rand.Intn(16)] = 1 } + nlt1 := simd.MakeNibbleLookupTable(table1) + nlt2 := simd.MakeNibbleLookupTable(table2) - result1a := countUnpackedNibblesInSetSlow(srcSlice, &table1) - result1b := countUnpackedNibblesInSetSlow(srcSlice, &table2) - result2a, result2b := simd.CountUnpackedNibblesInTwoSets(srcSlice, &table1, &table2) + result1a := countUnpackedNibblesInSetSlow(srcSlice, &nlt1) + result1b := countUnpackedNibblesInSetSlow(srcSlice, &nlt2) + result2a, result2b := simd.CountUnpackedNibblesInTwoSets(srcSlice, &nlt1, &nlt2) if (result1a != result2a) || (result1b != result2b) { t.Fatal("Mismatched CountUnpackedNibblesInTwoSets result.") } @@ -724,86 +525,6 @@ func TestCountUnpackedNibblesInTwoSets(t *testing.T) { } } -func accumulate8Subtask(src []byte, nIter int) int { - tot := 0 - for iter := 0; iter < nIter; iter++ { - tot += simd.Accumulate8(src) - // tot += accumulate8Slow(src) - } - return tot -} - -func accumulate8SubtaskFuture(src []byte, nIter int) chan int { - future := make(chan int) - go func() { future <- accumulate8Subtask(src, nIter) }() - return future -} - -func multiAccumulate8(srcs [][]byte, cpus int, nJob int) { - sumFutures := make([]chan int, cpus) - shardSizeBase := nJob / cpus - shardRemainder := nJob - shardSizeBase*cpus - shardSizeP1 := shardSizeBase + 1 - var taskIdx int - for ; taskIdx < shardRemainder; taskIdx++ { - sumFutures[taskIdx] = accumulate8SubtaskFuture(srcs[taskIdx], shardSizeP1) - } - for ; taskIdx < cpus; taskIdx++ { - sumFutures[taskIdx] = accumulate8SubtaskFuture(srcs[taskIdx], shardSizeBase) - } - var sum int - for taskIdx = 0; taskIdx < cpus; taskIdx++ { - sum += <-sumFutures[taskIdx] - } -} - -func benchmarkAccumulate8(cpus int, nByte int, nJob int, b *testing.B) { - if cpus > runtime.NumCPU() { - b.Skipf("only have %v cpus", runtime.NumCPU()) - } - - mainSlices := make([][]byte, cpus) - for ii := range mainSlices { - // Add 63 to prevent false sharing. - newArr := simd.MakeUnsafe(nByte + 63) - for jj := 0; jj < nByte; jj++ { - newArr[jj] = byte(jj * 3) - } - mainSlices[ii] = newArr[:nByte] - } - for i := 0; i < b.N; i++ { - multiAccumulate8(mainSlices, cpus, nJob) - } -} - -// Base sequence in length-150 .bam read occupies 75 bytes, so 75 is a good -// size for the short-array benchmark. -func Benchmark_Accumulate8Short1(b *testing.B) { - benchmarkAccumulate8(1, 75, 9999999, b) -} - -func Benchmark_Accumulate8Short4(b *testing.B) { - benchmarkAccumulate8(4, 75, 9999999, b) -} - -func Benchmark_Accumulate8ShortMax(b *testing.B) { - benchmarkAccumulate8(runtime.NumCPU(), 75, 9999999, b) -} - -// GRCh37 chromosome 1 length is 249250621, so that's a plausible long-array -// use case. -func Benchmark_Accumulate8Long1(b *testing.B) { - benchmarkAccumulate8(1, 249250621, 50, b) -} - -func Benchmark_Accumulate8Long4(b *testing.B) { - benchmarkAccumulate8(4, 249250621, 50, b) -} - -func Benchmark_Accumulate8LongMax(b *testing.B) { - benchmarkAccumulate8(runtime.NumCPU(), 249250621, 50, b) -} - func accumulate8Slow(src []byte) int { cnt := 0 for _, srcByte := range src { @@ -832,86 +553,58 @@ func TestAccumulate8(t *testing.T) { } } -func accumulate8GreaterSubtask(src []byte, nIter int) int { +/* +Benchmark results: + MacBook Pro (15-inch, 2016) + 2.7 GHz Intel Core i7, 16 GB 2133 MHz LPDDR3 + +Benchmark_Accumulate8/SIMDShort1Cpu-8 20 92560842 ns/op +Benchmark_Accumulate8/SIMDShortHalfCpu-8 50 24796260 ns/op +Benchmark_Accumulate8/SIMDShortAllCpu-8 100 21541910 ns/op +Benchmark_Accumulate8/SIMDLong1Cpu-8 2 778781187 ns/op +Benchmark_Accumulate8/SIMDLongHalfCpu-8 3 466101270 ns/op +Benchmark_Accumulate8/SIMDLongAllCpu-8 3 472125495 ns/op +Benchmark_Accumulate8/SlowShort1Cpu-8 2 725211331 ns/op +Benchmark_Accumulate8/SlowShortHalfCpu-8 10 192303935 ns/op +Benchmark_Accumulate8/SlowShortAllCpu-8 10 146159760 ns/op +Benchmark_Accumulate8/SlowLong1Cpu-8 1 5371110621 ns/op +Benchmark_Accumulate8/SlowLongHalfCpu-8 1 1473946277 ns/op +Benchmark_Accumulate8/SlowLongAllCpu-8 1 1118962315 ns/op +*/ + +func accumulate8SimdSubtask(dst, src []byte, nIter int) int { tot := 0 for iter := 0; iter < nIter; iter++ { - tot += simd.Accumulate8Greater(src, 14) - // tot += accumulate8GreaterSlow(src, 14) + tot += simd.Accumulate8(src) } return tot } -func accumulate8GreaterSubtaskFuture(src []byte, nIter int) chan int { - future := make(chan int) - go func() { future <- accumulate8GreaterSubtask(src, nIter) }() - return future -} - -func multiAccumulate8Greater(srcs [][]byte, cpus int, nJob int) { - sumFutures := make([]chan int, cpus) - shardSizeBase := nJob / cpus - shardRemainder := nJob - shardSizeBase*cpus - shardSizeP1 := shardSizeBase + 1 - var taskIdx int - for ; taskIdx < shardRemainder; taskIdx++ { - sumFutures[taskIdx] = accumulate8GreaterSubtaskFuture(srcs[taskIdx], shardSizeP1) - } - for ; taskIdx < cpus; taskIdx++ { - sumFutures[taskIdx] = accumulate8GreaterSubtaskFuture(srcs[taskIdx], shardSizeBase) - } - var sum int - for taskIdx = 0; taskIdx < cpus; taskIdx++ { - sum += <-sumFutures[taskIdx] +func accumulate8SlowSubtask(dst, src []byte, nIter int) int { + tot := 0 + for iter := 0; iter < nIter; iter++ { + tot += accumulate8Slow(src) } + return tot } -func benchmarkAccumulate8Greater(cpus int, nByte int, nJob int, b *testing.B) { - if cpus > runtime.NumCPU() { - b.Skipf("only have %v cpus", runtime.NumCPU()) - } - - mainSlices := make([][]byte, cpus) - for ii := range mainSlices { - // Add 63 to prevent false sharing. - newArr := simd.MakeUnsafe(nByte + 63) - for jj := 0; jj < nByte; jj++ { - newArr[jj] = byte(jj*3) & 127 - } - mainSlices[ii] = newArr[:nByte] +func Benchmark_Accumulate8(b *testing.B) { + funcs := []taggedMultiBenchFunc{ + { + f: accumulate8SimdSubtask, + tag: "SIMD", + }, + { + f: accumulate8SlowSubtask, + tag: "Slow", + }, } - for i := 0; i < b.N; i++ { - multiAccumulate8Greater(mainSlices, cpus, nJob) + for _, f := range funcs { + multiBenchmark(f.f, f.tag+"Short", 0, 150, 9999999, b) + multiBenchmark(f.f, f.tag+"Long", 0, 249250621, 50, b) } } -// Base sequence in length-150 .bam read occupies 75 bytes, so 75 is a good -// size for the short-array benchmark. -func Benchmark_Accumulate8GreaterShort1(b *testing.B) { - benchmarkAccumulate8Greater(1, 75, 9999999, b) -} - -func Benchmark_Accumulate8GreaterShort4(b *testing.B) { - benchmarkAccumulate8Greater(4, 75, 9999999, b) -} - -func Benchmark_Accumulate8GreaterShortMax(b *testing.B) { - benchmarkAccumulate8Greater(runtime.NumCPU(), 75, 9999999, b) -} - -// GRCh37 chromosome 1 length is 249250621, so that's a plausible long-array -// use case. -func Benchmark_Accumulate8GreaterLong1(b *testing.B) { - benchmarkAccumulate8Greater(1, 249250621, 50, b) -} - -func Benchmark_Accumulate8GreaterLong4(b *testing.B) { - benchmarkAccumulate8Greater(4, 249250621, 50, b) -} - -func Benchmark_Accumulate8GreaterLongMax(b *testing.B) { - benchmarkAccumulate8Greater(runtime.NumCPU(), 249250621, 50, b) -} - func accumulate8GreaterSlow(src []byte, val byte) int { cnt := 0 for _, srcByte := range src { @@ -943,3 +636,55 @@ func TestAccumulate8Greater(t *testing.T) { } } } + +/* +Benchmark results: + MacBook Pro (15-inch, 2016) + 2.7 GHz Intel Core i7, 16 GB 2133 MHz LPDDR3 + +Benchmark_Accumulate8Greater/SIMDShort1Cpu-8 10 137436870 ns/op +Benchmark_Accumulate8Greater/SIMDShortHalfCpu-8 50 36257710 ns/op +Benchmark_Accumulate8Greater/SIMDShortAllCpu-8 50 32131334 ns/op +Benchmark_Accumulate8Greater/SIMDLong1Cpu-8 2 895831574 ns/op +Benchmark_Accumulate8Greater/SIMDLongHalfCpu-8 2 501501504 ns/op +Benchmark_Accumulate8Greater/SIMDLongAllCpu-8 3 473122019 ns/op +Benchmark_Accumulate8Greater/SlowShort1Cpu-8 1 1026311714 ns/op +Benchmark_Accumulate8Greater/SlowShortHalfCpu-8 5 270841153 ns/op +Benchmark_Accumulate8Greater/SlowShortAllCpu-8 5 254131935 ns/op +Benchmark_Accumulate8Greater/SlowLong1Cpu-8 1 7651910478 ns/op +Benchmark_Accumulate8Greater/SlowLongHalfCpu-8 1 2113221447 ns/op +Benchmark_Accumulate8Greater/SlowLongAllCpu-8 1 2047822921 ns/op +*/ + +func accumulate8GreaterSimdSubtask(dst, src []byte, nIter int) int { + tot := 0 + for iter := 0; iter < nIter; iter++ { + tot += simd.Accumulate8Greater(src, 14) + } + return tot +} + +func accumulate8GreaterSlowSubtask(dst, src []byte, nIter int) int { + tot := 0 + for iter := 0; iter < nIter; iter++ { + tot += accumulate8GreaterSlow(src, 14) + } + return tot +} + +func Benchmark_Accumulate8Greater(b *testing.B) { + funcs := []taggedMultiBenchFunc{ + { + f: accumulate8GreaterSimdSubtask, + tag: "SIMD", + }, + { + f: accumulate8GreaterSlowSubtask, + tag: "Slow", + }, + } + for _, f := range funcs { + multiBenchmark(f.f, f.tag+"Short", 0, 150, 9999999, b) + multiBenchmark(f.f, f.tag+"Long", 0, 249250621, 50, b) + } +} diff --git a/simd/doc.go b/simd/doc.go index 6510ca01..3d6942be 100644 --- a/simd/doc.go +++ b/simd/doc.go @@ -6,11 +6,11 @@ // operations on byte arrays which the compiler cannot be trusted to // autovectorize within the next several years. // -// The backend currently assumes SSE4.2 is available, and does not use anything -// past that. (init() checks for SSE4.2 support, and panics when it isn't -// there.) However, the interface is designed to allow the backend to -// autodetect e.g. AVX2/AVX-512 and opportunistically use those instructions, -// without any changes to properly written higher-level code. +// The backend assumes SSE4.2 is available: init() checks for SSE4.2 support, +// and panics when it isn't there. The interface is designed to allow the +// backend to also autodetect e.g. AVX2/AVX-512 and opportunistically use those +// instructions, without any changes to properly written higher-level code. +// Implementation of the AVX2 part of this is in progress. // // // The central constraint driving this package's design is the standard Go @@ -25,18 +25,16 @@ // // Two classes of functions are exported: // -// - Functions with 'Unsafe' in their names will assume it is safe to use the -// main vectorized loop to process the entire slice; this may involve memory -// accesses a few bytes beyond the end of the slice. MakeUnsafe() and related -// functions can be used to allocate a slice with sufficient capacity for this -// to work (this currently means bytesPerVec extra bytes; simply rounding up to -// a multiple of bytesPerVec is not always enough). They may have other -// preconditions as well, and won't check those, either. +// - Functions with 'Unsafe' in their names are very performant, but are +// memory-unsafe, do not validate documented preconditions, and may have the +// unusual property of reading/writing to a few bytes *past* the end of the +// given slices. The MakeUnsafe() function and its relatives allocate +// byte-slices with sufficient extra capacity for all Unsafe functions with the +// latter property to work properly. // // - Their safe analogues work properly on ordinary slices, and often panic // when documented preconditions are not met. When a precondition is not // explicitly checked (due to computational cost), safe functions may return -// garbage values when the condition is not met, but they will not corrupt -// unrelated memory or perform out-of-bounds read operations. (Unsafe -// functions may do either of those things when misused.) +// garbage values when the condition is not met, but they are memory-safe: they +// will not corrupt unrelated memory or perform out-of-bounds read operations. package simd diff --git a/simd/float_amd64.go b/simd/float_amd64.go new file mode 100644 index 00000000..b9349152 --- /dev/null +++ b/simd/float_amd64.go @@ -0,0 +1,48 @@ +// Copyright 2021 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +// +build amd64,!appengine + +package simd + +import ( + "math" + "reflect" + "unsafe" + + "golang.org/x/sys/cpu" +) + +//go:noescape +func findNaNOrInf64SSSE3Asm(data unsafe.Pointer, nElem int) int + +//go:noescape +func findNaNOrInf64AVX2Asm(data unsafe.Pointer, nElem int) int + +var avx2Available bool + +func init() { + avx2Available = cpu.X86.HasAVX2 + // possible todo: detect FMA and/or AVX512DQ. +} + +// FindNaNOrInf64 returns the position of the first NaN/inf value if one is +// present, and -1 otherwise. +func FindNaNOrInf64(data []float64) int { + nElem := len(data) + if nElem < 16 { + for i, x := range data { + if (math.Float64bits(x) & (0x7ff << 52)) == (0x7ff << 52) { + return i + } + } + return -1 + } + dataHeader := (*reflect.SliceHeader)(unsafe.Pointer(&data)) + if avx2Available { + return findNaNOrInf64AVX2Asm(unsafe.Pointer(dataHeader.Data), nElem) + } else { + return findNaNOrInf64SSSE3Asm(unsafe.Pointer(dataHeader.Data), nElem) + } +} diff --git a/simd/float_amd64.s b/simd/float_amd64.s new file mode 100644 index 00000000..c1797fea --- /dev/null +++ b/simd/float_amd64.s @@ -0,0 +1,234 @@ +// Copyright 2021 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +// +build amd64,!appengine + + DATA ·ExponentMask<>+0x00(SB)/8, $0x7ff07ff07ff07ff0 + DATA ·ExponentMask<>+0x08(SB)/8, $0x7ff07ff07ff07ff0 + DATA ·ExponentMask<>+0x10(SB)/8, $0x7ff07ff07ff07ff0 + DATA ·ExponentMask<>+0x18(SB)/8, $0x7ff07ff07ff07ff0 + // NOPTR = 16, RODATA = 8 + GLOBL ·ExponentMask<>(SB), 24, $32 + + DATA ·FirstShuffle<>+0x00(SB)/8, $0xffffffff0f0e0706 + DATA ·FirstShuffle<>+0x08(SB)/8, $0xffffffffffffffff + DATA ·FirstShuffle<>+0x10(SB)/8, $0xffffffff0f0e0706 + DATA ·FirstShuffle<>+0x18(SB)/8, $0xffffffffffffffff + GLOBL ·FirstShuffle<>(SB), 24, $32 + + DATA ·SecondShuffle<>+0x00(SB)/8, $0x0f0e0706ffffffff + DATA ·SecondShuffle<>+0x08(SB)/8, $0xffffffffffffffff + DATA ·SecondShuffle<>+0x10(SB)/8, $0x0f0e0706ffffffff + DATA ·SecondShuffle<>+0x18(SB)/8, $0xffffffffffffffff + GLOBL ·SecondShuffle<>(SB), 24, $32 + + DATA ·ThirdShuffle<>+0x00(SB)/8, $0xffffffffffffffff + DATA ·ThirdShuffle<>+0x08(SB)/8, $0xffffffff0f0e0706 + DATA ·ThirdShuffle<>+0x10(SB)/8, $0xffffffffffffffff + DATA ·ThirdShuffle<>+0x18(SB)/8, $0xffffffff0f0e0706 + GLOBL ·ThirdShuffle<>(SB), 24, $32 + + DATA ·FourthShuffle<>+0x00(SB)/8, $0xffffffffffffffff + DATA ·FourthShuffle<>+0x08(SB)/8, $0x0f0e0706ffffffff + DATA ·FourthShuffle<>+0x10(SB)/8, $0xffffffffffffffff + DATA ·FourthShuffle<>+0x18(SB)/8, $0x0f0e0706ffffffff + GLOBL ·FourthShuffle<>(SB), 24, $32 + +TEXT ·findNaNOrInf64SSSE3Asm(SB),4,$0-24 + // findNaNOrInf64SSSE3Asm returns x if the first NaN/inf in data is at + // position x, or -1 if no NaN/inf is present. nElem must be at least + // 8. + // + // The implementation exploits the fact that we only need to look at + // the exponent bits to determine NaN/inf status, and these occupy just + // the top two bytes of each 8-byte float. Thus, we can pack the + // exponent-containing-bytes of 8 consecutive float64s into a single + // 16-byte vector, and check them in parallel. + // + // Register allocation: + // AX: data + // BX: nElem - 8 + // CX: current index + // DX: comparison result + // SI: &(data[2]) + // DI: &(data[4]) + // R8: &(data[6]) + // R9: nElem + // X0: exponent mask + // X1: first shuffle mask + // X2: second shuffle mask + // X3: third shuffle mask + // X4: fourth shuffle mask + MOVQ data+0(FP), AX + MOVQ nElem+8(FP), BX + MOVQ BX, R9 + SUBQ $8, BX + XORL CX, CX + MOVQ AX, SI + MOVQ AX, DI + MOVQ AX, R8 + ADDQ $16, SI + ADDQ $32, DI + ADDQ $48, R8 + + MOVOU ·ExponentMask<>(SB), X0 + MOVOU ·FirstShuffle<>(SB), X1 + MOVOU ·SecondShuffle<>(SB), X2 + MOVOU ·ThirdShuffle<>(SB), X3 + MOVOU ·FourthShuffle<>(SB), X4 + +findNaNOrInf64SSSE3AsmLoop: + // Scan 8 float64s, starting from &(data[CX]), into X5..X8. + MOVOU (AX)(CX*8), X5 + MOVOU (SI)(CX*8), X6 + MOVOU (DI)(CX*8), X7 + MOVOU (R8)(CX*8), X8 + + // Extract exponent bytes. + PSHUFB X1, X5 + PSHUFB X2, X6 + PSHUFB X3, X7 + PSHUFB X4, X8 + + // Collect into X5. + POR X6, X5 + POR X8, X7 + POR X7, X5 + + // Mask out non-exponent bits, and then compare 2-byte groups in + // parallel. + PAND X0, X5 + PCMPEQW X0, X5 + + // Check result. + PMOVMSKB X5, DX + TESTQ DX, DX + JNE findNaNOrInf64SSSE3AsmFound + + // Advance loop. + ADDQ $8, CX + CMPQ BX, CX + JGE findNaNOrInf64SSSE3AsmLoop + + // Less than 8 float64s left... + CMPQ R9, CX + JE findNaNOrInf64SSSE3AsmNotFound + + // ...but more than zero. Set CX := nElem - 8, and start one last + // loop iteration. + MOVQ BX, CX + JMP findNaNOrInf64SSSE3AsmLoop + +findNaNOrInf64SSSE3AsmNotFound: + MOVQ $-1, ret+16(FP) + RET + +findNaNOrInf64SSSE3AsmFound: + // Determine the position of the lowest set bit in DX, i.e. the byte + // offset of the first comparison success. + BSFQ DX, BX + // We compared 2-byte groups, so divide by 2 to determine the original + // index. + SHRQ $1, BX + ADDQ CX, BX + MOVQ BX, ret+16(FP) + RET + + +TEXT ·findNaNOrInf64AVX2Asm(SB),4,$0-24 + // findNaNOrInf64AVX2Asm is nearly identical to the SSSE3 version; it + // just compares 16 float64s at a time instead of 8. + MOVQ data+0(FP), AX + MOVQ nElem+8(FP), BX + MOVQ BX, R9 + SUBQ $16, BX + XORL CX, CX + MOVQ AX, SI + MOVQ AX, DI + MOVQ AX, R8 + ADDQ $32, SI + ADDQ $64, DI + ADDQ $96, R8 + + VMOVDQU ·ExponentMask<>(SB), Y0 + VMOVDQU ·FirstShuffle<>(SB), Y1 + VMOVDQU ·SecondShuffle<>(SB), Y2 + VMOVDQU ·ThirdShuffle<>(SB), Y3 + VMOVDQU ·FourthShuffle<>(SB), Y4 + +findNaNOrInf64AVX2AsmLoop: + // Scan 16 float64s, starting from &(data[CX]), into Y5..Y8. + VMOVDQU (AX)(CX*8), Y5 + VMOVDQU (SI)(CX*8), Y6 + VMOVDQU (DI)(CX*8), Y7 + VMOVDQU (R8)(CX*8), Y8 + + // Extract exponent bytes. + VPSHUFB Y1, Y5, Y5 + VPSHUFB Y2, Y6, Y6 + VPSHUFB Y3, Y7, Y7 + VPSHUFB Y4, Y8, Y8 + + // Collect into Y5. + VPOR Y6, Y5, Y5 + VPOR Y8, Y7, Y7 + VPOR Y7, Y5, Y5 + + // Mask out non-exponent bits, and then compare 2-byte groups in + // parallel. + VPAND Y0, Y5, Y5 + VPCMPEQW Y0, Y5, Y5 + + // Check result. + VPMOVMSKB Y5, DX + TESTQ DX, DX + JNE findNaNOrInf64AVX2AsmFound + + // Advance loop. + ADDQ $16, CX + CMPQ BX, CX + JGE findNaNOrInf64AVX2AsmLoop + + // Less than 8 float64s left... + CMPQ R9, CX + JE findNaNOrInf64AVX2AsmNotFound + + // ...but more than zero. Set CX := nElem - 8, and start one last + // loop iteration. + MOVQ BX, CX + JMP findNaNOrInf64AVX2AsmLoop + +findNaNOrInf64AVX2AsmNotFound: + MOVQ $-1, ret+16(FP) + RET + +findNaNOrInf64AVX2AsmFound: + // Since the PSHUFB instruction acts separately on the two 16-byte + // "lanes", the 2-byte chunks in Y5, and consequently the 2-bit groups + // in DX here, are drawn from &(data[CX])..&(data[CX+15]) in the + // following order: + // 0 1 4 5 8 9 12 13 2 3 6 7 10 11 14 15 + // We "unscramble" this before grabbing the lowest set bit. + + // Clear odd bits. + ANDQ $0x55555555, DX + + // Rearrange to + // 0 1 * * 4 5 * * 8 9 * * 12 13 * * 2 3 ... + // where the above refers to single bits, and * denotes a cleared bit. + MOVQ DX, BX + SHRQ $1, BX + ORQ BX, DX + ANDQ $0x33333333, DX + + // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 ... + MOVQ DX, BX + SHRQ $14, BX + ORQ BX, DX + + // Okay, now we're ready. + BSFQ DX, BX + ADDQ CX, BX + MOVQ BX, ret+16(FP) + RET diff --git a/simd/float_generic.go b/simd/float_generic.go new file mode 100644 index 00000000..14c2758a --- /dev/null +++ b/simd/float_generic.go @@ -0,0 +1,24 @@ +// Copyright 2021 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +// +build !amd64 appengine + +package simd + +import ( + "math" +) + +// FindNaNOrInf64 returns the position of the first NaN/inf value if one is +// present, and -1 otherwise. +func FindNaNOrInf64(data []float64) int { + for i, x := range data { + // Extract the exponent bits, and check if they're all set: that (and only + // that) corresponds to NaN/inf. + if (math.Float64bits(x) & (0x7ff << 52)) == (0x7ff << 52) { + return i + } + } + return -1 +} diff --git a/simd/float_test.go b/simd/float_test.go new file mode 100644 index 00000000..3b6baa6f --- /dev/null +++ b/simd/float_test.go @@ -0,0 +1,222 @@ +// Copyright 2021 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +package simd_test + +import ( + "math" + "math/rand" + "testing" + + "github.com/grailbio/base/simd" + "github.com/grailbio/testutil/expect" +) + +func findNaNOrInf64Standard(data []float64) int { + for i, x := range data { + if math.IsNaN(x) || (x > math.MaxFloat64) || (x < -math.MaxFloat64) { + return i + } + } + return -1 +} + +func getPossiblyNaNOrInfFloat64(rate float64) float64 { + var x float64 + if rand.Float64() < rate { + r := rand.Intn(3) + if r == 0 { + x = math.NaN() + } else { + // -inf if r == 1, +inf if r == 2. + x = math.Inf(r - 2) + } + } else { + // Exponentially-distributed random number in + // [-math.MaxFloat64, math.MaxFloat64]. + x = rand.ExpFloat64() + if rand.Intn(2) != 0 { + x = -x + } + } + return x +} + +func TestFindNaNOrInf(t *testing.T) { + // Exhausively test all first-NaN/inf positions for sizes in 0..32. + for size := 0; size <= 32; size++ { + slice := make([]float64, size) + got := simd.FindNaNOrInf64(slice) + want := findNaNOrInf64Standard(slice) + expect.EQ(t, got, want) + expect.EQ(t, got, -1) + + for target := size - 1; target >= 0; target-- { + slice[target] = math.Inf(1) + // Randomize everything after this position, maximizing entropy. + for i := target + 1; i < size; i++ { + slice[i] = getPossiblyNaNOrInfFloat64(0.5) + } + got = simd.FindNaNOrInf64(slice) + want = findNaNOrInf64Standard(slice) + expect.EQ(t, got, want) + expect.EQ(t, got, target) + } + for i := range slice { + slice[i] = 0.0 + } + for target := size - 1; target >= 0; target-- { + slice[target] = math.NaN() + for i := target + 1; i < size; i++ { + slice[i] = getPossiblyNaNOrInfFloat64(0.5) + } + got = simd.FindNaNOrInf64(slice) + want = findNaNOrInf64Standard(slice) + expect.EQ(t, got, want) + expect.EQ(t, got, target) + } + } + // Random test for larger sizes. + maxSize := 30000 + nIter := 200 + rand.Seed(1) + for iter := 0; iter < nIter; iter++ { + size := 1 + rand.Intn(maxSize) + rate := rand.Float64() + slice := make([]float64, size) + for i := range slice { + slice[i] = getPossiblyNaNOrInfFloat64(rate) + } + + for pos := 0; ; { + got := simd.FindNaNOrInf64(slice[pos:]) + want := findNaNOrInf64Standard(slice[pos:]) + expect.EQ(t, got, want) + if got == -1 { + break + } + pos += got + 1 + } + } +} + +type float64Args struct { + main []float64 +} + +func findNaNOrInfSimdSubtask(args interface{}, nIter int) int { + a := args.(float64Args) + slice := a.main + sum := 0 + pos := 0 + for iter := 0; iter < nIter; iter++ { + got := simd.FindNaNOrInf64(slice[pos:]) + sum += got + if got == -1 { + pos = 0 + } else { + pos += got + 1 + } + } + return sum +} + +func findNaNOrInf64Bitwise(data []float64) int { + for i, x := range data { + // Extract the exponent bits, and check if they're all set: that (and only + // that) corresponds to NaN/inf. + // Interestingly, the performance of this idiom degrades significantly, + // relative to + // "math.IsNaN(x) || x > math.MaxFloat64 || x < -math.MaxFloat64", + // if x is interpreted as a float64 anywhere in this loop. + if (math.Float64bits(x) & (0x7ff << 52)) == (0x7ff << 52) { + return i + } + } + return -1 +} + +func findNaNOrInfBitwiseSubtask(args interface{}, nIter int) int { + a := args.(float64Args) + slice := a.main + sum := 0 + pos := 0 + for iter := 0; iter < nIter; iter++ { + got := findNaNOrInf64Bitwise(slice[pos:]) + sum += got + if got == -1 { + pos = 0 + } else { + pos += got + 1 + } + } + return sum +} + +func findNaNOrInfStandardSubtask(args interface{}, nIter int) int { + a := args.(float64Args) + slice := a.main + sum := 0 + pos := 0 + for iter := 0; iter < nIter; iter++ { + got := findNaNOrInf64Standard(slice[pos:]) + sum += got + if got == -1 { + pos = 0 + } else { + pos += got + 1 + } + } + return sum +} + +// On an m5.16xlarge: +// $ bazel run //go/src/github.com/grailbio/base/simd:go_default_test -- -test.bench=FindNaNOrInf +// ... +// Benchmark_FindNaNOrInf/SIMDLong1Cpu-64 82 14053127 ns/op +// Benchmark_FindNaNOrInf/SIMDLongHalfCpu-64 960 1143599 ns/op +// Benchmark_FindNaNOrInf/SIMDLongAllCpu-64 1143 1018525 ns/op +// Benchmark_FindNaNOrInf/BitwiseLong1Cpu-64 8 126930287 ns/op +// Benchmark_FindNaNOrInf/BitwiseLongHalfCpu-64 253 6668467 ns/op +// Benchmark_FindNaNOrInf/BitwiseLongAllCpu-64 229 4679633 ns/op +// Benchmark_FindNaNOrInf/StandardLong1Cpu-64 7 158318559 ns/op +// Benchmark_FindNaNOrInf/StandardLongHalfCpu-64 190 6223669 ns/op +// Benchmark_FindNaNOrInf/StandardLongAllCpu-64 171 6746008 ns/op +// PASS +func Benchmark_FindNaNOrInf(b *testing.B) { + funcs := []taggedMultiBenchVarargsFunc{ + { + f: findNaNOrInfSimdSubtask, + tag: "SIMD", + }, + { + f: findNaNOrInfBitwiseSubtask, + tag: "Bitwise", + }, + { + f: findNaNOrInfStandardSubtask, + tag: "Standard", + }, + } + rand.Seed(1) + for _, f := range funcs { + multiBenchmarkVarargs(f.f, f.tag+"Long", 100000, func() interface{} { + main := make([]float64, 30000) + // Results were overly influenced by RNG if the number of NaNs/infs in + // the slice was not controlled. + for i := 0; i < 30; i++ { + for { + pos := rand.Intn(len(main)) + if main[pos] != math.Inf(0) { + main[pos] = math.Inf(0) + break + } + } + } + return float64Args{ + main: main, + } + }, b) + } +} diff --git a/simd/invmask_amd64.go b/simd/invmask_amd64.go index f2d94cf4..72994e04 100644 --- a/simd/invmask_amd64.go +++ b/simd/invmask_amd64.go @@ -1,8 +1,10 @@ -// Code generated from " ../gtl/generate.py --prefix=Invmask -DOPCHAR=&^ --package=simd --output=invmask_amd64.go invmask_amd64.go.tpl ". DO NOT EDIT. -// Copyright 2018 GRAIL, Inc. All rights reserved. +// Code generated by "../gtl/generate.py --prefix=Invmask -DOPCHAR=&^ --package=simd --output=invmask_amd64.go bitwise_amd64.go.tpl". DO NOT EDIT. + +// Copyright 2021 GRAIL, Inc. All rights reserved. // Use of this source code is governed by the Apache-2.0 // license that can be found in the LICENSE file. +//go:build amd64 && !appengine // +build amd64,!appengine package simd @@ -12,7 +14,7 @@ import ( "unsafe" ) -// InvmaskUnsafeInplace sets main[pos] := arg[pos] &^ main[pos] for every position +// InvmaskUnsafeInplace sets main[pos] := main[pos] &^ arg[pos] for every position // in main[]. // // WARNING: This is a function designed to be used in inner loops, which makes @@ -30,18 +32,18 @@ import ( // changed. func InvmaskUnsafeInplace(main, arg []byte) { mainLen := len(main) - argHeader := (*reflect.SliceHeader)(unsafe.Pointer(&arg)) - mainHeader := (*reflect.SliceHeader)(unsafe.Pointer(&main)) - argWordsIter := unsafe.Pointer(argHeader.Data) - mainWordsIter := unsafe.Pointer(mainHeader.Data) + argData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&arg)).Data) + mainData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&main)).Data) + argWordsIter := argData + mainWordsIter := mainData if mainLen > 2*BytesPerWord { nWordMinus2 := (mainLen - BytesPerWord - 1) >> Log2BytesPerWord for widx := 0; widx < nWordMinus2; widx++ { mainWord := *((*uintptr)(mainWordsIter)) argWord := *((*uintptr)(argWordsIter)) *((*uintptr)(mainWordsIter)) = mainWord &^ argWord - mainWordsIter = unsafe.Pointer(uintptr(mainWordsIter) + BytesPerWord) - argWordsIter = unsafe.Pointer(uintptr(argWordsIter) + BytesPerWord) + mainWordsIter = unsafe.Add(mainWordsIter, BytesPerWord) + argWordsIter = unsafe.Add(argWordsIter, BytesPerWord) } } else if mainLen <= BytesPerWord { mainWord := *((*uintptr)(mainWordsIter)) @@ -56,8 +58,8 @@ func InvmaskUnsafeInplace(main, arg []byte) { mainWord1 := *((*uintptr)(mainWordsIter)) argWord1 := *((*uintptr)(argWordsIter)) finalOffset := uintptr(mainLen - BytesPerWord) - mainFinalWordPtr := unsafe.Pointer(mainHeader.Data + finalOffset) - argFinalWordPtr := unsafe.Pointer(argHeader.Data + finalOffset) + mainFinalWordPtr := unsafe.Add(mainData, finalOffset) + argFinalWordPtr := unsafe.Add(argData, finalOffset) mainWord2 := *((*uintptr)(mainFinalWordPtr)) argWord2 := *((*uintptr)(argFinalWordPtr)) *((*uintptr)(mainWordsIter)) = mainWord1 &^ argWord1 @@ -82,25 +84,25 @@ func InvmaskInplace(main, arg []byte) { } return } - argHeader := (*reflect.SliceHeader)(unsafe.Pointer(&arg)) - mainHeader := (*reflect.SliceHeader)(unsafe.Pointer(&main)) - argWordsIter := unsafe.Pointer(argHeader.Data) - mainWordsIter := unsafe.Pointer(mainHeader.Data) + argData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&arg)).Data) + mainData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&main)).Data) + argWordsIter := argData + mainWordsIter := mainData if mainLen > 2*BytesPerWord { nWordMinus2 := (mainLen - BytesPerWord - 1) >> Log2BytesPerWord for widx := 0; widx < nWordMinus2; widx++ { mainWord := *((*uintptr)(mainWordsIter)) argWord := *((*uintptr)(argWordsIter)) *((*uintptr)(mainWordsIter)) = mainWord &^ argWord - mainWordsIter = unsafe.Pointer(uintptr(mainWordsIter) + BytesPerWord) - argWordsIter = unsafe.Pointer(uintptr(argWordsIter) + BytesPerWord) + mainWordsIter = unsafe.Add(mainWordsIter, BytesPerWord) + argWordsIter = unsafe.Add(argWordsIter, BytesPerWord) } } mainWord1 := *((*uintptr)(mainWordsIter)) argWord1 := *((*uintptr)(argWordsIter)) finalOffset := uintptr(mainLen - BytesPerWord) - mainFinalWordPtr := unsafe.Pointer(mainHeader.Data + finalOffset) - argFinalWordPtr := unsafe.Pointer(argHeader.Data + finalOffset) + mainFinalWordPtr := unsafe.Add(mainData, finalOffset) + argFinalWordPtr := unsafe.Add(argData, finalOffset) mainWord2 := *((*uintptr)(mainFinalWordPtr)) argWord2 := *((*uintptr)(argFinalWordPtr)) *((*uintptr)(mainWordsIter)) = mainWord1 &^ argWord1 @@ -135,9 +137,9 @@ func InvmaskUnsafe(dst, src1, src2 []byte) { src1Word := *((*uintptr)(src1Iter)) src2Word := *((*uintptr)(src2Iter)) *((*uintptr)(dstIter)) = src1Word &^ src2Word - src1Iter = unsafe.Pointer(uintptr(src1Iter) + BytesPerWord) - src2Iter = unsafe.Pointer(uintptr(src2Iter) + BytesPerWord) - dstIter = unsafe.Pointer(uintptr(dstIter) + BytesPerWord) + src1Iter = unsafe.Add(src1Iter, BytesPerWord) + src2Iter = unsafe.Add(src2Iter, BytesPerWord) + dstIter = unsafe.Add(dstIter, BytesPerWord) } } @@ -154,28 +156,160 @@ func Invmask(dst, src1, src2 []byte) { } return } - src1Header := (*reflect.SliceHeader)(unsafe.Pointer(&src1)) - src2Header := (*reflect.SliceHeader)(unsafe.Pointer(&src2)) - dstHeader := (*reflect.SliceHeader)(unsafe.Pointer(&dst)) + src1Data := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&src1)).Data) + src2Data := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&src2)).Data) + dstData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&dst)).Data) nWordMinus1 := (dstLen - 1) >> Log2BytesPerWord - src1Iter := unsafe.Pointer(src1Header.Data) - src2Iter := unsafe.Pointer(src2Header.Data) - dstIter := unsafe.Pointer(dstHeader.Data) + src1Iter := src1Data + src2Iter := src2Data + dstIter := dstData for widx := 0; widx < nWordMinus1; widx++ { src1Word := *((*uintptr)(src1Iter)) src2Word := *((*uintptr)(src2Iter)) *((*uintptr)(dstIter)) = src1Word &^ src2Word - src1Iter = unsafe.Pointer(uintptr(src1Iter) + BytesPerWord) - src2Iter = unsafe.Pointer(uintptr(src2Iter) + BytesPerWord) - dstIter = unsafe.Pointer(uintptr(dstIter) + BytesPerWord) + src1Iter = unsafe.Add(src1Iter, BytesPerWord) + src2Iter = unsafe.Add(src2Iter, BytesPerWord) + dstIter = unsafe.Add(dstIter, BytesPerWord) } // No store-forwarding problem here. finalOffset := uintptr(dstLen - BytesPerWord) - src1Iter = unsafe.Pointer(src1Header.Data + finalOffset) - src2Iter = unsafe.Pointer(src2Header.Data + finalOffset) - dstIter = unsafe.Pointer(dstHeader.Data + finalOffset) + src1Iter = unsafe.Add(src1Data, finalOffset) + src2Iter = unsafe.Add(src2Data, finalOffset) + dstIter = unsafe.Add(dstData, finalOffset) src1Word := *((*uintptr)(src1Iter)) src2Word := *((*uintptr)(src2Iter)) *((*uintptr)(dstIter)) = src1Word &^ src2Word } + +// InvmaskConst8UnsafeInplace sets main[pos] := main[pos] &^ val for every position +// in main[]. +// +// WARNING: This is a function designed to be used in inner loops, which makes +// assumptions about length and capacity which aren't checked at runtime. Use +// the safe version of this function when that's a problem. +// These assumptions are always satisfied when the last +// potentially-size-increasing operation on main[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(). +// +// 1. cap(main) is at least RoundUpPow2(len(main) + 1, bytesPerVec). +// +// 2. The caller does not care if a few bytes past the end of main[] are +// changed. +func InvmaskConst8UnsafeInplace(main []byte, val byte) { + mainLen := len(main) + argWord := 0x101010101010101 * uintptr(val) + mainData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&main)).Data) + mainWordsIter := mainData + if mainLen > 2*BytesPerWord { + nWordMinus2 := (mainLen - BytesPerWord - 1) >> Log2BytesPerWord + for widx := 0; widx < nWordMinus2; widx++ { + mainWord := *((*uintptr)(mainWordsIter)) + *((*uintptr)(mainWordsIter)) = mainWord &^ argWord + mainWordsIter = unsafe.Add(mainWordsIter, BytesPerWord) + } + } else if mainLen <= BytesPerWord { + mainWord := *((*uintptr)(mainWordsIter)) + *((*uintptr)(mainWordsIter)) = mainWord &^ argWord + return + } + mainWord1 := *((*uintptr)(mainWordsIter)) + finalOffset := uintptr(mainLen - BytesPerWord) + mainFinalWordPtr := unsafe.Add(mainData, finalOffset) + mainWord2 := *((*uintptr)(mainFinalWordPtr)) + *((*uintptr)(mainWordsIter)) = mainWord1 &^ argWord + *((*uintptr)(mainFinalWordPtr)) = mainWord2 &^ argWord +} + +// InvmaskConst8Inplace sets main[pos] := main[pos] &^ val for every position in +// main[]. +func InvmaskConst8Inplace(main []byte, val byte) { + mainLen := len(main) + if mainLen < BytesPerWord { + for pos, mainByte := range main { + main[pos] = mainByte &^ val + } + return + } + argWord := 0x101010101010101 * uintptr(val) + mainData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&main)).Data) + mainWordsIter := mainData + if mainLen > 2*BytesPerWord { + nWordMinus2 := (mainLen - BytesPerWord - 1) >> Log2BytesPerWord + for widx := 0; widx < nWordMinus2; widx++ { + mainWord := *((*uintptr)(mainWordsIter)) + *((*uintptr)(mainWordsIter)) = mainWord &^ argWord + mainWordsIter = unsafe.Add(mainWordsIter, BytesPerWord) + } + } + mainWord1 := *((*uintptr)(mainWordsIter)) + finalOffset := uintptr(mainLen - BytesPerWord) + mainFinalWordPtr := unsafe.Add(mainData, finalOffset) + mainWord2 := *((*uintptr)(mainFinalWordPtr)) + *((*uintptr)(mainWordsIter)) = mainWord1 &^ argWord + *((*uintptr)(mainFinalWordPtr)) = mainWord2 &^ argWord +} + +// InvmaskConst8Unsafe sets dst[pos] := src[pos] &^ val for every position in dst. +// +// WARNING: This is a function designed to be used in inner loops, which makes +// assumptions about length and capacity which aren't checked at runtime. Use +// the safe version of this function when that's a problem. +// Assumptions #2-3 are always satisfied when the last +// potentially-size-increasing operation on src[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(), and the same is true for dst[]. +// +// 1. len(src) and len(dst) must be equal. +// +// 2. Capacities are at least RoundUpPow2(len(dst) + 1, bytesPerVec). +// +// 3. The caller does not care if a few bytes past the end of dst[] are +// changed. +func InvmaskConst8Unsafe(dst, src []byte, val byte) { + srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src)) + dstHeader := (*reflect.SliceHeader)(unsafe.Pointer(&dst)) + nWord := DivUpPow2(len(dst), BytesPerWord, Log2BytesPerWord) + argWord := 0x101010101010101 * uintptr(val) + + srcIter := unsafe.Pointer(srcHeader.Data) + dstIter := unsafe.Pointer(dstHeader.Data) + for widx := 0; widx < nWord; widx++ { + srcWord := *((*uintptr)(srcIter)) + *((*uintptr)(dstIter)) = srcWord &^ argWord + srcIter = unsafe.Add(srcIter, BytesPerWord) + dstIter = unsafe.Add(dstIter, BytesPerWord) + } +} + +// InvmaskConst8 sets dst[pos] := src[pos] &^ val for every position in dst. It +// panics if slice lengths don't match. +func InvmaskConst8(dst, src []byte, val byte) { + dstLen := len(dst) + if len(src) != dstLen { + panic("InvmaskConst8() requires len(src) == len(dst).") + } + if dstLen < BytesPerWord { + for pos, srcByte := range src { + dst[pos] = srcByte &^ val + } + return + } + srcData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&src)).Data) + dstData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&dst)).Data) + nWordMinus1 := (dstLen - 1) >> Log2BytesPerWord + argWord := 0x101010101010101 * uintptr(val) + + srcIter := unsafe.Pointer(srcData) + dstIter := unsafe.Pointer(dstData) + for widx := 0; widx < nWordMinus1; widx++ { + srcWord := *((*uintptr)(srcIter)) + *((*uintptr)(dstIter)) = srcWord &^ argWord + srcIter = unsafe.Add(srcIter, BytesPerWord) + dstIter = unsafe.Add(dstIter, BytesPerWord) + } + finalOffset := uintptr(dstLen - BytesPerWord) + srcIter = unsafe.Add(srcData, finalOffset) + dstIter = unsafe.Add(dstData, finalOffset) + srcWord := *((*uintptr)(srcIter)) + *((*uintptr)(dstIter)) = srcWord &^ argWord +} diff --git a/simd/invmask_generic.go b/simd/invmask_generic.go new file mode 100644 index 00000000..8d484594 --- /dev/null +++ b/simd/invmask_generic.go @@ -0,0 +1,135 @@ +// Code generated by " ../gtl/generate.py --prefix=Invmask -DOPCHAR=&^ --package=simd --output=invmask_generic.go bitwise_generic.go.tpl ". DO NOT EDIT. + +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +// +build !amd64 appengine + +package simd + +// InvmaskUnsafeInplace sets main[pos] := main[pos] &^ arg[pos] for every position +// in main[]. +// +// WARNING: This is a function designed to be used in inner loops, which makes +// assumptions about length and capacity which aren't checked at runtime. Use +// the safe version of this function when that's a problem. +// Assumptions #2-3 are always satisfied when the last +// potentially-size-increasing operation on arg[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(), and the same is true for main[]. +// +// 1. len(arg) and len(main) must be equal. +// +// 2. Capacities are at least RoundUpPow2(len(main) + 1, bytesPerVec). +// +// 3. The caller does not care if a few bytes past the end of main[] are +// changed. +func InvmaskUnsafeInplace(main, arg []byte) { + for i, x := range main { + main[i] = x &^ arg[i] + } +} + +// InvmaskInplace sets main[pos] := main[pos] &^ arg[pos] for every position in +// main[]. It panics if slice lengths don't match. +func InvmaskInplace(main, arg []byte) { + if len(arg) != len(main) { + panic("InvmaskInplace() requires len(arg) == len(main).") + } + for i, x := range main { + main[i] = x &^ arg[i] + } +} + +// InvmaskUnsafe sets dst[pos] := src1[pos] &^ src2[pos] for every position in dst. +// +// WARNING: This is a function designed to be used in inner loops, which makes +// assumptions about length and capacity which aren't checked at runtime. Use +// the safe version of this function when that's a problem. +// Assumptions #2-3 are always satisfied when the last +// potentially-size-increasing operation on src1[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(), and the same is true for src2[] and dst[]. +// +// 1. len(src1), len(src2), and len(dst) must be equal. +// +// 2. Capacities are at least RoundUpPow2(len(dst) + 1, bytesPerVec). +// +// 3. The caller does not care if a few bytes past the end of dst[] are +// changed. +func InvmaskUnsafe(dst, src1, src2 []byte) { + for i, x := range src1 { + dst[i] = x &^ src2[i] + } +} + +// Invmask sets dst[pos] := src1[pos] &^ src2[pos] for every position in dst. It +// panics if slice lengths don't match. +func Invmask(dst, src1, src2 []byte) { + dstLen := len(dst) + if (len(src1) != dstLen) || (len(src2) != dstLen) { + panic("Invmask() requires len(src1) == len(src2) == len(dst).") + } + for i, x := range src1 { + dst[i] = x &^ src2[i] + } +} + +// InvmaskConst8UnsafeInplace sets main[pos] := main[pos] &^ val for every position +// in main[]. +// +// WARNING: This is a function designed to be used in inner loops, which makes +// assumptions about length and capacity which aren't checked at runtime. Use +// the safe version of this function when that's a problem. +// These assumptions are always satisfied when the last +// potentially-size-increasing operation on main[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(). +// +// 1. cap(main) is at least RoundUpPow2(len(main) + 1, bytesPerVec). +// +// 2. The caller does not care if a few bytes past the end of main[] are +// changed. +func InvmaskConst8UnsafeInplace(main []byte, val byte) { + for i, x := range main { + main[i] = x &^ val + } +} + +// InvmaskConst8Inplace sets main[pos] := main[pos] &^ val for every position in +// main[]. +func InvmaskConst8Inplace(main []byte, val byte) { + for i, x := range main { + main[i] = x &^ val + } +} + +// InvmaskConst8Unsafe sets dst[pos] := src[pos] &^ val for every position in dst. +// +// WARNING: This is a function designed to be used in inner loops, which makes +// assumptions about length and capacity which aren't checked at runtime. Use +// the safe version of this function when that's a problem. +// Assumptions #2-3 are always satisfied when the last +// potentially-size-increasing operation on src[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(), and the same is true for dst[]. +// +// 1. len(src) and len(dst) must be equal. +// +// 2. Capacities are at least RoundUpPow2(len(dst) + 1, bytesPerVec). +// +// 3. The caller does not care if a few bytes past the end of dst[] are +// changed. +func InvmaskConst8Unsafe(dst, src []byte, val byte) { + for i, x := range src { + dst[i] = x &^ val + } +} + +// InvmaskConst8 sets dst[pos] := src[pos] &^ val for every position in dst. It +// panics if slice lengths don't match. +func InvmaskConst8(dst, src []byte, val byte) { + if len(src) != len(dst) { + panic("InvmaskConst8() requires len(src) == len(dst).") + } + for i, x := range src { + dst[i] = x &^ val + } +} diff --git a/simd/multi_benchmark_test.go b/simd/multi_benchmark_test.go new file mode 100644 index 00000000..da6a3746 --- /dev/null +++ b/simd/multi_benchmark_test.go @@ -0,0 +1,172 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +package simd_test + +import ( + "runtime" + "testing" + + "github.com/grailbio/base/simd" + "github.com/grailbio/base/traverse" +) + +// Utility functions to assist with benchmarking of embarrassingly parallel +// jobs. It probably makes sense to move this code to a more central location +// at some point. + +type multiBenchFunc func(dst, src []byte, nIter int) int + +type taggedMultiBenchFunc struct { + f multiBenchFunc + tag string +} + +type bytesInitFunc func(src []byte) + +type multiBenchmarkOpts struct { + dstInit bytesInitFunc + srcInit bytesInitFunc +} + +func multiBenchmark(bf multiBenchFunc, benchmarkSubtype string, nDstByte, nSrcByte, nJob int, b *testing.B, opts ...multiBenchmarkOpts) { + // 'bf' is expected to execute the benchmarking target nIter times. + // + // Given that, for each of the 3 nCpu settings below, multiBenchmark launches + // 'parallelism' goroutines, where each goroutine has nIter set to roughly + // (nJob / nCpu), so that the total number of benchmark-target-function + // invocations across all threads is nJob. It is designed to measure how + // effective traverse.Each-style parallelization is at reducing wall-clock + // runtime. + totalCpu := runtime.NumCPU() + cases := []struct { + nCpu int + descrip string + }{ + { + nCpu: 1, + descrip: "1Cpu", + }, + // 'Half' is often the saturation point, due to hyperthreading. + { + nCpu: (totalCpu + 1) / 2, + descrip: "HalfCpu", + }, + { + nCpu: totalCpu, + descrip: "AllCpu", + }, + } + var dstInit bytesInitFunc + var srcInit bytesInitFunc + if len(opts) >= 1 { + dstInit = opts[0].dstInit + srcInit = opts[0].srcInit + } + for _, c := range cases { + success := b.Run(benchmarkSubtype+c.descrip, func(b *testing.B) { + dsts := make([][]byte, c.nCpu) + srcs := make([][]byte, c.nCpu) + for i := 0; i < c.nCpu; i++ { + // Add 63 to prevent false sharing. + newArrDst := simd.MakeUnsafe(nDstByte + 63) + newArrSrc := simd.MakeUnsafe(nSrcByte + 63) + if i == 0 { + if dstInit != nil { + dstInit(newArrDst) + } + if srcInit != nil { + srcInit(newArrSrc) + } else { + for j := 0; j < nSrcByte; j++ { + newArrSrc[j] = byte(j * 3) + } + } + } else { + if dstInit != nil { + copy(newArrDst[:nDstByte], dsts[0]) + } + copy(newArrSrc[:nSrcByte], srcs[0]) + } + dsts[i] = newArrDst[:nDstByte] + srcs[i] = newArrSrc[:nSrcByte] + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + // May want to replace this with something based on testing.B's + // RunParallel method. (Haven't done so yet since I don't see a clean + // way to make that play well with per-core preallocated buffers.) + _ = traverse.Each(c.nCpu, func(threadIdx int) error { + nIter := (((threadIdx + 1) * nJob) / c.nCpu) - ((threadIdx * nJob) / c.nCpu) + _ = bf(dsts[threadIdx], srcs[threadIdx], nIter) + return nil + }) + } + }) + if !success { + panic("benchmark failed") + } + } +} + +func bytesInit0(src []byte) { + // do nothing +} + +func bytesInitMax15(src []byte) { + for i := 0; i < len(src); i++ { + src[i] = byte(i*3) & 15 + } +} + +type multiBenchVarargsFunc func(args interface{}, nIter int) int + +type taggedMultiBenchVarargsFunc struct { + f multiBenchVarargsFunc + tag string +} + +type varargsFactory func() interface{} + +func multiBenchmarkVarargs(bvf multiBenchVarargsFunc, benchmarkSubtype string, nJob int, argsFactory varargsFactory, b *testing.B) { + totalCpu := runtime.NumCPU() + cases := []struct { + nCpu int + descrip string + }{ + { + nCpu: 1, + descrip: "1Cpu", + }, + { + nCpu: (totalCpu + 1) / 2, + descrip: "HalfCpu", + }, + { + nCpu: totalCpu, + descrip: "AllCpu", + }, + } + for _, c := range cases { + success := b.Run(benchmarkSubtype+c.descrip, func(b *testing.B) { + var argSlice []interface{} + for i := 0; i < c.nCpu; i++ { + // Can take an "args interface{}" parameter and make deep copies + // instead. + argSlice = append(argSlice, argsFactory()) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = traverse.Each(c.nCpu, func(threadIdx int) error { + nIter := (((threadIdx + 1) * nJob) / c.nCpu) - ((threadIdx * nJob) / c.nCpu) + _ = bvf(argSlice[threadIdx], nIter) + return nil + }) + } + }) + if !success { + panic("benchmark failed") + } + } +} diff --git a/simd/multibyte_amd64.go b/simd/multibyte_amd64.go index 3de8bb68..cb84e2b2 100644 --- a/simd/multibyte_amd64.go +++ b/simd/multibyte_amd64.go @@ -1,7 +1,8 @@ -// Copyright 2018 GRAIL, Inc. All rights reserved. +// Copyright 2021 GRAIL, Inc. All rights reserved. // Use of this source code is governed by the Apache-2.0 // license that can be found in the LICENSE file. +//go:build amd64 && !appengine // +build amd64,!appengine package simd @@ -24,6 +25,9 @@ import ( // *** the following functions are defined in multibyte_amd64.s +//go:noescape +func index16SSE2Asm(main unsafe.Pointer, val, nElem int) int + //go:noescape func reverse16InplaceSSSE3Asm(main unsafe.Pointer, nElem int) @@ -44,7 +48,7 @@ func Memset16Raw(dst, valPtr unsafe.Pointer, nElem int) { if nElem < BytesPerWord/2 { for idx := 0; idx != nElem; idx++ { *((*uint16)(dst)) = val - dst = unsafe.Pointer(uintptr(dst) + 2) + dst = unsafe.Add(dst, 2) } return } @@ -53,9 +57,9 @@ func Memset16Raw(dst, valPtr unsafe.Pointer, nElem int) { dstWordsIter := dst for widx := 0; widx != nWordMinus1; widx++ { *((*uintptr)(dstWordsIter)) = valWord - dstWordsIter = unsafe.Pointer(uintptr(dstWordsIter) + BytesPerWord) + dstWordsIter = unsafe.Add(dstWordsIter, BytesPerWord) } - dstWordsIter = unsafe.Pointer(uintptr(dst) + uintptr(nElem)*2 - BytesPerWord) + dstWordsIter = unsafe.Add(dst, nElem*2-BytesPerWord) *((*uintptr)(dstWordsIter)) = valWord } @@ -75,9 +79,9 @@ func Memset32Raw(dst, valPtr unsafe.Pointer, nElem int) { dstWordsIter := dst for widx := 0; widx != nWordMinus1; widx++ { *((*uintptr)(dstWordsIter)) = valWord - dstWordsIter = unsafe.Pointer(uintptr(dstWordsIter) + BytesPerWord) + dstWordsIter = unsafe.Add(dstWordsIter, BytesPerWord) } - dstWordsIter = unsafe.Pointer(uintptr(dst) + uintptr(nElem)*4 - BytesPerWord) + dstWordsIter = unsafe.Add(dst, nElem*4-BytesPerWord) *((*uintptr)(dstWordsIter)) = valWord } @@ -93,6 +97,21 @@ func RepeatU16(dst []uint16, val uint16) { Memset16Raw(unsafe.Pointer(dstHeader.Data), unsafe.Pointer(&val), dstHeader.Len) } +// IndexU16 returns the index of the first instance of val in main, or -1 if +// val is not present in main. +func IndexU16(main []uint16, val uint16) int { + if len(main) < 8 { + for i, v := range main { + if v == val { + return i + } + } + return -1 + } + mainHeader := (*reflect.SliceHeader)(unsafe.Pointer(&main)) + return index16SSE2Asm(unsafe.Pointer(mainHeader.Data), int(val), mainHeader.Len) +} + // (Add a function which has the original little-endian byte-slice semantics if // we ever need it.) @@ -102,13 +121,13 @@ func Reverse16InplaceRaw(main unsafe.Pointer, nElem int) { if nElem <= 8 { nElemDiv2 := nElem >> 1 fwdIter := main - revIter := unsafe.Pointer(uintptr(main) + uintptr((nElem-1)*2)) + revIter := unsafe.Add(main, (nElem-1)*2) for idx := 0; idx != nElemDiv2; idx++ { origLeftVal := *((*uint16)(fwdIter)) *((*uint16)(fwdIter)) = *((*uint16)(revIter)) *((*uint16)(revIter)) = origLeftVal - fwdIter = unsafe.Pointer(uintptr(fwdIter) + 2) - revIter = unsafe.Pointer(uintptr(revIter) - 2) + fwdIter = unsafe.Add(fwdIter, 2) + revIter = unsafe.Add(revIter, -2) } return } @@ -119,12 +138,12 @@ func Reverse16InplaceRaw(main unsafe.Pointer, nElem int) { // and sets dst[pos] := src[ct - 1 - pos] for each position. func Reverse16Raw(dst, src unsafe.Pointer, nElem int) { if nElem < 8 { - srcIter := unsafe.Pointer(uintptr(src) + uintptr((nElem-1)*2)) + srcIter := unsafe.Add(src, (nElem-1)*2) dstIter := dst for idx := 0; idx != nElem; idx++ { *((*uint16)(dstIter)) = *((*uint16)(srcIter)) - srcIter = unsafe.Pointer(uintptr(srcIter) - 2) - dstIter = unsafe.Pointer(uintptr(dstIter) + 2) + srcIter = unsafe.Add(srcIter, -2) + dstIter = unsafe.Add(dstIter, 2) } return } diff --git a/simd/multibyte_amd64.s b/simd/multibyte_amd64.s index 6c1cf55e..038cd258 100644 --- a/simd/multibyte_amd64.s +++ b/simd/multibyte_amd64.s @@ -9,6 +9,68 @@ GLOBL ·Reverse16<>(SB), 24, $16 // NOPTR = 16, RODATA = 8 +TEXT ·index16SSE2Asm(SB),4,$0-32 + // index16SSE2Asm scans main[], searching for the first instance of + // val. If no instances are found, it returns -1. + // It requires nElem >= 8. + // The implementation is based on a loop which uses _mm_cmpeq_epi16() + // to scan 8 uint16s in parallel, and _mm_movemask_epi8() to extract + // the result of that scan. It is similar to firstLeq8 in cmp_amd64.s. + + // There's a ~10% benefit from 2x-unrolling the main loop so that only + // one test is performed per loop iteration (i.e. just look at the + // bitwise-or of the comparison results, and backtrack a bit on a hit). + // I'll leave that on the table for now to keep the logic simpler. + + // Register allocation: + // AX: pointer to start of main[] + // BX: nElem - 8 + // CX: current index + // X0: vector with 8 copies of val + MOVQ main+0(FP), AX + + // clang compiles _mm_set1_epi16() to this, I'll trust it. + MOVQ val+8(FP), X0 + PSHUFLW $0xe0, X0, X0 + PSHUFD $0, X0, X0 + + MOVQ nElem+16(FP), BX + SUBQ $8, BX + XORL CX, CX + +index16SSE2AsmLoop: + // Scan 8 elements starting from &(main[CX]). + MOVOU (AX)(CX*2), X1 + PCMPEQW X0, X1 + PMOVMSKB X1, DX + // Bits 2k and 2k+1 are now set in DX iff the uint16 at position k + // compared equal. + TESTQ DX, DX + JNE index16SSE2AsmFound + ADDQ $8, CX + CMPQ BX, CX + JG index16SSE2AsmLoop + + // Scan the last 8 elements; this may partially overlap with the + // previous scan. + MOVQ BX, CX + MOVOU (AX)(CX*2), X1 + PCMPEQW X0, X1 + PMOVMSKB X1, DX + TESTQ DX, DX + JNE index16SSE2AsmFound + // No match found, return -1. + MOVQ $-1, ret+24(FP) + RET + +index16SSE2AsmFound: + BSFQ DX, AX + // AX now has the index of the lowest set bit in DX. + SHRQ $1, AX + ADDQ CX, AX + MOVQ AX, ret+24(FP) + RET + TEXT ·reverse16InplaceSSSE3Asm(SB),4,$0-16 // This is only called with nElem > 8. So we can safely divide this // into two cases: diff --git a/simd/multibyte_appengine.go b/simd/multibyte_appengine.go new file mode 100644 index 00000000..217ca469 --- /dev/null +++ b/simd/multibyte_appengine.go @@ -0,0 +1,76 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +// +build appengine + +package simd + +// This file contains functions which operate on slices of 2- or 4-byte +// elements (typically small structs or integers) in ways that differ from the +// corresponding operations on single-byte elements. +// In this context, there is little point in making the interface based on +// []byte, since the caller will need to unsafely cast to it. Instead, most +// functions take unsafe.Pointer(s) and a count, and have names ending in +// 'Raw'; the caller should write safe wrappers around them when appropriate. +// We provide sample wrappers for the int16 and uint16 cases. (Originally did +// this for int32/uint32, but turns out the compiler has hardcoded +// optimizations for those cases which are currently missing for {u}int16.) + +// RepeatI16 fills dst[] with the given int16. +func RepeatI16(dst []int16, val int16) { + for i := range dst { + dst[i] = val + } +} + +// RepeatU16 fills dst[] with the given uint16. +func RepeatU16(dst []uint16, val uint16) { + for i := range dst { + dst[i] = val + } +} + +// ReverseI16Inplace reverses a []int16 in-place. +func ReverseI16Inplace(main []int16) { + nElem := len(main) + nElemDiv2 := nElem >> 1 + for i, j := 0, nElem-1; i != nElemDiv2; i, j = i+1, j-1 { + main[i], main[j] = main[j], main[i] + } +} + +// ReverseU16Inplace reverses a []uint16 in-place. +func ReverseU16Inplace(main []uint16) { + nElem := len(main) + nElemDiv2 := nElem >> 1 + for i, j := 0, nElem-1; i != nElemDiv2; i, j = i+1, j-1 { + main[i], main[j] = main[j], main[i] + } +} + +// ReverseI16 sets dst[len(src) - 1 - pos] := src[pos] for each position in +// src. It panics if len(src) != len(dst). +func ReverseI16(dst, src []int16) { + if len(dst) != len(src) { + panic("ReverseI16() requires len(src) == len(dst).") + } + nElemMinus1 := len(dst) - 1 + for i := range dst { + dst[i] = src[nElemMinus1-i] + } +} + +// ReverseU16 sets dst[len(src) - 1 - pos] := src[pos] for each position in +// src. It panics if len(src) != len(dst). +func ReverseU16(dst, src []uint16) { + if len(dst) != len(src) { + panic("ReverseU16() requires len(src) == len(dst).") + } + nElemMinus1 := len(dst) - 1 + for i := range dst { + dst[i] = src[nElemMinus1-i] + } +} + +// Benchmark results suggest that Reverse32Raw is unimportant. diff --git a/simd/multibyte_generic.go b/simd/multibyte_generic.go new file mode 100644 index 00000000..0a3c4a6b --- /dev/null +++ b/simd/multibyte_generic.go @@ -0,0 +1,137 @@ +// Copyright 2021 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +//go:build !amd64 && !appengine +// +build !amd64,!appengine + +package simd + +import ( + "reflect" + "unsafe" +) + +// This file contains functions which operate on slices of 2- or 4-byte +// elements (typically small structs or integers) in ways that differ from the +// corresponding operations on single-byte elements. +// In this context, there is little point in making the interface based on +// []byte, since the caller will need to unsafely cast to it. Instead, most +// functions take unsafe.Pointer(s) and a count, and have names ending in +// 'Raw'; the caller should write safe wrappers around them when appropriate. +// We provide sample wrappers for the int16 and uint16 cases. (Originally did +// this for int32/uint32, but turns out the compiler has hardcoded +// optimizations for those cases which are currently missing for {u}int16.) + +// Memset16Raw assumes dst points to an array of nElem 2-byte elements, and +// valPtr points to a single 2-byte element. It fills dst with copies of +// *valPtr. +func Memset16Raw(dst, valPtr unsafe.Pointer, nElem int) { + val := *((*uint16)(valPtr)) + for idx := 0; idx != nElem; idx++ { + *((*uint16)(dst)) = val + dst = unsafe.Add(dst, 2) + } +} + +// Memset32Raw assumes dst points to an array of nElem 4-byte elements, and +// valPtr points to a single 4-byte element. It fills dst with copies of +// *valPtr. +func Memset32Raw(dst, valPtr unsafe.Pointer, nElem int) { + val := *((*uint32)(valPtr)) + for idx := 0; idx != nElem; idx++ { + *((*uint32)(dst)) = val + dst = unsafe.Add(dst, 4) + } +} + +// RepeatI16 fills dst[] with the given int16. +func RepeatI16(dst []int16, val int16) { + dstHeader := (*reflect.SliceHeader)(unsafe.Pointer(&dst)) + Memset16Raw(unsafe.Pointer(dstHeader.Data), unsafe.Pointer(&val), dstHeader.Len) +} + +// RepeatU16 fills dst[] with the given uint16. +func RepeatU16(dst []uint16, val uint16) { + dstHeader := (*reflect.SliceHeader)(unsafe.Pointer(&dst)) + Memset16Raw(unsafe.Pointer(dstHeader.Data), unsafe.Pointer(&val), dstHeader.Len) +} + +// IndexU16 returns the index of the first instance of val in main, or -1 if +// val is not present in main. +func IndexU16(main []uint16, val uint16) int { + for i, v := range main { + if v == val { + return i + } + } + return -1 +} + +// (Add a function which has the original little-endian byte-slice semantics if +// we ever need it.) + +// Reverse16InplaceRaw assumes main points to an array of ct 2-byte elements, +// and reverses it in-place. +func Reverse16InplaceRaw(main unsafe.Pointer, nElem int) { + nElemDiv2 := nElem >> 1 + fwdIter := main + revIter := unsafe.Add(main, (nElem-1)*2) + for idx := 0; idx != nElemDiv2; idx++ { + origLeftVal := *((*uint16)(fwdIter)) + *((*uint16)(fwdIter)) = *((*uint16)(revIter)) + *((*uint16)(revIter)) = origLeftVal + fwdIter = unsafe.Add(fwdIter, 2) + revIter = unsafe.Add(revIter, -2) + } +} + +// Reverse16Raw assumes dst and src both point to arrays of ct 2-byte elements, +// and sets dst[pos] := src[ct - 1 - pos] for each position. +func Reverse16Raw(dst, src unsafe.Pointer, nElem int) { + srcIter := unsafe.Add(src, (nElem-1)*2) + dstIter := dst + for idx := 0; idx != nElem; idx++ { + *((*uint16)(dstIter)) = *((*uint16)(srcIter)) + srcIter = unsafe.Add(srcIter, -2) + dstIter = unsafe.Add(dstIter, 2) + } +} + +// ReverseI16Inplace reverses a []int16 in-place. +func ReverseI16Inplace(main []int16) { + mainHeader := (*reflect.SliceHeader)(unsafe.Pointer(&main)) + Reverse16InplaceRaw(unsafe.Pointer(mainHeader.Data), mainHeader.Len) +} + +// ReverseU16Inplace reverses a []uint16 in-place. +func ReverseU16Inplace(main []uint16) { + mainHeader := (*reflect.SliceHeader)(unsafe.Pointer(&main)) + Reverse16InplaceRaw(unsafe.Pointer(mainHeader.Data), mainHeader.Len) +} + +// ReverseI16 sets dst[len(src) - 1 - pos] := src[pos] for each position in +// src. It panics if len(src) != len(dst). +func ReverseI16(dst, src []int16) { + srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src)) + dstHeader := (*reflect.SliceHeader)(unsafe.Pointer(&dst)) + nElem := srcHeader.Len + if nElem != dstHeader.Len { + panic("ReverseI16() requires len(src) == len(dst).") + } + Reverse16Raw(unsafe.Pointer(dstHeader.Data), unsafe.Pointer(srcHeader.Data), nElem) +} + +// ReverseU16 sets dst[len(src) - 1 - pos] := src[pos] for each position in +// src. It panics if len(src) != len(dst). +func ReverseU16(dst, src []uint16) { + srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src)) + dstHeader := (*reflect.SliceHeader)(unsafe.Pointer(&dst)) + nElem := srcHeader.Len + if nElem != dstHeader.Len { + panic("ReverseU16() requires len(src) == len(dst).") + } + Reverse16Raw(unsafe.Pointer(dstHeader.Data), unsafe.Pointer(srcHeader.Data), nElem) +} + +// Benchmark results suggest that Reverse32Raw is unimportant. diff --git a/simd/multibyte_test.go b/simd/multibyte_test.go index 1a8a31d7..81f604c6 100644 --- a/simd/multibyte_test.go +++ b/simd/multibyte_test.go @@ -2,133 +2,20 @@ // Use of this source code is governed by the Apache-2.0 // license that can be found in the LICENSE file. +// +build !appengine + package simd_test import ( "math/rand" "reflect" - "runtime" "testing" "unsafe" "github.com/grailbio/base/simd" + "github.com/grailbio/testutil/expect" ) -/* -Initial benchmark results: - MacBook Pro (15-inch, 2016) - 2.7 GHz Intel Core i7, 16 GB 2133 MHz LPDDR3 - -Benchmark_Memset16Short1-8 20 79792679 ns/op -Benchmark_Memset16Short4-8 100 21460685 ns/op -Benchmark_Memset16ShortMax-8 100 19242532 ns/op -Benchmark_Memset16Long1-8 1 1209730588 ns/op -Benchmark_Memset16Long4-8 1 1630931319 ns/op -Benchmark_Memset16LongMax-8 1 2098725129 ns/op - -Benchmark_Reverse16Short1-8 20 86850717 ns/op -Benchmark_Reverse16Short4-8 50 26629273 ns/op -Benchmark_Reverse16ShortMax-8 100 21015725 ns/op -Benchmark_Reverse16Long1-8 1 1241551853 ns/op -Benchmark_Reverse16Long4-8 1 1691636166 ns/op -Benchmark_Reverse16LongMax-8 1 2201613448 ns/op - -For comparison, memset16: -Benchmark_Memset16Short1-8 5 254778732 ns/op -Benchmark_Memset16Short4-8 20 68925278 ns/op -Benchmark_Memset16ShortMax-8 20 60629923 ns/op -Benchmark_Memset16Long1-8 1 1261998317 ns/op -Benchmark_Memset16Long4-8 1 1684414682 ns/op -Benchmark_Memset16LongMax-8 1 2203954500 ns/op - -reverseU16Slow: -Benchmark_Reverse16Short1-8 10 180262413 ns/op -Benchmark_Reverse16Short4-8 30 49862651 ns/op -Benchmark_Reverse16ShortMax-8 10 114370495 ns/op -Benchmark_Reverse16Long1-8 1 3367505528 ns/op -Benchmark_Reverse16Long4-8 1 1707333366 ns/op -Benchmark_Reverse16LongMax-8 1 2175367071 ns/op -*/ - -func memset16Subtask(dst []uint16, nIter int) int { - for iter := 0; iter < nIter; iter++ { - simd.RepeatU16(dst, 0x201) - } - return int(dst[0]) -} - -func memset16SubtaskFuture(dst []uint16, nIter int) chan int { - future := make(chan int) - go func() { future <- memset16Subtask(dst, nIter) }() - return future -} - -func multiMemset16(dsts [][]uint16, cpus int, nJob int) { - sumFutures := make([]chan int, cpus) - shardSizeBase := nJob / cpus - shardRemainder := nJob - shardSizeBase*cpus - shardSizeP1 := shardSizeBase + 1 - var taskIdx int - for ; taskIdx < shardRemainder; taskIdx++ { - sumFutures[taskIdx] = memset16SubtaskFuture(dsts[taskIdx], shardSizeP1) - } - for ; taskIdx < cpus; taskIdx++ { - sumFutures[taskIdx] = memset16SubtaskFuture(dsts[taskIdx], shardSizeBase) - } - var sum int - for taskIdx = 0; taskIdx < cpus; taskIdx++ { - sum += <-sumFutures[taskIdx] - } -} - -func benchmarkMemset16(cpus int, nByte int, nJob int, b *testing.B) { - if cpus > runtime.NumCPU() { - b.Skipf("only have %v cpus", runtime.NumCPU()) - } - - nU16 := (nByte + 1) >> 1 - mainSlices := make([][]uint16, cpus) - for ii := range mainSlices { - // Add 31 to prevent false sharing. - newArr := make([]uint16, nU16, nU16+31) - for jj := 0; jj < nU16; jj++ { - newArr[jj] = uint16(jj * 3) - } - mainSlices[ii] = newArr[:nU16] - } - for i := 0; i < b.N; i++ { - multiMemset16(mainSlices, cpus, nJob) - } -} - -// Base sequence in length-150 .bam read occupies 75 bytes, so 75 is a good -// size for the short-array benchmark. -func Benchmark_Memset16Short1(b *testing.B) { - benchmarkMemset16(1, 75, 9999999, b) -} - -func Benchmark_Memset16Short4(b *testing.B) { - benchmarkMemset16(4, 75, 9999999, b) -} - -func Benchmark_Memset16ShortMax(b *testing.B) { - benchmarkMemset16(runtime.NumCPU(), 75, 9999999, b) -} - -// GRCh37 chromosome 1 length is 249250621, so that's a plausible long-array -// use case. -func Benchmark_Memset16Long1(b *testing.B) { - benchmarkMemset16(1, 249250621, 50, b) -} - -func Benchmark_Memset16Long4(b *testing.B) { - benchmarkMemset16(4, 249250621, 50, b) -} - -func Benchmark_Memset16LongMax(b *testing.B) { - benchmarkMemset16(runtime.NumCPU(), 249250621, 50, b) -} - // The compiler clearly recognizes this; performance is almost // indistinguishable from handcoded assembly. func memset32Builtin(dst []uint32, val uint32) { @@ -140,6 +27,7 @@ func memset32Builtin(dst []uint32, val uint32) { func TestMemset32(t *testing.T) { maxSize := 500 nIter := 200 + rand.Seed(1) main1Arr := make([]uint32, maxSize) main2Arr := make([]uint32, maxSize) for iter := 0; iter < nIter; iter++ { @@ -162,7 +50,7 @@ func TestMemset32(t *testing.T) { } } -func memset16(dst []uint16, val uint16) { +func memset16Standard(dst []uint16, val uint16) { // This tends to be better than the range-for loop, though it's less // clear-cut than the memset case. nDst := len(dst) @@ -177,6 +65,7 @@ func memset16(dst []uint16, val uint16) { func TestMemset16(t *testing.T) { maxSize := 500 nIter := 200 + rand.Seed(1) main1Arr := make([]uint16, maxSize) main2Arr := make([]uint16, maxSize) for iter := 0; iter < nIter; iter++ { @@ -187,7 +76,7 @@ func TestMemset16(t *testing.T) { main2Slice := main2Arr[sliceStart:sliceEnd] sentinel := uint16(rand.Uint32()) main2Arr[sliceEnd] = sentinel - memset16(main1Slice, u16Val) + memset16Standard(main1Slice, u16Val) simd.RepeatU16(main2Slice, u16Val) if !reflect.DeepEqual(main1Slice, main2Slice) { t.Fatal("Mismatched RepeatU16 result.") @@ -198,79 +87,150 @@ func TestMemset16(t *testing.T) { } } -func reverse16Subtask(main []uint16, nIter int) int { - for iter := 0; iter < nIter; iter++ { - simd.ReverseU16Inplace(main) - } - return int(main[0]) -} +/* +Benchmark results: + MacBook Pro (15-inch, 2016) + 2.7 GHz Intel Core i7, 16 GB 2133 MHz LPDDR3 -func reverse16SubtaskFuture(main []uint16, nIter int) chan int { - future := make(chan int) - go func() { future <- reverse16Subtask(main, nIter) }() - return future +Benchmark_Memset16/SIMDShort1Cpu-8 10 140130606 ns/op +Benchmark_Memset16/SIMDShortHalfCpu-8 50 37087600 ns/op +Benchmark_Memset16/SIMDShortAllCpu-8 50 35361817 ns/op +Benchmark_Memset16/SIMDLong1Cpu-8 1 1157494604 ns/op +Benchmark_Memset16/SIMDLongHalfCpu-8 2 921843584 ns/op +Benchmark_Memset16/SIMDLongAllCpu-8 2 960652822 ns/op +Benchmark_Memset16/StandardShort1Cpu-8 5 343877390 ns/op +Benchmark_Memset16/StandardShortHalfCpu-8 20 88295789 ns/op +Benchmark_Memset16/StandardShortAllCpu-8 20 86026817 ns/op +Benchmark_Memset16/StandardLong1Cpu-8 1 1038072481 ns/op +Benchmark_Memset16/StandardLongHalfCpu-8 2 979292703 ns/op +Benchmark_Memset16/StandardLongAllCpu-8 1 1052316741 ns/op +*/ + +type u16Args struct { + main []uint16 } -func multiReverse16(mains [][]uint16, cpus int, nJob int) { - sumFutures := make([]chan int, cpus) - shardSizeBase := nJob / cpus - shardRemainder := nJob - shardSizeBase*cpus - shardSizeP1 := shardSizeBase + 1 - var taskIdx int - for ; taskIdx < shardRemainder; taskIdx++ { - sumFutures[taskIdx] = reverse16SubtaskFuture(mains[taskIdx], shardSizeP1) - } - for ; taskIdx < cpus; taskIdx++ { - sumFutures[taskIdx] = reverse16SubtaskFuture(mains[taskIdx], shardSizeBase) - } - var sum int - for taskIdx = 0; taskIdx < cpus; taskIdx++ { - sum += <-sumFutures[taskIdx] +func memset16SimdSubtask(args interface{}, nIter int) int { + a := args.(u16Args) + for iter := 0; iter < nIter; iter++ { + simd.RepeatU16(a.main, 0x201) } + return int(a.main[0]) } -func benchmarkReverse16(cpus int, nByte int, nJob int, b *testing.B) { - if cpus > runtime.NumCPU() { - b.Skipf("only have %v cpus", runtime.NumCPU()) +func memset16StandardSubtask(args interface{}, nIter int) int { + a := args.(u16Args) + for iter := 0; iter < nIter; iter++ { + memset16Standard(a.main, 0x201) } + return int(a.main[0]) +} - nU16 := (nByte + 1) >> 1 - mainSlices := make([][]uint16, cpus) - for ii := range mainSlices { - // Add 31 to prevent false sharing. - newArr := make([]uint16, nU16, nU16+31) - for jj := 0; jj < nU16; jj++ { - newArr[jj] = uint16(jj * 3) - } - mainSlices[ii] = newArr[:nU16] +func Benchmark_Memset16(b *testing.B) { + funcs := []taggedMultiBenchVarargsFunc{ + { + f: memset16SimdSubtask, + tag: "SIMD", + }, + { + f: memset16StandardSubtask, + tag: "Standard", + }, } - for i := 0; i < b.N; i++ { - multiReverse16(mainSlices, cpus, nJob) + for _, f := range funcs { + multiBenchmarkVarargs(f.f, f.tag+"Short", 9999999, func() interface{} { + return u16Args{ + main: make([]uint16, 75, 75+31), + } + }, b) + multiBenchmarkVarargs(f.f, f.tag+"Long", 50, func() interface{} { + return u16Args{ + main: make([]uint16, 249250622/2, 249250622/2+31), + } + }, b) } } -func Benchmark_Reverse16Short1(b *testing.B) { - benchmarkReverse16(1, 75, 9999999, b) +func indexU16Standard(main []uint16, val uint16) int { + for i, v := range main { + if v == val { + return i + } + } + return -1 } -func Benchmark_Reverse16Short4(b *testing.B) { - benchmarkReverse16(4, 75, 9999999, b) +func TestIndexU16(t *testing.T) { + // Generate nOuterIter random length-arrLen []uint16s, and perform nInnerIter + // random searches on each slice. + arrLen := 50000 + nOuterIter := 5 + nInnerIter := 100 + valLimit := 65536 // maximum uint16 is 65535 + rand.Seed(1) + mainArr := make([]uint16, arrLen) + for outerIdx := 0; outerIdx < nOuterIter; outerIdx++ { + for i := range mainArr { + mainArr[i] = uint16(rand.Intn(valLimit)) + } + for innerIdx := 0; innerIdx < nInnerIter; innerIdx++ { + needle := uint16(rand.Intn(valLimit)) + expected := indexU16Standard(mainArr, needle) + actual := simd.IndexU16(mainArr, needle) + expect.EQ(t, expected, actual) + } + } } -func Benchmark_Reverse16ShortMax(b *testing.B) { - benchmarkReverse16(runtime.NumCPU(), 75, 9999999, b) -} +const indexU16TestLimit = 100 -func Benchmark_Reverse16Long1(b *testing.B) { - benchmarkReverse16(1, 249250621, 50, b) +func indexU16SimdSubtask(args interface{}, nIter int) int { + a := args.(u16Args) + sum := 0 + needle := uint16(0) + for iter := 0; iter < nIter; iter++ { + sum += simd.IndexU16(a.main, needle) + needle++ + if needle == indexU16TestLimit { + needle = 0 + } + } + return sum } -func Benchmark_Reverse16Long4(b *testing.B) { - benchmarkReverse16(4, 249250621, 50, b) +func indexU16StandardSubtask(args interface{}, nIter int) int { + a := args.(u16Args) + sum := 0 + needle := uint16(0) + for iter := 0; iter < nIter; iter++ { + sum += indexU16Standard(a.main, needle) + needle++ + if needle == indexU16TestLimit { + needle = 0 + } + } + return sum } -func Benchmark_Reverse16LongMax(b *testing.B) { - benchmarkReverse16(runtime.NumCPU(), 249250621, 50, b) +// Single-threaded performance is ~4x as good in my testing. +func Benchmark_IndexU16(b *testing.B) { + funcs := []taggedMultiBenchVarargsFunc{ + { + f: indexU16SimdSubtask, + tag: "SIMD", + }, + { + f: indexU16StandardSubtask, + tag: "Standard", + }, + } + for _, f := range funcs { + multiBenchmarkVarargs(f.f, f.tag+"Long", 50, func() interface{} { + return u16Args{ + main: make([]uint16, 4000000, 4000000+31), + } + }, b) + } } func reverseU16Slow(main []uint16) { @@ -284,6 +244,7 @@ func reverseU16Slow(main []uint16) { func TestReverse16(t *testing.T) { maxSize := 500 nIter := 200 + rand.Seed(1) main1Arr := make([]uint16, maxSize) main2Arr := make([]uint16, maxSize) main3Arr := make([]uint16, maxSize) @@ -324,3 +285,63 @@ func TestReverse16(t *testing.T) { } } } + +/* +Benchmark results: + MacBook Pro (15-inch, 2016) + 2.7 GHz Intel Core i7, 16 GB 2133 MHz LPDDR3 + +Benchmark_ReverseU16Inplace/SIMDShort1Cpu-8 20 102899505 ns/op +Benchmark_ReverseU16Inplace/SIMDShortHalfCpu-8 50 32918441 ns/op +Benchmark_ReverseU16Inplace/SIMDShortAllCpu-8 30 38848510 ns/op +Benchmark_ReverseU16Inplace/SIMDLong1Cpu-8 1 1116384992 ns/op +Benchmark_ReverseU16Inplace/SIMDLongHalfCpu-8 2 880730467 ns/op +Benchmark_ReverseU16Inplace/SIMDLongAllCpu-8 2 943204867 ns/op +Benchmark_ReverseU16Inplace/SlowShort1Cpu-8 3 443056373 ns/op +Benchmark_ReverseU16Inplace/SlowShortHalfCpu-8 10 117142962 ns/op +Benchmark_ReverseU16Inplace/SlowShortAllCpu-8 10 159087579 ns/op +Benchmark_ReverseU16Inplace/SlowLong1Cpu-8 1 3158497662 ns/op +Benchmark_ReverseU16Inplace/SlowLongHalfCpu-8 2 967619258 ns/op +Benchmark_ReverseU16Inplace/SlowLongAllCpu-8 2 978231337 ns/op +*/ + +func reverseU16InplaceSimdSubtask(args interface{}, nIter int) int { + a := args.(u16Args) + for iter := 0; iter < nIter; iter++ { + simd.ReverseU16Inplace(a.main) + } + return int(a.main[0]) +} + +func reverseU16InplaceSlowSubtask(args interface{}, nIter int) int { + a := args.(u16Args) + for iter := 0; iter < nIter; iter++ { + reverseU16Slow(a.main) + } + return int(a.main[0]) +} + +func Benchmark_ReverseU16Inplace(b *testing.B) { + funcs := []taggedMultiBenchVarargsFunc{ + { + f: reverseU16InplaceSimdSubtask, + tag: "SIMD", + }, + { + f: reverseU16InplaceSlowSubtask, + tag: "Slow", + }, + } + for _, f := range funcs { + multiBenchmarkVarargs(f.f, f.tag+"Short", 9999999, func() interface{} { + return u16Args{ + main: make([]uint16, 75, 75+31), + } + }, b) + multiBenchmarkVarargs(f.f, f.tag+"Long", 50, func() interface{} { + return u16Args{ + main: make([]uint16, 249250622/2, 249250622/2+31), + } + }, b) + } +} diff --git a/simd/or_amd64.go b/simd/or_amd64.go index 9ef67204..42c6bc3d 100644 --- a/simd/or_amd64.go +++ b/simd/or_amd64.go @@ -1,8 +1,10 @@ -// Code generated from " ../gtl/generate.py --prefix=Or -DOPCHAR=| --package=simd --output=or_amd64.go bitwise_amd64.go.tpl ". DO NOT EDIT. -// Copyright 2018 GRAIL, Inc. All rights reserved. +// Code generated by "../gtl/generate.py --prefix=Or -DOPCHAR=| --package=simd --output=or_amd64.go bitwise_amd64.go.tpl". DO NOT EDIT. + +// Copyright 2021 GRAIL, Inc. All rights reserved. // Use of this source code is governed by the Apache-2.0 // license that can be found in the LICENSE file. +//go:build amd64 && !appengine // +build amd64,!appengine package simd @@ -12,7 +14,7 @@ import ( "unsafe" ) -// OrUnsafeInplace sets main[pos] := arg[pos] | main[pos] for every position +// OrUnsafeInplace sets main[pos] := main[pos] | arg[pos] for every position // in main[]. // // WARNING: This is a function designed to be used in inner loops, which makes @@ -30,18 +32,18 @@ import ( // changed. func OrUnsafeInplace(main, arg []byte) { mainLen := len(main) - argHeader := (*reflect.SliceHeader)(unsafe.Pointer(&arg)) - mainHeader := (*reflect.SliceHeader)(unsafe.Pointer(&main)) - argWordsIter := unsafe.Pointer(argHeader.Data) - mainWordsIter := unsafe.Pointer(mainHeader.Data) + argData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&arg)).Data) + mainData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&main)).Data) + argWordsIter := argData + mainWordsIter := mainData if mainLen > 2*BytesPerWord { nWordMinus2 := (mainLen - BytesPerWord - 1) >> Log2BytesPerWord for widx := 0; widx < nWordMinus2; widx++ { mainWord := *((*uintptr)(mainWordsIter)) argWord := *((*uintptr)(argWordsIter)) *((*uintptr)(mainWordsIter)) = mainWord | argWord - mainWordsIter = unsafe.Pointer(uintptr(mainWordsIter) + BytesPerWord) - argWordsIter = unsafe.Pointer(uintptr(argWordsIter) + BytesPerWord) + mainWordsIter = unsafe.Add(mainWordsIter, BytesPerWord) + argWordsIter = unsafe.Add(argWordsIter, BytesPerWord) } } else if mainLen <= BytesPerWord { mainWord := *((*uintptr)(mainWordsIter)) @@ -56,8 +58,8 @@ func OrUnsafeInplace(main, arg []byte) { mainWord1 := *((*uintptr)(mainWordsIter)) argWord1 := *((*uintptr)(argWordsIter)) finalOffset := uintptr(mainLen - BytesPerWord) - mainFinalWordPtr := unsafe.Pointer(mainHeader.Data + finalOffset) - argFinalWordPtr := unsafe.Pointer(argHeader.Data + finalOffset) + mainFinalWordPtr := unsafe.Add(mainData, finalOffset) + argFinalWordPtr := unsafe.Add(argData, finalOffset) mainWord2 := *((*uintptr)(mainFinalWordPtr)) argWord2 := *((*uintptr)(argFinalWordPtr)) *((*uintptr)(mainWordsIter)) = mainWord1 | argWord1 @@ -82,25 +84,25 @@ func OrInplace(main, arg []byte) { } return } - argHeader := (*reflect.SliceHeader)(unsafe.Pointer(&arg)) - mainHeader := (*reflect.SliceHeader)(unsafe.Pointer(&main)) - argWordsIter := unsafe.Pointer(argHeader.Data) - mainWordsIter := unsafe.Pointer(mainHeader.Data) + argData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&arg)).Data) + mainData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&main)).Data) + argWordsIter := argData + mainWordsIter := mainData if mainLen > 2*BytesPerWord { nWordMinus2 := (mainLen - BytesPerWord - 1) >> Log2BytesPerWord for widx := 0; widx < nWordMinus2; widx++ { mainWord := *((*uintptr)(mainWordsIter)) argWord := *((*uintptr)(argWordsIter)) *((*uintptr)(mainWordsIter)) = mainWord | argWord - mainWordsIter = unsafe.Pointer(uintptr(mainWordsIter) + BytesPerWord) - argWordsIter = unsafe.Pointer(uintptr(argWordsIter) + BytesPerWord) + mainWordsIter = unsafe.Add(mainWordsIter, BytesPerWord) + argWordsIter = unsafe.Add(argWordsIter, BytesPerWord) } } mainWord1 := *((*uintptr)(mainWordsIter)) argWord1 := *((*uintptr)(argWordsIter)) finalOffset := uintptr(mainLen - BytesPerWord) - mainFinalWordPtr := unsafe.Pointer(mainHeader.Data + finalOffset) - argFinalWordPtr := unsafe.Pointer(argHeader.Data + finalOffset) + mainFinalWordPtr := unsafe.Add(mainData, finalOffset) + argFinalWordPtr := unsafe.Add(argData, finalOffset) mainWord2 := *((*uintptr)(mainFinalWordPtr)) argWord2 := *((*uintptr)(argFinalWordPtr)) *((*uintptr)(mainWordsIter)) = mainWord1 | argWord1 @@ -135,9 +137,9 @@ func OrUnsafe(dst, src1, src2 []byte) { src1Word := *((*uintptr)(src1Iter)) src2Word := *((*uintptr)(src2Iter)) *((*uintptr)(dstIter)) = src1Word | src2Word - src1Iter = unsafe.Pointer(uintptr(src1Iter) + BytesPerWord) - src2Iter = unsafe.Pointer(uintptr(src2Iter) + BytesPerWord) - dstIter = unsafe.Pointer(uintptr(dstIter) + BytesPerWord) + src1Iter = unsafe.Add(src1Iter, BytesPerWord) + src2Iter = unsafe.Add(src2Iter, BytesPerWord) + dstIter = unsafe.Add(dstIter, BytesPerWord) } } @@ -154,27 +156,27 @@ func Or(dst, src1, src2 []byte) { } return } - src1Header := (*reflect.SliceHeader)(unsafe.Pointer(&src1)) - src2Header := (*reflect.SliceHeader)(unsafe.Pointer(&src2)) - dstHeader := (*reflect.SliceHeader)(unsafe.Pointer(&dst)) + src1Data := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&src1)).Data) + src2Data := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&src2)).Data) + dstData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&dst)).Data) nWordMinus1 := (dstLen - 1) >> Log2BytesPerWord - src1Iter := unsafe.Pointer(src1Header.Data) - src2Iter := unsafe.Pointer(src2Header.Data) - dstIter := unsafe.Pointer(dstHeader.Data) + src1Iter := src1Data + src2Iter := src2Data + dstIter := dstData for widx := 0; widx < nWordMinus1; widx++ { src1Word := *((*uintptr)(src1Iter)) src2Word := *((*uintptr)(src2Iter)) *((*uintptr)(dstIter)) = src1Word | src2Word - src1Iter = unsafe.Pointer(uintptr(src1Iter) + BytesPerWord) - src2Iter = unsafe.Pointer(uintptr(src2Iter) + BytesPerWord) - dstIter = unsafe.Pointer(uintptr(dstIter) + BytesPerWord) + src1Iter = unsafe.Add(src1Iter, BytesPerWord) + src2Iter = unsafe.Add(src2Iter, BytesPerWord) + dstIter = unsafe.Add(dstIter, BytesPerWord) } // No store-forwarding problem here. finalOffset := uintptr(dstLen - BytesPerWord) - src1Iter = unsafe.Pointer(src1Header.Data + finalOffset) - src2Iter = unsafe.Pointer(src2Header.Data + finalOffset) - dstIter = unsafe.Pointer(dstHeader.Data + finalOffset) + src1Iter = unsafe.Add(src1Data, finalOffset) + src2Iter = unsafe.Add(src2Data, finalOffset) + dstIter = unsafe.Add(dstData, finalOffset) src1Word := *((*uintptr)(src1Iter)) src2Word := *((*uintptr)(src2Iter)) *((*uintptr)(dstIter)) = src1Word | src2Word @@ -197,14 +199,14 @@ func Or(dst, src1, src2 []byte) { func OrConst8UnsafeInplace(main []byte, val byte) { mainLen := len(main) argWord := 0x101010101010101 * uintptr(val) - mainHeader := (*reflect.SliceHeader)(unsafe.Pointer(&main)) - mainWordsIter := unsafe.Pointer(mainHeader.Data) + mainData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&main)).Data) + mainWordsIter := mainData if mainLen > 2*BytesPerWord { nWordMinus2 := (mainLen - BytesPerWord - 1) >> Log2BytesPerWord for widx := 0; widx < nWordMinus2; widx++ { mainWord := *((*uintptr)(mainWordsIter)) *((*uintptr)(mainWordsIter)) = mainWord | argWord - mainWordsIter = unsafe.Pointer(uintptr(mainWordsIter) + BytesPerWord) + mainWordsIter = unsafe.Add(mainWordsIter, BytesPerWord) } } else if mainLen <= BytesPerWord { mainWord := *((*uintptr)(mainWordsIter)) @@ -213,7 +215,7 @@ func OrConst8UnsafeInplace(main []byte, val byte) { } mainWord1 := *((*uintptr)(mainWordsIter)) finalOffset := uintptr(mainLen - BytesPerWord) - mainFinalWordPtr := unsafe.Pointer(mainHeader.Data + finalOffset) + mainFinalWordPtr := unsafe.Add(mainData, finalOffset) mainWord2 := *((*uintptr)(mainFinalWordPtr)) *((*uintptr)(mainWordsIter)) = mainWord1 | argWord *((*uintptr)(mainFinalWordPtr)) = mainWord2 | argWord @@ -230,19 +232,19 @@ func OrConst8Inplace(main []byte, val byte) { return } argWord := 0x101010101010101 * uintptr(val) - mainHeader := (*reflect.SliceHeader)(unsafe.Pointer(&main)) - mainWordsIter := unsafe.Pointer(mainHeader.Data) + mainData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&main)).Data) + mainWordsIter := mainData if mainLen > 2*BytesPerWord { nWordMinus2 := (mainLen - BytesPerWord - 1) >> Log2BytesPerWord for widx := 0; widx < nWordMinus2; widx++ { mainWord := *((*uintptr)(mainWordsIter)) *((*uintptr)(mainWordsIter)) = mainWord | argWord - mainWordsIter = unsafe.Pointer(uintptr(mainWordsIter) + BytesPerWord) + mainWordsIter = unsafe.Add(mainWordsIter, BytesPerWord) } } mainWord1 := *((*uintptr)(mainWordsIter)) finalOffset := uintptr(mainLen - BytesPerWord) - mainFinalWordPtr := unsafe.Pointer(mainHeader.Data + finalOffset) + mainFinalWordPtr := unsafe.Add(mainData, finalOffset) mainWord2 := *((*uintptr)(mainFinalWordPtr)) *((*uintptr)(mainWordsIter)) = mainWord1 | argWord *((*uintptr)(mainFinalWordPtr)) = mainWord2 | argWord @@ -274,8 +276,8 @@ func OrConst8Unsafe(dst, src []byte, val byte) { for widx := 0; widx < nWord; widx++ { srcWord := *((*uintptr)(srcIter)) *((*uintptr)(dstIter)) = srcWord | argWord - srcIter = unsafe.Pointer(uintptr(srcIter) + BytesPerWord) - dstIter = unsafe.Pointer(uintptr(dstIter) + BytesPerWord) + srcIter = unsafe.Add(srcIter, BytesPerWord) + dstIter = unsafe.Add(dstIter, BytesPerWord) } } @@ -292,22 +294,22 @@ func OrConst8(dst, src []byte, val byte) { } return } - srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src)) - dstHeader := (*reflect.SliceHeader)(unsafe.Pointer(&dst)) + srcData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&src)).Data) + dstData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&dst)).Data) nWordMinus1 := (dstLen - 1) >> Log2BytesPerWord argWord := 0x101010101010101 * uintptr(val) - srcIter := unsafe.Pointer(srcHeader.Data) - dstIter := unsafe.Pointer(dstHeader.Data) + srcIter := unsafe.Pointer(srcData) + dstIter := unsafe.Pointer(dstData) for widx := 0; widx < nWordMinus1; widx++ { srcWord := *((*uintptr)(srcIter)) *((*uintptr)(dstIter)) = srcWord | argWord - srcIter = unsafe.Pointer(uintptr(srcIter) + BytesPerWord) - dstIter = unsafe.Pointer(uintptr(dstIter) + BytesPerWord) + srcIter = unsafe.Add(srcIter, BytesPerWord) + dstIter = unsafe.Add(dstIter, BytesPerWord) } finalOffset := uintptr(dstLen - BytesPerWord) - srcIter = unsafe.Pointer(srcHeader.Data + finalOffset) - dstIter = unsafe.Pointer(dstHeader.Data + finalOffset) + srcIter = unsafe.Add(srcData, finalOffset) + dstIter = unsafe.Add(dstData, finalOffset) srcWord := *((*uintptr)(srcIter)) *((*uintptr)(dstIter)) = srcWord | argWord } diff --git a/simd/or_generic.go b/simd/or_generic.go new file mode 100644 index 00000000..dc34c8f6 --- /dev/null +++ b/simd/or_generic.go @@ -0,0 +1,135 @@ +// Code generated by " ../gtl/generate.py --prefix=Or -DOPCHAR=| --package=simd --output=or_generic.go bitwise_generic.go.tpl ". DO NOT EDIT. + +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +// +build !amd64 appengine + +package simd + +// OrUnsafeInplace sets main[pos] := main[pos] | arg[pos] for every position +// in main[]. +// +// WARNING: This is a function designed to be used in inner loops, which makes +// assumptions about length and capacity which aren't checked at runtime. Use +// the safe version of this function when that's a problem. +// Assumptions #2-3 are always satisfied when the last +// potentially-size-increasing operation on arg[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(), and the same is true for main[]. +// +// 1. len(arg) and len(main) must be equal. +// +// 2. Capacities are at least RoundUpPow2(len(main) + 1, bytesPerVec). +// +// 3. The caller does not care if a few bytes past the end of main[] are +// changed. +func OrUnsafeInplace(main, arg []byte) { + for i, x := range main { + main[i] = x | arg[i] + } +} + +// OrInplace sets main[pos] := main[pos] | arg[pos] for every position in +// main[]. It panics if slice lengths don't match. +func OrInplace(main, arg []byte) { + if len(arg) != len(main) { + panic("OrInplace() requires len(arg) == len(main).") + } + for i, x := range main { + main[i] = x | arg[i] + } +} + +// OrUnsafe sets dst[pos] := src1[pos] | src2[pos] for every position in dst. +// +// WARNING: This is a function designed to be used in inner loops, which makes +// assumptions about length and capacity which aren't checked at runtime. Use +// the safe version of this function when that's a problem. +// Assumptions #2-3 are always satisfied when the last +// potentially-size-increasing operation on src1[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(), and the same is true for src2[] and dst[]. +// +// 1. len(src1), len(src2), and len(dst) must be equal. +// +// 2. Capacities are at least RoundUpPow2(len(dst) + 1, bytesPerVec). +// +// 3. The caller does not care if a few bytes past the end of dst[] are +// changed. +func OrUnsafe(dst, src1, src2 []byte) { + for i, x := range src1 { + dst[i] = x | src2[i] + } +} + +// Or sets dst[pos] := src1[pos] | src2[pos] for every position in dst. It +// panics if slice lengths don't match. +func Or(dst, src1, src2 []byte) { + dstLen := len(dst) + if (len(src1) != dstLen) || (len(src2) != dstLen) { + panic("Or() requires len(src1) == len(src2) == len(dst).") + } + for i, x := range src1 { + dst[i] = x | src2[i] + } +} + +// OrConst8UnsafeInplace sets main[pos] := main[pos] | val for every position +// in main[]. +// +// WARNING: This is a function designed to be used in inner loops, which makes +// assumptions about length and capacity which aren't checked at runtime. Use +// the safe version of this function when that's a problem. +// These assumptions are always satisfied when the last +// potentially-size-increasing operation on main[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(). +// +// 1. cap(main) is at least RoundUpPow2(len(main) + 1, bytesPerVec). +// +// 2. The caller does not care if a few bytes past the end of main[] are +// changed. +func OrConst8UnsafeInplace(main []byte, val byte) { + for i, x := range main { + main[i] = x | val + } +} + +// OrConst8Inplace sets main[pos] := main[pos] | val for every position in +// main[]. +func OrConst8Inplace(main []byte, val byte) { + for i, x := range main { + main[i] = x | val + } +} + +// OrConst8Unsafe sets dst[pos] := src[pos] | val for every position in dst. +// +// WARNING: This is a function designed to be used in inner loops, which makes +// assumptions about length and capacity which aren't checked at runtime. Use +// the safe version of this function when that's a problem. +// Assumptions #2-3 are always satisfied when the last +// potentially-size-increasing operation on src[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(), and the same is true for dst[]. +// +// 1. len(src) and len(dst) must be equal. +// +// 2. Capacities are at least RoundUpPow2(len(dst) + 1, bytesPerVec). +// +// 3. The caller does not care if a few bytes past the end of dst[] are +// changed. +func OrConst8Unsafe(dst, src []byte, val byte) { + for i, x := range src { + dst[i] = x | val + } +} + +// OrConst8 sets dst[pos] := src[pos] | val for every position in dst. It +// panics if slice lengths don't match. +func OrConst8(dst, src []byte, val byte) { + if len(src) != len(dst) { + panic("OrConst8() requires len(src) == len(dst).") + } + for i, x := range src { + dst[i] = x | val + } +} diff --git a/simd/simd_amd64.go b/simd/simd_amd64.go index 32b6e6a1..b1f45e4f 100644 --- a/simd/simd_amd64.go +++ b/simd/simd_amd64.go @@ -1,7 +1,8 @@ -// Copyright 2018 GRAIL, Inc. All rights reserved. +// Copyright 2021 GRAIL, Inc. All rights reserved. // Use of this source code is governed by the Apache-2.0 // license that can be found in the LICENSE file. +//go:build amd64 && !appengine // +build amd64,!appengine package simd @@ -11,12 +12,15 @@ import ( "reflect" "unsafe" - gunsafe "github.com/grailbio/base/unsafe" + "golang.org/x/sys/cpu" ) // amd64 compile-time constants. // BytesPerWord is the number of bytes in a machine word. +// We don't use unsafe.Sizeof(uintptr(1)) since there are advantages to having +// this as an untyped constant, and there's essentially no drawback since this +// is an _amd64-specific file. const BytesPerWord = 8 // Log2BytesPerWord is log2(BytesPerWord). This is relevant for manual @@ -24,14 +28,45 @@ const BytesPerWord = 8 // not (e.g. dividend is of signed int type). const Log2BytesPerWord = uint(3) -// const minPageSize = 4096 may be relevant for safe functions soon. +// BitsPerWord is the number of bits in a machine word. +const BitsPerWord = BytesPerWord * 8 + +// This must be at least / 16. +const nibbleLookupDup = 1 + +// NibbleLookupTable represents a parallel-byte-substitution operation f, where +// every byte b in a byte-slice is replaced with +// f(b) := shuffle[0][b & 15] for b <= 127, and +// f(b) := 0 for b > 127. +// (The second part is usually irrelevant in practice, but must be defined this +// way to allow _mm_shuffle_epi8()/_mm256_shuffle_epi8()/_mm512_shuffle_epi8() +// to be used to implement the operation efficiently.) +// It's named NibbleLookupTable rather than ByteLookupTable since only the +// bottom nibble of each byte can be used for table lookup. +// It potentially stores multiple adjacent copies of the lookup table since +// that speeds up the AVX2 and AVX-512 use cases (the table can be loaded with +// a single _mm256_loadu_si256 operation, instead of e.g. _mm_loadu_si128 +// followed by _mm256_set_m128i with the same argument twice), and the typical +// use case involves initializing very few tables and using them many, many +// times. +type NibbleLookupTable struct { + shuffle [nibbleLookupDup][16]byte +} + +// Get performs the b <= 127 part of the lookup operation described above. +// The b > 127 branch is omitted because in many use cases (e.g. +// PackedNibbleLookup below), it can be proven that b > 127 is impossible, and +// removing the if-statement is a significant performance win when it's +// possible. +func (t *NibbleLookupTable) Get(b byte) byte { + return t.shuffle[0][b] +} -// These could be compile-time constants for now, but not after AVX2 -// autodetection is added. +// const minPageSize = 4096 may be relevant for safe functions soon. // bytesPerVec is the size of the maximum-width vector that may be used. It is -// currently always 16, but it will be set to larger values at runtime in the -// future when AVX2/AVX-512/etc. is detected. +// at least 16, but will soon be set to 32 if AVX2 support is detected. It +// may be set to 64 in the future when AVX-512 is detected. var bytesPerVec int // log2BytesPerVec supports efficient division by bytesPerVec. @@ -39,14 +74,6 @@ var log2BytesPerVec uint // *** the following functions are defined in simd_amd64.s -// Strictly speaking, hasSSE42Asm() duplicates code in e.g. -// github.com/klauspost/cpuid , but it's literally only a few bytes. -// Todo: look into replacing this with go:linkname exploitation of the -// runtime's cpuid check results, and empty import of runtime. - -//go:noescape -func hasSSE42Asm() bool - // There was a unpackedNibbleLookupInplaceSSSE3Asm function here, but it // actually benchmarked worse than the general-case function. @@ -80,16 +107,24 @@ func reverse8InplaceSSSE3Asm(main unsafe.Pointer, nByte int) //go:noescape func reverse8SSSE3Asm(dst, src unsafe.Pointer, nByte int) +//go:noescape +func bitFromEveryByteSSE2Asm(dst, src unsafe.Pointer, lshift, nDstByte int) + // *** end assembly function signatures func init() { - if !hasSSE42Asm() { + if !cpu.X86.HasSSE42 { panic("SSE4.2 required.") } bytesPerVec = 16 log2BytesPerVec = 4 } +// BytesPerVec is an accessor for the bytesPerVec package variable. +func BytesPerVec() int { + return bytesPerVec +} + // RoundUpPow2 returns val rounded up to a multiple of alignment, assuming // alignment is a power of 2. func RoundUpPow2(val, alignment int) int { @@ -123,7 +158,7 @@ func MakeUnsafe(len int) []byte { func RemakeUnsafe(bufptr *[]byte, len int) { minCap := len + bytesPerVec if minCap <= cap(*bufptr) { - gunsafe.ExtendBytes(bufptr, len) + *bufptr = (*bufptr)[:len] return } // This is likely to be called in an inner loop processing variable-size @@ -138,7 +173,7 @@ func RemakeUnsafe(bufptr *[]byte, len int) { func ResizeUnsafe(bufptr *[]byte, len int) { minCap := len + bytesPerVec if minCap <= cap(*bufptr) { - gunsafe.ExtendBytes(bufptr, len) + *bufptr = (*bufptr)[:len] return } dst := make([]byte, len, RoundUpPow2(minCap+(minCap/8), bytesPerVec)) @@ -149,9 +184,6 @@ func ResizeUnsafe(bufptr *[]byte, len int) { // XcapUnsafe is shorthand for ResizeUnsafe's most common use case (no length // change, just want to ensure sufficient capacity). func XcapUnsafe(bufptr *[]byte) { - // mid-stack inlining isn't yet working as I write this, but it should be - // available soon enough: - // https://github.com/golang/go/issues/19348 ResizeUnsafe(bufptr, len(*bufptr)) } @@ -166,7 +198,7 @@ func XcapUnsafe(bufptr *[]byte) { // past the end of dst[] are changed. Use the safe version of this function if // any of these properties are problematic. // These assumptions are always satisfied when the last -// potentially-size-increasing operation on dst[] is {Re}makeUnsafe(), +// potentially-size-increasing operation on dst[] is {Make,Remake}Unsafe(), // ResizeUnsafe(), or XcapUnsafe(). func Memset8Unsafe(dst []byte, val byte) { dstHeader := (*reflect.SliceHeader)(unsafe.Pointer(&dst)) @@ -180,7 +212,7 @@ func Memset8Unsafe(dst []byte, val byte) { dstWordsIter := unsafe.Pointer(dstHeader.Data) for widx := 0; widx < nWord; widx++ { *((*uintptr)(dstWordsIter)) = valWord - dstWordsIter = unsafe.Pointer(uintptr(dstWordsIter) + BytesPerWord) + dstWordsIter = unsafe.Add(dstWordsIter, BytesPerWord) } } @@ -196,18 +228,26 @@ func Memset8(dst []byte, val byte) { } return } - dstHeader := (*reflect.SliceHeader)(unsafe.Pointer(&dst)) + dstData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&dst)).Data) valWord := uintptr(0x0101010101010101) * uintptr(val) nWordMinus1 := (dstLen - 1) >> Log2BytesPerWord - dstWordsIter := unsafe.Pointer(dstHeader.Data) + dstWordsIter := dstData for widx := 0; widx < nWordMinus1; widx++ { *((*uintptr)(dstWordsIter)) = valWord - dstWordsIter = unsafe.Pointer(uintptr(dstWordsIter) + BytesPerWord) + dstWordsIter = unsafe.Add(dstWordsIter, BytesPerWord) } - dstWordsIter = unsafe.Pointer(dstHeader.Data + uintptr(dstLen) - BytesPerWord) + dstWordsIter = unsafe.Add(dstData, dstLen-BytesPerWord) *((*uintptr)(dstWordsIter)) = valWord } +// MakeNibbleLookupTable generates a NibbleLookupTable from a [16]byte. +func MakeNibbleLookupTable(table [16]byte) (t NibbleLookupTable) { + for i := range t.shuffle { + t.shuffle[i] = table + } + return +} + // UnpackedNibbleLookupUnsafeInplace replaces the bytes in main[] as follows: // if value < 128, set to table[value & 15] // otherwise, set to 0 @@ -216,14 +256,14 @@ func Memset8(dst []byte, val byte) { // assumptions about capacity which aren't checked at runtime. Use the safe // version of this function when that's a problem. // These assumptions are always satisfied when the last -// potentially-size-increasing operation on main[] is {Re}makeUnsafe(), +// potentially-size-increasing operation on main[] is {Make,Remake}Unsafe(), // ResizeUnsafe(), or XcapUnsafe(). // // 1. cap(main) must be at least RoundUpPow2(len(main) + 1, bytesPerVec). // // 2. The caller does not care if a few bytes past the end of main[] are // changed. -func UnpackedNibbleLookupUnsafeInplace(main []byte, tablePtr *[16]byte) { +func UnpackedNibbleLookupUnsafeInplace(main []byte, tablePtr *NibbleLookupTable) { mainLen := len(main) mainHeader := (*reflect.SliceHeader)(unsafe.Pointer(&main)) if mainLen <= 16 { @@ -239,7 +279,7 @@ func UnpackedNibbleLookupUnsafeInplace(main []byte, tablePtr *[16]byte) { // UnpackedNibbleLookupInplace replaces the bytes in main[] as follows: // if value < 128, set to table[value & 15] // otherwise, set to 0 -func UnpackedNibbleLookupInplace(main []byte, tablePtr *[16]byte) { +func UnpackedNibbleLookupInplace(main []byte, tablePtr *NibbleLookupTable) { // May want to define variants of these functions which have undefined // results for input values in [16, 128); this will be useful for // cross-platform ARM/x86 code. @@ -250,7 +290,7 @@ func UnpackedNibbleLookupInplace(main []byte, tablePtr *[16]byte) { // justifications for exporting Unsafe functions at all.) for pos, curByte := range main { if curByte < 128 { - curByte = tablePtr[curByte&15] + curByte = tablePtr.Get(curByte & 15) } else { curByte = 0 } @@ -279,7 +319,7 @@ func UnpackedNibbleLookupInplace(main []byte, tablePtr *[16]byte) { // // 3. The caller does not care if a few bytes past the end of dst[] are // changed. -func UnpackedNibbleLookupUnsafe(dst, src []byte, tablePtr *[16]byte) { +func UnpackedNibbleLookupUnsafe(dst, src []byte, tablePtr *NibbleLookupTable) { srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src)) dstHeader := (*reflect.SliceHeader)(unsafe.Pointer(&dst)) unpackedNibbleLookupSSSE3Asm(unsafe.Pointer(dstHeader.Data), unsafe.Pointer(srcHeader.Data), unsafe.Pointer(tablePtr), srcHeader.Len) @@ -289,7 +329,7 @@ func UnpackedNibbleLookupUnsafe(dst, src []byte, tablePtr *[16]byte) { // if src[pos] < 128, set dst[pos] := table[src[pos] & 15] // otherwise, set dst[pos] := 0 // It panics if len(src) != len(dst). -func UnpackedNibbleLookup(dst, src []byte, tablePtr *[16]byte) { +func UnpackedNibbleLookup(dst, src []byte, tablePtr *NibbleLookupTable) { srcLen := len(src) if len(dst) != srcLen { panic("UnpackedNibbleLookup() requires len(src) == len(dst).") @@ -297,7 +337,7 @@ func UnpackedNibbleLookup(dst, src []byte, tablePtr *[16]byte) { if srcLen < 16 { for pos, curByte := range src { if curByte < 128 { - curByte = tablePtr[curByte&15] + curByte = tablePtr.Get(curByte & 15) } else { curByte = 0 } @@ -310,6 +350,30 @@ func UnpackedNibbleLookup(dst, src []byte, tablePtr *[16]byte) { unpackedNibbleLookupOddSSSE3Asm(unsafe.Pointer(dstHeader.Data), unsafe.Pointer(srcHeader.Data), unsafe.Pointer(tablePtr), srcLen) } +// UnpackedNibbleLookupS is a variant of UnpackedNibbleLookup() that takes +// string src. +func UnpackedNibbleLookupS(dst []byte, src string, tablePtr *NibbleLookupTable) { + srcLen := len(src) + if len(dst) != srcLen { + panic("UnpackedNibbleLookupS() requires len(src) == len(dst).") + } + if srcLen < 16 { + for pos := range src { + curByte := src[pos] + if curByte < 128 { + curByte = tablePtr.Get(curByte & 15) + } else { + curByte = 0 + } + dst[pos] = curByte + } + return + } + srcHeader := (*reflect.StringHeader)(unsafe.Pointer(&src)) + dstHeader := (*reflect.SliceHeader)(unsafe.Pointer(&dst)) + unpackedNibbleLookupOddSSSE3Asm(unsafe.Pointer(dstHeader.Data), unsafe.Pointer(srcHeader.Data), unsafe.Pointer(tablePtr), srcLen) +} + // PackedNibbleLookupUnsafe sets the bytes in dst[] as follows: // if pos is even, dst[pos] := table[src[pos / 2] & 15] // if pos is odd, dst[pos] := table[src[pos / 2] >> 4] @@ -328,7 +392,7 @@ func UnpackedNibbleLookup(dst, src []byte, tablePtr *[16]byte) { // // 3. The caller does not care if a few bytes past the end of dst[] are // changed. -func PackedNibbleLookupUnsafe(dst, src []byte, tablePtr *[16]byte) { +func PackedNibbleLookupUnsafe(dst, src []byte, tablePtr *NibbleLookupTable) { // Note that this is not the correct order for .bam seq[] unpacking; use // biosimd.UnpackAndReplaceSeqUnsafe() for that. srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src)) @@ -344,7 +408,7 @@ func PackedNibbleLookupUnsafe(dst, src []byte, tablePtr *[16]byte) { // Nothing bad happens if len(dst) is odd and some high bits in the last src[] // byte are set, though it's generally good practice to ensure that case // doesn't come up. -func PackedNibbleLookup(dst, src []byte, tablePtr *[16]byte) { +func PackedNibbleLookup(dst, src []byte, tablePtr *NibbleLookupTable) { // This takes ~15% longer than the unsafe function on the short-array // benchmark. dstLen := len(dst) @@ -356,8 +420,8 @@ func PackedNibbleLookup(dst, src []byte, tablePtr *[16]byte) { if nSrcFullByte < 16 { for srcPos := 0; srcPos < nSrcFullByte; srcPos++ { srcByte := src[srcPos] - dst[2*srcPos] = tablePtr[srcByte&15] - dst[2*srcPos+1] = tablePtr[srcByte>>4] + dst[2*srcPos] = tablePtr.Get(srcByte & 15) + dst[2*srcPos+1] = tablePtr.Get(srcByte >> 4) } } else { srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src)) @@ -366,7 +430,7 @@ func PackedNibbleLookup(dst, src []byte, tablePtr *[16]byte) { } if srcOdd == 1 { srcByte := src[nSrcFullByte] - dst[2*nSrcFullByte] = tablePtr[srcByte&15] + dst[2*nSrcFullByte] = tablePtr.Get(srcByte & 15) } } @@ -448,27 +512,27 @@ func Reverse8Unsafe(dst, src []byte) { } return } - srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src)) - dstHeader := (*reflect.SliceHeader)(unsafe.Pointer(&dst)) + srcData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&src)).Data) + dstData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&dst)).Data) if nByte < 16 { // use bswap64 on a word at a time nWordMinus1 := (nByte - 1) >> Log2BytesPerWord finalOffset := uintptr(nByte) - BytesPerWord - srcIter := unsafe.Pointer(srcHeader.Data + finalOffset) - dstIter := unsafe.Pointer(dstHeader.Data) + srcIter := unsafe.Add(srcData, finalOffset) + dstIter := dstData for widx := 0; widx < nWordMinus1; widx++ { srcWord := *((*uintptr)(srcIter)) *((*uintptr)(dstIter)) = uintptr(bits.ReverseBytes64(uint64(srcWord))) - srcIter = unsafe.Pointer(uintptr(srcIter) - BytesPerWord) - dstIter = unsafe.Pointer(uintptr(dstIter) - BytesPerWord) + srcIter = unsafe.Add(srcIter, -BytesPerWord) + dstIter = unsafe.Add(dstIter, -BytesPerWord) } - srcFirstWordPtr := unsafe.Pointer(srcHeader.Data) - dstLastWordPtr := unsafe.Pointer(dstHeader.Data + finalOffset) + srcFirstWordPtr := unsafe.Pointer(srcData) + dstLastWordPtr := unsafe.Add(dstData, finalOffset) srcWord := *((*uintptr)(srcFirstWordPtr)) *((*uintptr)(dstLastWordPtr)) = uintptr(bits.ReverseBytes64(uint64(srcWord))) return } - reverse8SSSE3Asm(unsafe.Pointer(dstHeader.Data), unsafe.Pointer(srcHeader.Data), nByte) + reverse8SSSE3Asm(dstData, srcData, nByte) } // Reverse8 sets dst[pos] := src[len(src) - 1 - pos] for every position in src. @@ -486,25 +550,61 @@ func Reverse8(dst, src []byte) { } return } - srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src)) - dstHeader := (*reflect.SliceHeader)(unsafe.Pointer(&dst)) + srcData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&src)).Data) + dstData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&dst)).Data) if nByte < 16 { // use bswap64 on a word at a time nWordMinus1 := (nByte - 1) >> Log2BytesPerWord finalOffset := uintptr(nByte) - BytesPerWord - srcIter := unsafe.Pointer(srcHeader.Data + finalOffset) - dstIter := unsafe.Pointer(dstHeader.Data) + srcIter := unsafe.Add(srcData, finalOffset) + dstIter := dstData for widx := 0; widx < nWordMinus1; widx++ { srcWord := *((*uintptr)(srcIter)) *((*uintptr)(dstIter)) = uintptr(bits.ReverseBytes64(uint64(srcWord))) - srcIter = unsafe.Pointer(uintptr(srcIter) - BytesPerWord) - dstIter = unsafe.Pointer(uintptr(dstIter) - BytesPerWord) + srcIter = unsafe.Add(srcIter, -BytesPerWord) + dstIter = unsafe.Add(dstIter, -BytesPerWord) } - srcFirstWordPtr := unsafe.Pointer(srcHeader.Data) - dstLastWordPtr := unsafe.Pointer(dstHeader.Data + finalOffset) + srcFirstWordPtr := srcData + dstLastWordPtr := unsafe.Add(dstData, finalOffset) srcWord := *((*uintptr)(srcFirstWordPtr)) *((*uintptr)(dstLastWordPtr)) = uintptr(bits.ReverseBytes64(uint64(srcWord))) return } - reverse8SSSE3Asm(unsafe.Pointer(dstHeader.Data), unsafe.Pointer(srcHeader.Data), nByte) + reverse8SSSE3Asm(dstData, srcData, nByte) +} + +// BitFromEveryByte fills dst[] with a bitarray containing every 8th bit from +// src[], starting with bitIdx, where bitIdx is in [0,7]. If len(src) is not +// divisible by 8, extra bits in the last filled byte of dst are set to zero. +// +// For example, if src[] is +// 0x1f 0x33 0x0d 0x00 0x51 0xcc 0x34 0x59 0x44 +// and bitIdx is 2, bit 2 from every byte is +// 1 0 1 0 0 1 1 0 1 +// so dst[] is filled with +// 0x65 0x01. +// +// - It panics if len(dst) < (len(src) + 7) / 8, or bitIdx isn't in [0,7]. +// - If dst is larger than necessary, the extra bytes are not changed. +func BitFromEveryByte(dst, src []byte, bitIdx int) { + requiredDstLen := (len(src) + 7) >> 3 + if (len(dst) < requiredDstLen) || (uint(bitIdx) > 7) { + panic("BitFromEveryByte requires len(dst) >= (len(src) + 7) / 8 and 0 <= bitIdx < 8.") + } + nSrcVecByte := len(src) &^ (bytesPerVec - 1) + if nSrcVecByte != 0 { + bitFromEveryByteSSE2Asm(unsafe.Pointer(&dst[0]), unsafe.Pointer(&src[0]), 7-bitIdx, nSrcVecByte>>3) + } + remainder := len(src) - nSrcVecByte + if remainder != 0 { + // Not optimized since it isn't expected to matter. + srcLast := src[nSrcVecByte:] + dstLast := dst[nSrcVecByte>>3 : requiredDstLen] + for i := range dstLast { + dstLast[i] = 0 + } + for i, b := range srcLast { + dstLast[i>>3] |= ((b >> uint32(bitIdx)) & 1) << uint32(i&7) + } + } } diff --git a/simd/simd_amd64.s b/simd/simd_amd64.s index 837d8447..61ba8774 100644 --- a/simd/simd_amd64.s +++ b/simd/simd_amd64.s @@ -6,23 +6,13 @@ DATA ·Mask0f0f<>+0x00(SB)/8, $0x0f0f0f0f0f0f0f0f DATA ·Mask0f0f<>+0x08(SB)/8, $0x0f0f0f0f0f0f0f0f - GLOBL ·Mask0f0f<>(SB), 24, $16 // NOPTR = 16, RODATA = 8 + GLOBL ·Mask0f0f<>(SB), 24, $16 + DATA ·Reverse8<>+0x00(SB)/8, $0x08090a0b0c0d0e0f DATA ·Reverse8<>+0x08(SB)/8, $0x0001020304050607 GLOBL ·Reverse8<>(SB), 24, $16 -// This was forked from github.com/willf/bitset . -// Some form of AVX2/AVX-512 detection will probably be added later. -TEXT ·hasSSE42Asm(SB),4,$0-1 - MOVQ $1, AX - CPUID - // CPUID function explicitly fills CX register. - SHRQ $23, CX - ANDQ $1, CX - MOVB CX, ret+0(FP) - RET - TEXT ·unpackedNibbleLookupTinyInplaceSSSE3Asm(SB),4,$0-16 // DI = pointer to current main[] element. MOVQ main+0(FP), DI @@ -400,3 +390,59 @@ reverse8SSSE3Loop: PSHUFB X0, X1 MOVOU X1, (R9) RET + +TEXT ·bitFromEveryByteSSE2Asm(SB),4,$0-32 + // bitFromEveryByteSSE2Asm grabs a single bit from every src[] byte, + // and packs them into dst[]. + // The implementation is based on the _mm_movemask_epi8() instruction, + // which grabs the *high* bit from each byte, so this function takes a + // 'lshift' argument instead of the wrapper's bitIdx. + + // Register allocation: + // AX: pointer to start of dst + // BX: pointer to start of src + // CX: nDstByte (must be even), minus 2 to support 2x unroll + // (rule of thumb: if the loop is less than ~10 operations, + // unrolling is likely to make a noticeable difference with + // minimal effort; otherwise don't bother) + // DX: loop counter + // SI, DI: intermediate movemask results + // + // X0: lshift + MOVQ dst+0(FP), AX + MOVQ src+8(FP), BX + MOVQ lshift+16(FP), X0 + + MOVQ nDstByte+24(FP), CX + SUBQ $2, CX + // Compilers emit this instead of XORQ DX,DX since it's smaller and has + // the same effect. + XORL DX, DX + + CMPQ CX, DX + JLE bitFromEveryByteSSE2AsmOdd + +bitFromEveryByteSSE2AsmLoop: + MOVOU (BX)(DX*8), X1 + MOVOU 16(BX)(DX*8), X2 + PSLLQ X0, X1 + PSLLQ X0, X2 + PMOVMSKB X1, SI + PMOVMSKB X2, DI + MOVW SI, (AX)(DX*1) + MOVW DI, 2(AX)(DX*1) + ADDQ $4, DX + CMPQ CX, DX + JG bitFromEveryByteSSE2AsmLoop + + JL bitFromEveryByteSSE2AsmFinish + + // Move this label up one line if we ever need to accept nDstByte == 0. +bitFromEveryByteSSE2AsmOdd: + MOVOU (BX)(DX*8), X1 + PSLLQ X0, X1 + PMOVMSKB X1, SI + MOVW SI, (AX)(DX*1) + +bitFromEveryByteSSE2AsmFinish: + RET diff --git a/simd/simd_generic.go b/simd/simd_generic.go new file mode 100644 index 00000000..0d04b73d --- /dev/null +++ b/simd/simd_generic.go @@ -0,0 +1,446 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +//go:build !amd64 || appengine + +package simd + +// amd64 compile-time constants. + +// BytesPerWord is the number of bytes in a machine word. +// We don't use unsafe.Sizeof(uintptr(1)) since there are advantages to having +// this as an untyped constant, and there's essentially no drawback since this +// is an _amd64-specific file. +const BytesPerWord = 8 + +// Log2BytesPerWord is log2(BytesPerWord). This is relevant for manual +// bit-shifting when we know that's a safe way to divide and the compiler does +// not (e.g. dividend is of signed int type). +const Log2BytesPerWord = uint(3) + +// BitsPerWord is the number of bits in a machine word. +const BitsPerWord = BytesPerWord * 8 + +// This must be at least / 16. +const nibbleLookupDup = 1 + +// NibbleLookupTable represents a parallel-byte-substitution operation f, where +// every byte b in a byte-slice is replaced with +// f(b) := shuffle[0][b & 15] for b <= 127, and +// f(b) := 0 for b > 127. +// (The second part is usually irrelevant in practice, but must be defined this +// way to allow _mm_shuffle_epi8()/_mm256_shuffle_epi8()/_mm512_shuffle_epi8() +// to be used to implement the operation efficiently.) +// It's named NibbleLookupTable rather than ByteLookupTable since only the +// bottom nibble of each byte can be used for table lookup. +// It potentially stores multiple adjacent copies of the lookup table since +// that speeds up the AVX2 and AVX-512 use cases (the table can be loaded with +// a single _mm256_loadu_si256 operation, instead of e.g. _mm_loadu_si128 +// followed by _mm256_set_m128i with the same argument twice), and the typical +// use case involves initializing very few tables and using them many, many +// times. +type NibbleLookupTable struct { + shuffle [nibbleLookupDup][16]byte +} + +func (t *NibbleLookupTable) Get(b byte) byte { + return t.shuffle[0][b] +} + +// const minPageSize = 4096 may be relevant for safe functions soon. + +// These could be compile-time constants for now, but not after AVX2 +// autodetection is added. + +// bytesPerVec is the size of the maximum-width vector that may be used. It is +// currently always 16, but it will be set to larger values at runtime in the +// future when AVX2/AVX-512/etc. is detected. +var bytesPerVec int + +// log2BytesPerVec supports efficient division by bytesPerVec. +var log2BytesPerVec uint + +func init() { + bytesPerVec = 16 + log2BytesPerVec = 4 +} + +// BytesPerVec is an accessor for the bytesPerVec package variable. +func BytesPerVec() int { + return bytesPerVec +} + +// RoundUpPow2 returns val rounded up to a multiple of alignment, assuming +// alignment is a power of 2. +func RoundUpPow2(val, alignment int) int { + return (val + alignment - 1) & (^(alignment - 1)) +} + +// DivUpPow2 efficiently divides a number by a power-of-2 divisor. (This works +// for negative dividends since the language specifies arithmetic right-shifts +// of signed numbers. I'm pretty sure this doesn't have a performance +// penalty.) +func DivUpPow2(dividend, divisor int, log2Divisor uint) int { + return (dividend + divisor - 1) >> log2Divisor +} + +// MakeUnsafe returns a byte slice of the given length which is guaranteed to +// have enough capacity for all Unsafe functions in this package to work. (It +// is not itself an unsafe function: allocated memory is zero-initialized.) +// Note that Unsafe functions occasionally have other caveats: e.g. +// PopcntUnsafe also requires relevant bytes past the end of the slice to be +// zeroed out. +func MakeUnsafe(len int) []byte { + // Although no planned function requires more than + // RoundUpPow2(len+1, bytesPerVec) capacity, it is necessary to add + // bytesPerVec instead to make subslicing safe. + return make([]byte, len, len+bytesPerVec) +} + +// RemakeUnsafe reuses the given buffer if it has sufficient capacity; +// otherwise it does the same thing as MakeUnsafe. It does NOT preserve +// existing contents of buf[]; use ResizeUnsafe() for that. +func RemakeUnsafe(bufptr *[]byte, len int) { + minCap := len + bytesPerVec + // This is likely to be called in an inner loop processing variable-size + // inputs, so mild exponential growth is appropriate. + *bufptr = make([]byte, len, RoundUpPow2(minCap+(minCap/8), bytesPerVec)) +} + +// ResizeUnsafe changes the length of buf and ensures it has enough extra +// capacity to be passed to this package's Unsafe functions. Existing buf[] +// contents are preserved (with possible truncation), though when length is +// increased, new bytes might not be zero-initialized. +func ResizeUnsafe(bufptr *[]byte, len int) { + minCap := len + bytesPerVec + dst := make([]byte, len, RoundUpPow2(minCap+(minCap/8), bytesPerVec)) + copy(dst, *bufptr) + *bufptr = dst +} + +// XcapUnsafe is shorthand for ResizeUnsafe's most common use case (no length +// change, just want to ensure sufficient capacity). +func XcapUnsafe(bufptr *[]byte) { + // mid-stack inlining isn't yet working as I write this, but it should be + // available soon enough: + // https://github.com/golang/go/issues/19348 + ResizeUnsafe(bufptr, len(*bufptr)) +} + +// Memset8Unsafe sets all values of dst[] to the given byte. (This is intended +// for val != 0. It is better to use a range-for loop for val == 0 since the +// compiler has a hardcoded optimization for that case; see +// https://github.com/golang/go/issues/5373 .) +// +// WARNING: This is a function designed to be used in inner loops, which +// assumes without checking that capacity is at least RoundUpPow2(len(dst), +// bytesPerVec). It also assumes that the caller does not care if a few bytes +// past the end of dst[] are changed. Use the safe version of this function if +// any of these properties are problematic. +// These assumptions are always satisfied when the last +// potentially-size-increasing operation on dst[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(). +func Memset8Unsafe(dst []byte, val byte) { + for pos := range dst { + dst[pos] = val + } +} + +// Memset8 sets all values of dst[] to the given byte. (This is intended for +// val != 0. It is better to use a range-for loop for val == 0 since the +// compiler has a hardcoded optimization for that case.) +func Memset8(dst []byte, val byte) { + for pos := range dst { + dst[pos] = val + } +} + +// MakeNibbleLookupTable generates a NibbleLookupTable from a [16]byte. +func MakeNibbleLookupTable(table [16]byte) (t NibbleLookupTable) { + for i := range t.shuffle { + t.shuffle[i] = table + } + return +} + +// UnpackedNibbleLookupUnsafeInplace replaces the bytes in main[] as follows: +// if value < 128, set to table[value & 15] +// otherwise, set to 0 +// +// WARNING: This is a function designed to be used in inner loops, which makes +// assumptions about capacity which aren't checked at runtime. Use the safe +// version of this function when that's a problem. +// These assumptions are always satisfied when the last +// potentially-size-increasing operation on main[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(). +// +// 1. cap(main) must be at least RoundUpPow2(len(main) + 1, bytesPerVec). +// +// 2. The caller does not care if a few bytes past the end of main[] are +// changed. +func UnpackedNibbleLookupUnsafeInplace(main []byte, tablePtr *NibbleLookupTable) { + for pos, curByte := range main { + if curByte < 128 { + curByte = tablePtr.shuffle[0][curByte&15] + } else { + curByte = 0 + } + main[pos] = curByte + } +} + +// UnpackedNibbleLookupInplace replaces the bytes in main[] as follows: +// if value < 128, set to table[value & 15] +// otherwise, set to 0 +func UnpackedNibbleLookupInplace(main []byte, tablePtr *NibbleLookupTable) { + for pos, curByte := range main { + if curByte < 128 { + curByte = tablePtr.shuffle[0][curByte&15] + } else { + curByte = 0 + } + main[pos] = curByte + } +} + +// UnpackedNibbleLookupUnsafe sets the bytes in dst[] as follows: +// if src[pos] < 128, set dst[pos] := table[src[pos] & 15] +// otherwise, set dst[pos] := 0 +// +// WARNING: This is a function designed to be used in inner loops, which makes +// assumptions about length and capacity which aren't checked at runtime. Use +// the safe version of this function when that's a problem. +// Assumptions #2-3 are always satisfied when the last +// potentially-size-increasing operation on src[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(), and the same is true for dst[]. +// +// 1. len(src) and len(dst) are equal. +// +// 2. Capacities are at least RoundUpPow2(len(src) + 1, bytesPerVec). +// +// 3. The caller does not care if a few bytes past the end of dst[] are +// changed. +func UnpackedNibbleLookupUnsafe(dst, src []byte, tablePtr *NibbleLookupTable) { + for pos, curByte := range src { + if curByte < 128 { + curByte = tablePtr.shuffle[0][curByte&15] + } else { + curByte = 0 + } + dst[pos] = curByte + } +} + +// UnpackedNibbleLookup sets the bytes in dst[] as follows: +// if src[pos] < 128, set dst[pos] := table[src[pos] & 15] +// otherwise, set dst[pos] := 0 +// It panics if len(src) != len(dst). +func UnpackedNibbleLookup(dst, src []byte, tablePtr *NibbleLookupTable) { + if len(dst) != len(src) { + panic("UnpackedNibbleLookup() requires len(src) == len(dst).") + } + for pos, curByte := range src { + if curByte < 128 { + curByte = tablePtr.shuffle[0][curByte&15] + } else { + curByte = 0 + } + dst[pos] = curByte + } +} + +// UnpackedNibbleLookupS is a variant of UnpackedNibbleLookup() that takes +// string src. +func UnpackedNibbleLookupS(dst []byte, src string, tablePtr *NibbleLookupTable) { + srcLen := len(src) + if len(dst) != srcLen { + panic("UnpackedNibbleLookupS() requires len(src) == len(dst).") + } + for pos := range src { + curByte := src[pos] + if curByte < 128 { + curByte = tablePtr.Get(curByte & 15) + } else { + curByte = 0 + } + dst[pos] = curByte + } + return +} + +// PackedNibbleLookupUnsafe sets the bytes in dst[] as follows: +// if pos is even, dst[pos] := table[src[pos / 2] & 15] +// if pos is odd, dst[pos] := table[src[pos / 2] >> 4] +// +// WARNING: This is a function designed to be used in inner loops, which makes +// assumptions about length and capacity which aren't checked at runtime. Use +// the safe version of this function when that's a problem. +// Assumptions #2-#3 are always satisfied when the last +// potentially-size-increasing operation on src[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(), and the same is true for dst[]. +// +// 1. len(src) == (len(dst) + 1) / 2. +// +// 2. Capacity of src is at least RoundUpPow2(len(src) + 1, bytesPerVec), and +// the same is true for dst. +// +// 3. The caller does not care if a few bytes past the end of dst[] are +// changed. +func PackedNibbleLookupUnsafe(dst, src []byte, tablePtr *NibbleLookupTable) { + dstLen := len(dst) + nSrcFullByte := dstLen >> 1 + srcOdd := dstLen & 1 + for srcPos := 0; srcPos < nSrcFullByte; srcPos++ { + srcByte := src[srcPos] + dst[2*srcPos] = tablePtr.shuffle[0][srcByte&15] + dst[2*srcPos+1] = tablePtr.shuffle[0][srcByte>>4] + } + if srcOdd == 1 { + srcByte := src[nSrcFullByte] + dst[2*nSrcFullByte] = tablePtr.shuffle[0][srcByte&15] + } +} + +// PackedNibbleLookup sets the bytes in dst[] as follows: +// if pos is even, dst[pos] := table[src[pos / 2] & 15] +// if pos is odd, dst[pos] := table[src[pos / 2] >> 4] +// It panics if len(src) != (len(dst) + 1) / 2. +// +// Nothing bad happens if len(dst) is odd and some high bits in the last src[] +// byte are set, though it's generally good practice to ensure that case +// doesn't come up. +func PackedNibbleLookup(dst, src []byte, tablePtr *NibbleLookupTable) { + dstLen := len(dst) + nSrcFullByte := dstLen >> 1 + srcOdd := dstLen & 1 + if len(src) != nSrcFullByte+srcOdd { + panic("PackedNibbleLookup() requires len(src) == (len(dst) + 1) / 2.") + } + for srcPos := 0; srcPos < nSrcFullByte; srcPos++ { + srcByte := src[srcPos] + dst[2*srcPos] = tablePtr.shuffle[0][srcByte&15] + dst[2*srcPos+1] = tablePtr.shuffle[0][srcByte>>4] + } + if srcOdd == 1 { + srcByte := src[nSrcFullByte] + dst[2*nSrcFullByte] = tablePtr.shuffle[0][srcByte&15] + } +} + +// Interleave8Unsafe sets the bytes in dst[] as follows: +// if pos is even, dst[pos] := even[pos/2] +// if pos is odd, dst[pos] := odd[pos/2] +// +// WARNING: This is a function designed to be used in inner loops, which makes +// assumptions about length and capacity which aren't checked at runtime. Use +// the safe version of this function when that's a problem. +// Assumptions #2-3 are always satisfied when the last +// potentially-size-increasing operation on dst[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(), and the same is true for even[] and odd[]. +// +// 1. len(even) = (len(dst) + 1) / 2, and len(odd) = len(dst) / 2. +// +// 2. cap(dst) >= RoundUpPow2(len(dst) + 1, bytesPerVec), +// cap(even) >= RoundUpPow2(len(even) + 1, bytesPerVec), and +// cap(odd) >= RoundUpPow2(len(odd) + 1, bytesPerVec). +// +// 3. The caller does not care if a few bytes past the end of dst[] are +// changed. +func Interleave8Unsafe(dst, even, odd []byte) { + dstLen := len(dst) + evenLen := (dstLen + 1) >> 1 + oddLen := dstLen >> 1 + for idx, oddByte := range odd { + dst[2*idx] = even[idx] + dst[2*idx+1] = oddByte + } + if oddLen != evenLen { + dst[oddLen*2] = even[oddLen] + } +} + +// Interleave8 sets the bytes in dst[] as follows: +// if pos is even, dst[pos] := even[pos/2] +// if pos is odd, dst[pos] := odd[pos/2] +// It panics if ((len(dst) + 1) / 2) != len(even), or (len(dst) / 2) != +// len(odd). +func Interleave8(dst, even, odd []byte) { + // This is ~6-20% slower than the unsafe function on the short-array + // benchmark. + dstLen := len(dst) + evenLen := (dstLen + 1) >> 1 + oddLen := dstLen >> 1 + if (len(even) != evenLen) || (len(odd) != oddLen) { + panic("Interleave8() requires len(even) == len(dst) + 1) / 2, and len(odd) == len(dst) / 2.") + } + for idx, oddByte := range odd { + dst[2*idx] = even[idx] + dst[2*idx+1] = oddByte + } + if oddLen != evenLen { + dst[oddLen*2] = even[oddLen] + } +} + +// Reverse8Inplace reverses the bytes in main[]. (There is no unsafe version +// of this function.) +func Reverse8Inplace(main []byte) { + nByte := len(main) + nByteDiv2 := nByte >> 1 + for idx, invIdx := 0, nByte-1; idx != nByteDiv2; idx, invIdx = idx+1, invIdx-1 { + main[idx], main[invIdx] = main[invIdx], main[idx] + } +} + +// Reverse8Unsafe sets dst[pos] := src[len(src) - 1 - pos] for every position +// in src. +// +// WARNING: This does not verify len(dst) == len(src); call the safe version of +// this function if you want that. +func Reverse8Unsafe(dst, src []byte) { + nByte := len(src) + nByteMinus1 := nByte - 1 + for idx := 0; idx != nByte; idx++ { + dst[nByteMinus1-idx] = src[idx] + } +} + +// Reverse8 sets dst[pos] := src[len(src) - 1 - pos] for every position in src. +// It panics if len(src) != len(dst). +func Reverse8(dst, src []byte) { + nByte := len(src) + if nByte != len(dst) { + panic("Reverse8() requires len(src) == len(dst).") + } + nByteMinus1 := nByte - 1 + for idx := 0; idx != nByte; idx++ { + dst[nByteMinus1-idx] = src[idx] + } +} + +// BitFromEveryByte fills dst[] with a bitarray containing every 8th bit from +// src[], starting with bitIdx, where bitIdx is in [0,7]. If len(src) is not +// divisible by 8, extra bits in the last filled byte of dst are set to zero. +// For example, if src[] is +// 0x1f 0x33 0x0d 0x00 0x51 0xcc 0x34 0x59 0x44 +// and bitIdx is 2, bit 2 from every byte is +// 1 0 1 0 0 1 1 0 1 +// so dst[] is filled with +// 0x65 0x01. +// +// - It panics if len(dst) < (len(src) + 7) / 8, or bitIdx isn't in [0,7]. +// - If dst is larger than necessary, the extra bytes are not changed. +func BitFromEveryByte(dst, src []byte, bitIdx int) { + requiredDstLen := (len(src) + 7) >> 3 + if (len(dst) < requiredDstLen) || (uint(bitIdx) > 7) { + panic("BitFromEveryByte requires len(dst) >= (len(src) + 7) / 8 and 0 <= bitIdx < 8.") + } + dst = dst[:requiredDstLen] + for i := range dst { + dst[i] = 0 + } + for i, b := range src { + dst[i>>3] |= ((b >> uint32(bitIdx)) & 1) << uint32(i&7) + } +} diff --git a/simd/simd_test.go b/simd/simd_test.go index 1fdecb12..2171a9c1 100644 --- a/simd/simd_test.go +++ b/simd/simd_test.go @@ -6,198 +6,17 @@ package simd_test import ( "bytes" + "encoding/binary" "math/rand" - "runtime" "testing" "github.com/grailbio/base/simd" + "github.com/grailbio/testutil/assert" ) -/* -Initial benchmark results: - MacBook Pro (15-inch, 2016) - 2.7 GHz Intel Core i7, 16 GB 2133 MHz LPDDR3 - -Benchmark_Memset8Short1-8 20 56542414 ns/op -Benchmark_Memset8Short4-8 100 15969877 ns/op -Benchmark_Memset8ShortMax-8 100 15780214 ns/op -Benchmark_Memset8Long1-8 1 1279094415 ns/op -Benchmark_Memset8Long4-8 1 1902840097 ns/op -Benchmark_Memset8LongMax-8 1 2715013574 ns/op - -Benchmark_UnpackedNibbleLookupShort1-8 20 66268257 ns/op -Benchmark_UnpackedNibbleLookupShort4-8 100 18575755 ns/op -Benchmark_UnpackedNibbleLookupShortMax-8 100 17474281 ns/op -Benchmark_UnpackedNibbleLookupLong1-8 1 1330878965 ns/op -Benchmark_UnpackedNibbleLookupLong4-8 1 1977241995 ns/op -Benchmark_UnpackedNibbleLookupLongMax-8 1 2793933818 ns/op - -Benchmark_PackedNibbleLookupShort1-8 20 80579763 ns/op -Benchmark_PackedNibbleLookupShort4-8 100 23488681 ns/op -Benchmark_PackedNibbleLookupShortMax-8 100 21701360 ns/op -Benchmark_PackedNibbleLookupLong1-8 1 1470408074 ns/op -Benchmark_PackedNibbleLookupLong4-8 1 2103843655 ns/op -Benchmark_PackedNibbleLookupLongMax-8 1 2767976716 ns/op - -Benchmark_InterleaveShort1-8 10 122320311 ns/op -Benchmark_InterleaveShort4-8 50 33240437 ns/op -Benchmark_InterleaveShortMax-8 50 27383249 ns/op -Benchmark_InterleaveLong1-8 1 1557992496 ns/op -Benchmark_InterleaveLong4-8 1 2177311837 ns/op -Benchmark_InterleaveLongMax-8 1 2838302958 ns/op - -Benchmark_Reverse8Short1-8 20 66878761 ns/op -Benchmark_Reverse8Short4-8 100 18888361 ns/op -Benchmark_Reverse8ShortMax-8 100 17845626 ns/op -Benchmark_Reverse8Long1-8 1 1274790843 ns/op -Benchmark_Reverse8Long4-8 1 1962669700 ns/op -Benchmark_Reverse8LongMax-8 1 2719838443 ns/op - -For comparison, memset8: -Benchmark_Memset8Short1-8 5 270933107 ns/op -Benchmark_Memset8Short4-8 20 78389931 ns/op -Benchmark_Memset8ShortMax-8 20 66983738 ns/op -Benchmark_Memset8Long1-8 1 1342542739 ns/op -Benchmark_Memset8Long4-8 1 1944395002 ns/op -Benchmark_Memset8LongMax-8 1 2757737157 ns/op - -memset-to-zero range for loop: -Benchmark_Memset8Short1-8 30 37976858 ns/op -Benchmark_Memset8Short4-8 50 25033805 ns/op -Benchmark_Memset8ShortMax-8 100 14801649 ns/op -Benchmark_Memset8Long1-8 3 448067523 ns/op -Benchmark_Memset8Long4-8 1 1361988705 ns/op -Benchmark_Memset8LongMax-8 1 2126505354 ns/op -(Note that this is usually better than simd.Memset8. This is due to reduced -function call overhead and use of AVX2 (with cache-bypassing stores in the AVX2 ->32 MiB case); there was no advantage to replacing simd.Memset8 with the -non-AVX2 portion of runtime.memclr_amd64.) - -unpackedNibbleLookupInplaceSlow (&15 removed, bytes restricted to 0..15): -Benchmark_UnpackedNibbleLookupShort1-8 2 524170511 ns/op -Benchmark_UnpackedNibbleLookupShort4-8 10 147371412 ns/op -Benchmark_UnpackedNibbleLookupShortMax-8 10 142252262 ns/op -Benchmark_UnpackedNibbleLookupLong1-8 1 8123456605 ns/op -Benchmark_UnpackedNibbleLookupLong4-8 1 5069456472 ns/op -Benchmark_UnpackedNibbleLookupLongMax-8 1 3929059263 ns/op - -packedNibbleLookupSlow: -Benchmark_PackedNibbleLookupShort1-8 2 572680365 ns/op -Benchmark_PackedNibbleLookupShort4-8 10 158619127 ns/op -Benchmark_PackedNibbleLookupShortMax-8 10 155940159 ns/op -Benchmark_PackedNibbleLookupLong1-8 1 8956157310 ns/op -Benchmark_PackedNibbleLookupLong4-8 1 3226223964 ns/op -Benchmark_PackedNibbleLookupLongMax-8 1 3788519710 ns/op - -interleaveSlow: -Benchmark_InterleaveShort1-8 2 779212342 ns/op -Benchmark_InterleaveShort4-8 5 207224364 ns/op -Benchmark_InterleaveShortMax-8 5 213528353 ns/op -Benchmark_InterleaveLong1-8 1 6926143664 ns/op -Benchmark_InterleaveLong4-8 1 2745455753 ns/op -Benchmark_InterleaveLongMax-8 1 3664858002 ns/op - -reverseSlow: -Benchmark_Reverse8Short1-8 3 423063894 ns/op -Benchmark_Reverse8Short4-8 10 112274707 ns/op -Benchmark_Reverse8ShortMax-8 10 196379771 ns/op -Benchmark_Reverse8Long1-8 1 6270445876 ns/op -Benchmark_Reverse8Long4-8 1 3965932146 ns/op -Benchmark_Reverse8LongMax-8 1 3349559784 ns/op -*/ - -/* -type QLetter struct { - L byte - Q byte -} -*/ - -func memset8Subtask(dst []byte, nIter int) int { - for iter := 0; iter < nIter; iter++ { - // Compiler-recognized range-for loop, for comparison. - /* - for pos := range dst { - dst[pos] = 0 - } - */ - simd.Memset8Unsafe(dst, 78) - } - return int(dst[0]) -} - -func memset8SubtaskFuture(dst []byte, nIter int) chan int { - future := make(chan int) - go func() { future <- memset8Subtask(dst, nIter) }() - return future -} - -func multiMemset8(dsts [][]byte, cpus int, nJob int) { - sumFutures := make([]chan int, cpus) - shardSizeBase := nJob / cpus - shardRemainder := nJob - shardSizeBase*cpus - shardSizeP1 := shardSizeBase + 1 - var taskIdx int - for ; taskIdx < shardRemainder; taskIdx++ { - sumFutures[taskIdx] = memset8SubtaskFuture(dsts[taskIdx], shardSizeP1) - } - for ; taskIdx < cpus; taskIdx++ { - sumFutures[taskIdx] = memset8SubtaskFuture(dsts[taskIdx], shardSizeBase) - } - var sum int - for taskIdx = 0; taskIdx < cpus; taskIdx++ { - sum += <-sumFutures[taskIdx] - } -} - -func benchmarkMemset8(cpus int, nByte int, nJob int, b *testing.B) { - if cpus > runtime.NumCPU() { - b.Skipf("only have %v cpus", runtime.NumCPU()) - } - - mainSlices := make([][]byte, cpus) - for ii := range mainSlices { - // Add 63 to prevent false sharing. - newArr := simd.MakeUnsafe(nByte + 63) - for jj := 0; jj < nByte; jj++ { - newArr[jj] = byte(jj * 3) - } - mainSlices[ii] = newArr[:nByte] - } - for i := 0; i < b.N; i++ { - multiMemset8(mainSlices, cpus, nJob) - } -} - -// Base sequence in length-150 .bam read occupies 75 bytes, so 75 is a good -// size for the short-array benchmark. -func Benchmark_Memset8Short1(b *testing.B) { - benchmarkMemset8(1, 75, 9999999, b) -} - -func Benchmark_Memset8Short4(b *testing.B) { - benchmarkMemset8(4, 75, 9999999, b) -} - -func Benchmark_Memset8ShortMax(b *testing.B) { - benchmarkMemset8(runtime.NumCPU(), 75, 9999999, b) -} - -// GRCh37 chromosome 1 length is 249250621, so that's a plausible long-array -// use case. -func Benchmark_Memset8Long1(b *testing.B) { - benchmarkMemset8(1, 249250621, 50, b) -} - -func Benchmark_Memset8Long4(b *testing.B) { - benchmarkMemset8(4, 249250621, 50, b) -} - -func Benchmark_Memset8LongMax(b *testing.B) { - benchmarkMemset8(runtime.NumCPU(), 249250621, 50, b) -} - -func memset8(dst []byte, val byte) { +// This is the most-frequently-recommended implementation. It's decent, so the +// suffix is 'Standard' instead of 'Slow'. +func memset8Standard(dst []byte, val byte) { dstLen := len(dst) if dstLen != 0 { dst[0] = val @@ -220,7 +39,7 @@ func TestMemset8(t *testing.T) { main2Slice := main2Arr[sliceStart:sliceEnd] main3Slice := main3Arr[sliceStart:sliceEnd] byteVal := byte(rand.Intn(256)) - memset8(main1Slice, byteVal) + memset8Standard(main1Slice, byteVal) simd.Memset8Unsafe(main2Slice, byteVal) if !bytes.Equal(main1Slice, main2Slice) { t.Fatal("Mismatched Memset8Unsafe result.") @@ -240,91 +59,94 @@ func TestMemset8(t *testing.T) { } } -func unpackedNibbleLookupSubtask(main []byte, nIter int) int { - table := [...]byte{0, 1, 0, 2, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0} - for iter := 0; iter < nIter; iter++ { - // Note that this uses the result of one lookup operation as the input to - // the next. - // (Given the current table, all values should be 1 or 0 after 3 or more - // iterations.) - simd.UnpackedNibbleLookupUnsafeInplace(main, &table) - } - return int(main[0]) -} +/* +Benchmark results: + MacBook Pro (15-inch, 2016) + 2.7 GHz Intel Core i7, 16 GB 2133 MHz LPDDR3 -func unpackedNibbleLookupSubtaskFuture(main []byte, nIter int) chan int { - future := make(chan int) - go func() { future <- unpackedNibbleLookupSubtask(main, nIter) }() - return future -} +Benchmark_Memset8/SIMDShort1Cpu-8 20 62706981 ns/op +Benchmark_Memset8/SIMDShortHalfCpu-8 100 17559573 ns/op +Benchmark_Memset8/SIMDShortAllCpu-8 100 17149982 ns/op +Benchmark_Memset8/SIMDLong1Cpu-8 1 1101524485 ns/op +Benchmark_Memset8/SIMDLongHalfCpu-8 2 925331938 ns/op +Benchmark_Memset8/SIMDLongAllCpu-8 2 971422170 ns/op +Benchmark_Memset8/StandardShort1Cpu-8 5 314689466 ns/op +Benchmark_Memset8/StandardShortHalfCpu-8 20 88260588 ns/op +Benchmark_Memset8/StandardShortAllCpu-8 20 84317546 ns/op +Benchmark_Memset8/StandardLong1Cpu-8 1 1082736141 ns/op +Benchmark_Memset8/StandardLongHalfCpu-8 2 992904776 ns/op +Benchmark_Memset8/StandardLongAllCpu-8 1 1052452033 ns/op +Benchmark_Memset8/RangeZeroShort1Cpu-8 30 44907924 ns/op +Benchmark_Memset8/RangeZeroShortHalfCpu-8 100 24173280 ns/op +Benchmark_Memset8/RangeZeroShortAllCpu-8 100 14991003 ns/op +Benchmark_Memset8/RangeZeroLong1Cpu-8 3 401003587 ns/op +Benchmark_Memset8/RangeZeroLongHalfCpu-8 3 400711072 ns/op +Benchmark_Memset8/RangeZeroLongAllCpu-8 3 404863223 ns/op + +Notes: simd.Memset8 is broadly useful for short arrays, though usually a bit +worse than memclr. However, memclr wins handily in the 249 MB long case on the +test machine, thanks to AVX2 (and, in the AVX2 subroutine, cache-bypassing +stores). +When the simd.Memset8 AVX2 implementation is written, it should obviously +imitate what memclr is doing. +*/ -func multiUnpackedNibbleLookup(mains [][]byte, cpus int, nJob int) { - sumFutures := make([]chan int, cpus) - shardSizeBase := nJob / cpus - shardRemainder := nJob - shardSizeBase*cpus - shardSizeP1 := shardSizeBase + 1 - var taskIdx int - for ; taskIdx < shardRemainder; taskIdx++ { - sumFutures[taskIdx] = unpackedNibbleLookupSubtaskFuture(mains[taskIdx], shardSizeP1) - } - for ; taskIdx < cpus; taskIdx++ { - sumFutures[taskIdx] = unpackedNibbleLookupSubtaskFuture(mains[taskIdx], shardSizeBase) - } - var sum int - for taskIdx = 0; taskIdx < cpus; taskIdx++ { - sum += <-sumFutures[taskIdx] +func memset8SimdSubtask(dst, src []byte, nIter int) int { + for iter := 0; iter < nIter; iter++ { + simd.Memset8(dst, 78) } + return int(dst[0]) } -func benchmarkUnpackedNibbleLookup(cpus int, nByte int, nJob int, b *testing.B) { - if cpus > runtime.NumCPU() { - b.Skipf("only have %v cpus", runtime.NumCPU()) +func memset8StandardSubtask(dst, src []byte, nIter int) int { + for iter := 0; iter < nIter; iter++ { + memset8Standard(dst, 78) } + return int(dst[0]) +} - mainSlices := make([][]byte, cpus) - for ii := range mainSlices { - // Add 63 to prevent false sharing. - newArr := simd.MakeUnsafe(nByte + 63) - for jj := 0; jj < nByte; jj++ { - newArr[jj] = byte(jj * 3) +func memset8RangeZeroSubtask(dst, src []byte, nIter int) int { + for iter := 0; iter < nIter; iter++ { + // Compiler-recognized loop, which gets converted to a memclr call with + // fancier optimizations than simd.Memset8. + for pos := range dst { + dst[pos] = 0 } - mainSlices[ii] = newArr[:nByte] - } - for i := 0; i < b.N; i++ { - multiUnpackedNibbleLookup(mainSlices, cpus, nJob) } + return int(dst[0]) } -func Benchmark_UnpackedNibbleLookupShort1(b *testing.B) { - benchmarkUnpackedNibbleLookup(1, 75, 9999999, b) -} - -func Benchmark_UnpackedNibbleLookupShort4(b *testing.B) { - benchmarkUnpackedNibbleLookup(4, 75, 9999999, b) -} - -func Benchmark_UnpackedNibbleLookupShortMax(b *testing.B) { - benchmarkUnpackedNibbleLookup(runtime.NumCPU(), 75, 9999999, b) -} - -func Benchmark_UnpackedNibbleLookupLong1(b *testing.B) { - benchmarkUnpackedNibbleLookup(1, 249250621, 50, b) -} - -func Benchmark_UnpackedNibbleLookupLong4(b *testing.B) { - benchmarkUnpackedNibbleLookup(4, 249250621, 50, b) -} - -func Benchmark_UnpackedNibbleLookupLongMax(b *testing.B) { - benchmarkUnpackedNibbleLookup(runtime.NumCPU(), 249250621, 50, b) +func Benchmark_Memset8(b *testing.B) { + funcs := []taggedMultiBenchFunc{ + { + f: memset8SimdSubtask, + tag: "SIMD", + }, + { + f: memset8StandardSubtask, + tag: "Standard", + }, + { + f: memset8RangeZeroSubtask, + tag: "RangeZero", + }, + } + for _, f := range funcs { + // Base sequence in length-150 .bam read occupies 75 bytes, so 75 is a good + // size for the short-array benchmark. + multiBenchmark(f.f, f.tag+"Short", 75, 0, 9999999, b) + // GRCh37 chromosome 1 length is 249250621, so that's a plausible + // long-array use case. + multiBenchmark(f.f, f.tag+"Long", 249250621, 0, 50, b) + } } // This only matches UnpackedNibbleLookupInplace when all bytes < 128; the test // has been restricted accordingly. _mm_shuffle_epi8()'s treatment of bytes >= // 128 usually isn't relevant. -func unpackedNibbleLookupInplaceSlow(main []byte, tablePtr *[16]byte) { +func unpackedNibbleLookupInplaceSlow(main []byte, tablePtr *simd.NibbleLookupTable) { for idx := range main { - main[idx] = tablePtr[main[idx]&15] + main[idx] = tablePtr.Get(main[idx] & 15) } } @@ -336,7 +158,7 @@ func TestUnpackedNibbleLookup(t *testing.T) { main3Arr := simd.MakeUnsafe(maxSize) main4Arr := simd.MakeUnsafe(maxSize) main5Arr := simd.MakeUnsafe(maxSize) - table := [...]byte{0, 1, 0, 2, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0} + table := simd.MakeNibbleLookupTable([16]byte{0, 1, 0, 2, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0}) for iter := 0; iter < nIter; iter++ { sliceStart := rand.Intn(maxSize) sliceEnd := sliceStart + rand.Intn(maxSize-sliceStart) @@ -384,97 +206,74 @@ func TestUnpackedNibbleLookup(t *testing.T) { } } -func packedNibbleLookupSubtask(dst, src []byte, nIter int) int { - table := [...]byte{0, 1, 0, 2, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0} +/* +Benchmark results: + MacBook Pro (15-inch, 2016) + 2.7 GHz Intel Core i7, 16 GB 2133 MHz LPDDR3 + +Benchmark_UnpackedNibbleLookupInplace/SIMDShort1Cpu-8 20 76720863 ns/op +Benchmark_UnpackedNibbleLookupInplace/SIMDShortHalfCpu-8 50 22968008 ns/op +Benchmark_UnpackedNibbleLookupInplace/SIMDShortAllCpu-8 100 18896633 ns/op +Benchmark_UnpackedNibbleLookupInplace/SIMDLong1Cpu-8 1 1046243684 ns/op +Benchmark_UnpackedNibbleLookupInplace/SIMDLongHalfCpu-8 2 861622838 ns/op +Benchmark_UnpackedNibbleLookupInplace/SIMDLongAllCpu-8 2 944384349 ns/op +Benchmark_UnpackedNibbleLookupInplace/SlowShort1Cpu-8 2 532267799 ns/op +Benchmark_UnpackedNibbleLookupInplace/SlowShortHalfCpu-8 10 144993320 ns/op +Benchmark_UnpackedNibbleLookupInplace/SlowShortAllCpu-8 10 146218387 ns/op +Benchmark_UnpackedNibbleLookupInplace/SlowLong1Cpu-8 1 7745668548 ns/op +Benchmark_UnpackedNibbleLookupInplace/SlowLongHalfCpu-8 1 2169127851 ns/op +Benchmark_UnpackedNibbleLookupInplace/SlowLongAllCpu-8 1 2164900359 ns/op +*/ + +func unpackedNibbleLookupInplaceSimdSubtask(dst, src []byte, nIter int) int { + table := simd.MakeNibbleLookupTable([16]byte{0, 1, 0, 2, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0}) for iter := 0; iter < nIter; iter++ { - simd.PackedNibbleLookupUnsafe(dst, src, &table) + // Note that this uses the result of one lookup operation as the input to + // the next. + // (Given the current table, all values should be 1 or 0 after 3 or more + // iterations.) + simd.UnpackedNibbleLookupInplace(dst, &table) } return int(dst[0]) } -func packedNibbleLookupSubtaskFuture(dst, src []byte, nIter int) chan int { - future := make(chan int) - go func() { future <- packedNibbleLookupSubtask(dst, src, nIter) }() - return future -} - -func multiPackedNibbleLookup(dsts, srcs [][]byte, cpus int, nJob int) { - sumFutures := make([]chan int, cpus) - shardSizeBase := nJob / cpus - shardRemainder := nJob - shardSizeBase*cpus - shardSizeP1 := shardSizeBase + 1 - var taskIdx int - for ; taskIdx < shardRemainder; taskIdx++ { - sumFutures[taskIdx] = packedNibbleLookupSubtaskFuture(dsts[taskIdx], srcs[taskIdx], shardSizeP1) - } - for ; taskIdx < cpus; taskIdx++ { - sumFutures[taskIdx] = packedNibbleLookupSubtaskFuture(dsts[taskIdx], srcs[taskIdx], shardSizeBase) - } - var sum int - for taskIdx = 0; taskIdx < cpus; taskIdx++ { - sum += <-sumFutures[taskIdx] +func unpackedNibbleLookupInplaceSlowSubtask(dst, src []byte, nIter int) int { + table := simd.MakeNibbleLookupTable([16]byte{0, 1, 0, 2, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0}) + for iter := 0; iter < nIter; iter++ { + unpackedNibbleLookupInplaceSlow(dst, &table) } + return int(dst[0]) } -func benchmarkPackedNibbleLookup(cpus int, nDstByte int, nJob int, b *testing.B) { - if cpus > runtime.NumCPU() { - b.Skipf("only have %v cpus", runtime.NumCPU()) - } - - srcSlices := make([][]byte, cpus) - dstSlices := make([][]byte, cpus) - nSrcByte := (nDstByte + 1) / 2 - for ii := range srcSlices { - // Add 63 to prevent false sharing. - newArr := simd.MakeUnsafe(nSrcByte + 63) - for jj := 0; jj < nSrcByte; jj++ { - newArr[jj] = byte(jj * 3) - } - srcSlices[ii] = newArr[:nSrcByte] - newArr = simd.MakeUnsafe(nDstByte + 63) - dstSlices[ii] = newArr[:nDstByte] +func Benchmark_UnpackedNibbleLookupInplace(b *testing.B) { + funcs := []taggedMultiBenchFunc{ + { + f: unpackedNibbleLookupInplaceSimdSubtask, + tag: "SIMD", + }, + { + f: unpackedNibbleLookupInplaceSlowSubtask, + tag: "Slow", + }, } - for i := 0; i < b.N; i++ { - multiPackedNibbleLookup(dstSlices, srcSlices, cpus, nJob) + for _, f := range funcs { + multiBenchmark(f.f, f.tag+"Short", 75, 0, 9999999, b) + multiBenchmark(f.f, f.tag+"Long", 249250621, 0, 50, b) } } -func Benchmark_PackedNibbleLookupShort1(b *testing.B) { - benchmarkPackedNibbleLookup(1, 75, 9999999, b) -} - -func Benchmark_PackedNibbleLookupShort4(b *testing.B) { - benchmarkPackedNibbleLookup(4, 75, 9999999, b) -} - -func Benchmark_PackedNibbleLookupShortMax(b *testing.B) { - benchmarkPackedNibbleLookup(runtime.NumCPU(), 75, 9999999, b) -} - -func Benchmark_PackedNibbleLookupLong1(b *testing.B) { - benchmarkPackedNibbleLookup(1, 249250621, 50, b) -} - -func Benchmark_PackedNibbleLookupLong4(b *testing.B) { - benchmarkPackedNibbleLookup(4, 249250621, 50, b) -} - -func Benchmark_PackedNibbleLookupLongMax(b *testing.B) { - benchmarkPackedNibbleLookup(runtime.NumCPU(), 249250621, 50, b) -} - -func packedNibbleLookupSlow(dst, src []byte, tablePtr *[16]byte) { +func packedNibbleLookupSlow(dst, src []byte, tablePtr *simd.NibbleLookupTable) { dstLen := len(dst) nSrcFullByte := dstLen / 2 srcOdd := dstLen & 1 for srcPos := 0; srcPos < nSrcFullByte; srcPos++ { srcByte := src[srcPos] - dst[2*srcPos] = tablePtr[srcByte&15] - dst[2*srcPos+1] = tablePtr[srcByte>>4] + dst[2*srcPos] = tablePtr.Get(srcByte & 15) + dst[2*srcPos+1] = tablePtr.Get(srcByte >> 4) } if srcOdd == 1 { srcByte := src[nSrcFullByte] - dst[2*nSrcFullByte] = tablePtr[srcByte&15] + dst[2*nSrcFullByte] = tablePtr.Get(srcByte & 15) } } @@ -485,7 +284,7 @@ func TestPackedNibbleLookup(t *testing.T) { srcArr := simd.MakeUnsafe(maxSrcSize) dst1Arr := simd.MakeUnsafe(maxDstSize) dst2Arr := simd.MakeUnsafe(maxDstSize) - table := [...]byte{0, 1, 0, 2, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0} + table := simd.MakeNibbleLookupTable([16]byte{0, 1, 0, 2, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0}) for iter := 0; iter < nIter; iter++ { srcSliceStart := rand.Intn(maxSrcSize) dstSliceStart := srcSliceStart * 2 @@ -518,84 +317,81 @@ func TestPackedNibbleLookup(t *testing.T) { } } -func interleaveSubtask(dst, src []byte, nIter int) int { +/* +Benchmark results: + MacBook Pro (15-inch, 2016) + 2.7 GHz Intel Core i7, 16 GB 2133 MHz LPDDR3 + +Benchmark_PackedNibbleLookup/UnsafeShort1Cpu-8 10 143501956 ns/op +Benchmark_PackedNibbleLookup/UnsafeShortHalfCpu-8 30 38748958 ns/op +Benchmark_PackedNibbleLookup/UnsafeShortAllCpu-8 50 31982398 ns/op +Benchmark_PackedNibbleLookup/UnsafeLong1Cpu-8 1 1372142640 ns/op +Benchmark_PackedNibbleLookup/UnsafeLongHalfCpu-8 1 1236198290 ns/op +Benchmark_PackedNibbleLookup/UnsafeLongAllCpu-8 1 1265315746 ns/op +Benchmark_PackedNibbleLookup/SIMDShort1Cpu-8 10 158155872 ns/op +Benchmark_PackedNibbleLookup/SIMDShortHalfCpu-8 30 43098347 ns/op +Benchmark_PackedNibbleLookup/SIMDShortAllCpu-8 30 37593692 ns/op +Benchmark_PackedNibbleLookup/SIMDLong1Cpu-8 1 1407559630 ns/op +Benchmark_PackedNibbleLookup/SIMDLongHalfCpu-8 1 1244569913 ns/op +Benchmark_PackedNibbleLookup/SIMDLongAllCpu-8 1 1245648867 ns/op +Benchmark_PackedNibbleLookup/SlowShort1Cpu-8 1 1322739228 ns/op +Benchmark_PackedNibbleLookup/SlowShortHalfCpu-8 3 381551545 ns/op +Benchmark_PackedNibbleLookup/SlowShortAllCpu-8 3 361846656 ns/op +Benchmark_PackedNibbleLookup/SlowLong1Cpu-8 1 9990188206 ns/op +Benchmark_PackedNibbleLookup/SlowLongHalfCpu-8 1 2855687759 ns/op +Benchmark_PackedNibbleLookup/SlowLongAllCpu-8 1 2877628266 ns/op + +Notes: Unsafe version of this function is also benchmarked, since the +short-array safety penalty is a bit high here. This is mainly an indicator of +room for improvement in the safe function; I think it's clear at this point +that we'll probably never need to use the Unsafe interface. +*/ + +func packedNibbleLookupUnsafeSubtask(dst, src []byte, nIter int) int { + table := simd.MakeNibbleLookupTable([16]byte{0, 1, 0, 2, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0}) for iter := 0; iter < nIter; iter++ { - simd.Interleave8Unsafe(dst, src, src) + simd.PackedNibbleLookupUnsafe(dst, src, &table) } return int(dst[0]) } -func interleaveSubtaskFuture(dst, src []byte, nIter int) chan int { - future := make(chan int) - go func() { future <- interleaveSubtask(dst, src, nIter) }() - return future -} - -func multiInterleave(dsts, srcs [][]byte, cpus int, nJob int) { - sumFutures := make([]chan int, cpus) - shardSizeBase := nJob / cpus - shardRemainder := nJob - shardSizeBase*cpus - shardSizeP1 := shardSizeBase + 1 - var taskIdx int - for ; taskIdx < shardRemainder; taskIdx++ { - sumFutures[taskIdx] = interleaveSubtaskFuture(dsts[taskIdx], srcs[taskIdx], shardSizeP1) - } - for ; taskIdx < cpus; taskIdx++ { - sumFutures[taskIdx] = interleaveSubtaskFuture(dsts[taskIdx], srcs[taskIdx], shardSizeBase) - } - var sum int - for taskIdx = 0; taskIdx < cpus; taskIdx++ { - sum += <-sumFutures[taskIdx] +func packedNibbleLookupSimdSubtask(dst, src []byte, nIter int) int { + table := simd.MakeNibbleLookupTable([16]byte{0, 1, 0, 2, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0}) + for iter := 0; iter < nIter; iter++ { + simd.PackedNibbleLookup(dst, src, &table) } + return int(dst[0]) } -func benchmarkInterleave(cpus int, nSrcByte int, nJob int, b *testing.B) { - if cpus > runtime.NumCPU() { - b.Skipf("only have %v cpus", runtime.NumCPU()) +func packedNibbleLookupSlowSubtask(dst, src []byte, nIter int) int { + table := simd.MakeNibbleLookupTable([16]byte{0, 1, 0, 2, 8, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0}) + for iter := 0; iter < nIter; iter++ { + packedNibbleLookupSlow(dst, src, &table) } + return int(dst[0]) +} - srcSlices := make([][]byte, cpus) - dstSlices := make([][]byte, cpus) - nDstByte := nSrcByte * 2 - for ii := range srcSlices { - // Add 63 to prevent false sharing. - newArr := simd.MakeUnsafe(nSrcByte + 63) - for jj := 0; jj < nSrcByte; jj++ { - newArr[jj] = byte(jj * 3) - } - srcSlices[ii] = newArr[:nSrcByte] - newArr = simd.MakeUnsafe(nDstByte + 63) - dstSlices[ii] = newArr[:nDstByte] +func Benchmark_PackedNibbleLookup(b *testing.B) { + funcs := []taggedMultiBenchFunc{ + { + f: packedNibbleLookupUnsafeSubtask, + tag: "Unsafe", + }, + { + f: packedNibbleLookupSimdSubtask, + tag: "SIMD", + }, + { + f: packedNibbleLookupSlowSubtask, + tag: "Slow", + }, } - for i := 0; i < b.N; i++ { - multiInterleave(dstSlices, srcSlices, cpus, nJob) + for _, f := range funcs { + multiBenchmark(f.f, f.tag+"Short", 150, 75, 9999999, b) + multiBenchmark(f.f, f.tag+"Long", 249250621, 249250622/2, 50, b) } } -func Benchmark_InterleaveShort1(b *testing.B) { - benchmarkInterleave(1, 75, 9999999, b) -} - -func Benchmark_InterleaveShort4(b *testing.B) { - benchmarkInterleave(4, 75, 9999999, b) -} - -func Benchmark_InterleaveShortMax(b *testing.B) { - benchmarkInterleave(runtime.NumCPU(), 75, 9999999, b) -} - -func Benchmark_InterleaveLong1(b *testing.B) { - benchmarkInterleave(1, 124625311, 50, b) -} - -func Benchmark_InterleaveLong4(b *testing.B) { - benchmarkInterleave(4, 124625311, 50, b) -} - -func Benchmark_InterleaveLongMax(b *testing.B) { - benchmarkInterleave(runtime.NumCPU(), 124625311, 50, b) -} - func interleaveSlow(dst, even, odd []byte) { dstLen := len(dst) evenLen := (dstLen + 1) >> 1 @@ -650,80 +446,73 @@ func TestInterleave(t *testing.T) { } } -func reverse8Subtask(main []byte, nIter int) int { +/* +Benchmark results: + MacBook Pro (15-inch, 2016) + 2.7 GHz Intel Core i7, 16 GB 2133 MHz LPDDR3 + +Benchmark_Interleave/UnsafeShort1Cpu-8 10 124397567 ns/op +Benchmark_Interleave/UnsafeShortHalfCpu-8 50 33427370 ns/op +Benchmark_Interleave/UnsafeShortAllCpu-8 50 27522495 ns/op +Benchmark_Interleave/UnsafeLong1Cpu-8 1 1364788736 ns/op +Benchmark_Interleave/UnsafeLongHalfCpu-8 1 1194034677 ns/op +Benchmark_Interleave/UnsafeLongAllCpu-8 1 1240540994 ns/op +Benchmark_Interleave/SIMDShort1Cpu-8 10 143574503 ns/op +Benchmark_Interleave/SIMDShortHalfCpu-8 30 40429942 ns/op +Benchmark_Interleave/SIMDShortAllCpu-8 50 30500450 ns/op +Benchmark_Interleave/SIMDLong1Cpu-8 1 1281952758 ns/op +Benchmark_Interleave/SIMDLongHalfCpu-8 1 1210134670 ns/op +Benchmark_Interleave/SIMDLongAllCpu-8 1 1284786977 ns/op +Benchmark_Interleave/SlowShort1Cpu-8 2 880545817 ns/op +Benchmark_Interleave/SlowShortHalfCpu-8 5 234673823 ns/op +Benchmark_Interleave/SlowShortAllCpu-8 5 230332535 ns/op +Benchmark_Interleave/SlowLong1Cpu-8 1 6669283712 ns/op +Benchmark_Interleave/SlowLongHalfCpu-8 1 1860713287 ns/op +Benchmark_Interleave/SlowLongAllCpu-8 1 1807886977 ns/op +*/ + +func interleaveUnsafeSubtask(dst, src []byte, nIter int) int { for iter := 0; iter < nIter; iter++ { - simd.Reverse8Inplace(main) + simd.Interleave8Unsafe(dst, src, src) } - return int(main[0]) -} - -func reverse8SubtaskFuture(main []byte, nIter int) chan int { - future := make(chan int) - go func() { future <- reverse8Subtask(main, nIter) }() - return future + return int(dst[0]) } -func multiReverse8(mains [][]byte, cpus int, nJob int) { - sumFutures := make([]chan int, cpus) - shardSizeBase := nJob / cpus - shardRemainder := nJob - shardSizeBase*cpus - shardSizeP1 := shardSizeBase + 1 - var taskIdx int - for ; taskIdx < shardRemainder; taskIdx++ { - sumFutures[taskIdx] = reverse8SubtaskFuture(mains[taskIdx], shardSizeP1) - } - for ; taskIdx < cpus; taskIdx++ { - sumFutures[taskIdx] = reverse8SubtaskFuture(mains[taskIdx], shardSizeBase) - } - var sum int - for taskIdx = 0; taskIdx < cpus; taskIdx++ { - sum += <-sumFutures[taskIdx] +func interleaveSimdSubtask(dst, src []byte, nIter int) int { + for iter := 0; iter < nIter; iter++ { + simd.Interleave8(dst, src, src) } + return int(dst[0]) } -func benchmarkReverse8(cpus int, nByte int, nJob int, b *testing.B) { - if cpus > runtime.NumCPU() { - b.Skipf("only have %v cpus", runtime.NumCPU()) +func interleaveSlowSubtask(dst, src []byte, nIter int) int { + for iter := 0; iter < nIter; iter++ { + interleaveSlow(dst, src, src) } + return int(dst[0]) +} - mainSlices := make([][]byte, cpus) - for ii := range mainSlices { - // Add 63 to prevent false sharing. - newArr := simd.MakeUnsafe(nByte + 63) - for jj := 0; jj < nByte; jj++ { - newArr[jj] = byte(jj * 3) - } - mainSlices[ii] = newArr[:nByte] +func Benchmark_Interleave(b *testing.B) { + funcs := []taggedMultiBenchFunc{ + { + f: interleaveUnsafeSubtask, + tag: "Unsafe", + }, + { + f: interleaveSimdSubtask, + tag: "SIMD", + }, + { + f: interleaveSlowSubtask, + tag: "Slow", + }, } - for i := 0; i < b.N; i++ { - multiReverse8(mainSlices, cpus, nJob) + for _, f := range funcs { + multiBenchmark(f.f, f.tag+"Short", 150, 75, 9999999, b) + multiBenchmark(f.f, f.tag+"Long", 124625311*2, 124625311, 50, b) } } -func Benchmark_Reverse8Short1(b *testing.B) { - benchmarkReverse8(1, 75, 9999999, b) -} - -func Benchmark_Reverse8Short4(b *testing.B) { - benchmarkReverse8(4, 75, 9999999, b) -} - -func Benchmark_Reverse8ShortMax(b *testing.B) { - benchmarkReverse8(runtime.NumCPU(), 75, 9999999, b) -} - -func Benchmark_Reverse8Long1(b *testing.B) { - benchmarkReverse8(1, 249250621, 50, b) -} - -func Benchmark_Reverse8Long4(b *testing.B) { - benchmarkReverse8(4, 249250621, 50, b) -} - -func Benchmark_Reverse8LongMax(b *testing.B) { - benchmarkReverse8(runtime.NumCPU(), 249250621, 50, b) -} - func reverse8Slow(main []byte) { nByte := len(main) nByteDiv2 := nByte >> 1 @@ -781,3 +570,240 @@ func TestReverse8(t *testing.T) { } } } + +/* +Benchmark results: + MacBook Pro (15-inch, 2016) + 2.7 GHz Intel Core i7, 16 GB 2133 MHz LPDDR3 + +Benchmark_Reverse8Inplace/SIMDShort1Cpu-8 20 67121510 ns/op +Benchmark_Reverse8Inplace/SIMDShortHalfCpu-8 100 18891965 ns/op +Benchmark_Reverse8Inplace/SIMDShortAllCpu-8 100 16177224 ns/op +Benchmark_Reverse8Inplace/SIMDLong1Cpu-8 1 1115497033 ns/op +Benchmark_Reverse8Inplace/SIMDLongHalfCpu-8 2 885764257 ns/op +Benchmark_Reverse8Inplace/SIMDLongAllCpu-8 2 941948715 ns/op +Benchmark_Reverse8Inplace/SlowShort1Cpu-8 3 398662666 ns/op +Benchmark_Reverse8Inplace/SlowShortHalfCpu-8 10 105618119 ns/op +Benchmark_Reverse8Inplace/SlowShortAllCpu-8 10 184808267 ns/op +Benchmark_Reverse8Inplace/SlowLong1Cpu-8 1 5665556658 ns/op +Benchmark_Reverse8Inplace/SlowLongHalfCpu-8 1 1597487158 ns/op +Benchmark_Reverse8Inplace/SlowLongAllCpu-8 1 1616963854 ns/op +*/ + +func reverse8InplaceSimdSubtask(dst, src []byte, nIter int) int { + for iter := 0; iter < nIter; iter++ { + simd.Reverse8Inplace(dst) + } + return int(dst[0]) +} + +func reverse8InplaceSlowSubtask(dst, src []byte, nIter int) int { + for iter := 0; iter < nIter; iter++ { + reverse8Slow(dst) + } + return int(dst[0]) +} + +func Benchmark_Reverse8Inplace(b *testing.B) { + funcs := []taggedMultiBenchFunc{ + { + f: reverse8InplaceSimdSubtask, + tag: "SIMD", + }, + { + f: reverse8InplaceSlowSubtask, + tag: "Slow", + }, + } + for _, f := range funcs { + multiBenchmark(f.f, f.tag+"Short", 75, 0, 9999999, b) + multiBenchmark(f.f, f.tag+"Long", 249250621, 0, 50, b) + } +} + +func bitFromEveryByteSlow(dst, src []byte, bitIdx int) { + requiredDstLen := (len(src) + 7) >> 3 + if (len(dst) < requiredDstLen) || (uint(bitIdx) > 7) { + panic("BitFromEveryByte requires len(dst) >= (len(src) + 7) / 8 and 0 <= bitIdx < 8.") + } + dst = dst[:requiredDstLen] + for i := range dst { + dst[i] = 0 + } + for i, b := range src { + dst[i>>3] |= ((b >> uint32(bitIdx)) & 1) << uint32(i&7) + } +} + +func bitFromEveryByteFancyNoasm(dst, src []byte, bitIdx int) { + requiredDstLen := (len(src) + 7) >> 3 + if (len(dst) < requiredDstLen) || (uint(bitIdx) > 7) { + panic("BitFromEveryByte requires len(dst) >= (len(src) + 7) / 8 and 0 <= bitIdx < 8.") + } + nSrcFullWord := len(src) >> 3 + for i := 0; i < nSrcFullWord; i++ { + // Tried using a unsafeBytesToWords function on src in place of + // binary.LittleEndian.Uint64, and it barely made any difference. + srcWord := binary.LittleEndian.Uint64(src[i*8:i*8+8]) >> uint32(bitIdx) + + srcWord &= 0x101010101010101 + + // Before this operation, the bits of interest are at positions 0, 8, 16, + // 24, 32, 40, 48, and 56 in srcWord, and all other bits are guaranteed to + // be zero. + // + // Suppose the bit at position 16 is set, and no other bits are set. What + // does multiplication by the magic number 0x102040810204080 accomplish? + // Well, the magic number has bits set at positions 7, 14, 21, 28, 35, 42, + // 49, and 56. Multiplying by 2^16 is equivalent to left-shifting by 16, + // so the product has bits set at positions (7+16), (14+16), (21+16), + // (28+16), (35+16), (42+16), and the last two overflow off the top end. + // + // Now suppose the bits at position 0 and 16 are both set. The result is + // then the sum of (2^0) * + (2^16) * . The + // first term in this sum has bits set at positions 7, 14, ..., 56. + // Critically, *none of these bits overlap with the second term*, so there + // are no 'carries' when we add the two terms together. So the final + // product has bits set at positions 7, 14, 21, 23, 28, 30, 35, 37, 42, 44, + // 49, 51, 56, and 58. + // + // It turns out that none of the bits in any of the 8 terms of this product + // have overlapping positions. So the multiplication operation just makes + // a bunch of left-shifted copies of the original bits... and in + // particular, bits 56-63 of the product are: + // 56: original bit 0, left-shifted 56 + // 57: original bit 8, left-shifted 49 + // 58: original bit 16, left-shifted 42 + // 59: original bit 24, left-shifted 35 + // 60: original bit 32, left-shifted 28 + // 61: original bit 40, left-shifted 21 + // 62: original bit 48, left-shifted 14 + // 63: original bit 56, left-shifted 7 + // Thus, right-shifting the product by 56 gives us the byte we want. + // + // This is a very esoteric algorithm, and it doesn't have much direct + // application because all 64-bit x86 processors provide an assembly + // instruction which lets you do this >6 times as quickly. Occasionally + // the idea of using multiplication to create staggered left-shifted copies + // of bits does genuinely come in handy, though. + dst[i] = byte((srcWord * 0x102040810204080) >> 56) + } + if nSrcFullWord != requiredDstLen { + srcLast := src[nSrcFullWord*8:] + dstLast := dst[nSrcFullWord:requiredDstLen] + for i := range dstLast { + dstLast[i] = 0 + } + for i, b := range srcLast { + dstLast[i>>3] |= ((b >> uint32(bitIdx)) & 1) << uint32(i&7) + } + } +} + +func TestBitFromEveryByte(t *testing.T) { + maxSize := 500 + nIter := 200 + rand.Seed(1) + srcArr := make([]byte, maxSize) + dstArr1 := make([]byte, maxSize) + dstArr2 := make([]byte, maxSize) + dstArr3 := make([]byte, maxSize) + for iter := 0; iter < nIter; iter++ { + sliceStart := rand.Intn(maxSize) + srcSize := rand.Intn(maxSize - sliceStart) + srcSliceEnd := sliceStart + srcSize + srcSlice := srcArr[sliceStart:srcSliceEnd] + + minDstSize := (srcSize + 7) >> 3 + dstSliceEnd := sliceStart + minDstSize + dstSlice1 := dstArr1[sliceStart:dstSliceEnd] + dstSlice2 := dstArr2[sliceStart:dstSliceEnd] + dstSlice3 := dstArr3[sliceStart:dstSliceEnd] + + for ii := range srcSlice { + srcSlice[ii] = byte(rand.Intn(256)) + } + sentinel := byte(rand.Intn(256)) + dstArr2[dstSliceEnd] = sentinel + + bitIdx := rand.Intn(8) + bitFromEveryByteSlow(dstSlice1, srcSlice, bitIdx) + simd.BitFromEveryByte(dstSlice2, srcSlice, bitIdx) + assert.EQ(t, dstSlice1, dstSlice2) + assert.EQ(t, sentinel, dstArr2[dstSliceEnd]) + + // Also validate the assembly-free multiplication-based algorithm. + sentinel = byte(rand.Intn(256)) + dstArr3[dstSliceEnd] = sentinel + bitFromEveryByteFancyNoasm(dstSlice3, srcSlice, bitIdx) + assert.EQ(t, dstSlice1, dstSlice3) + assert.EQ(t, sentinel, dstArr3[dstSliceEnd]) + } +} + +/* +Benchmark results: + MacBook Pro (15-inch, 2016) + 2.7 GHz Intel Core i7, 16 GB 2133 MHz LPDDR3 + +Benchmark_BitFromEveryByte/SIMDLong1Cpu-8 200 6861450 ns/op +Benchmark_BitFromEveryByte/SIMDLongHalfCpu-8 200 7360937 ns/op +Benchmark_BitFromEveryByte/SIMDLongAllCpu-8 200 8846261 ns/op +Benchmark_BitFromEveryByte/FancyNoasmLong1Cpu-8 20 58756902 ns/op +Benchmark_BitFromEveryByte/FancyNoasmLongHalfCpu-8 100 17244847 ns/op +Benchmark_BitFromEveryByte/FancyNoasmLongAllCpu-8 100 16624282 ns/op +Benchmark_BitFromEveryByte/SlowLong1Cpu-8 3 422073091 ns/op +Benchmark_BitFromEveryByte/SlowLongHalfCpu-8 10 117732813 ns/op +Benchmark_BitFromEveryByte/SlowLongAllCpu-8 10 114903556 ns/op + +Notes: 1Cpu has higher throughput than HalfCpu/AllCpu on this test machine due +to L3 cache saturation: multiBenchmarkDstSrc makes each goroutine process its +own ~4 MB job, rather than splitting a single job into smaller pieces, and a +15-inch 2016 MacBook Pro has a 8 MB L3 cache. If you shrink the test size to +len(src)=400000, HalfCpu outperforms 1Cpu by the expected amount. + +I'm leaving this unusual benchmark result here since (i) it corresponds to how +we actually need to use the function, and (ii) this phenomenon is definitely +worth knowing about. +*/ + +func bitFromEveryByteSimdSubtask(dst, src []byte, nIter int) int { + for iter := 0; iter < nIter; iter++ { + simd.BitFromEveryByte(dst, src, 0) + } + return int(dst[0]) +} + +func bitFromEveryByteFancyNoasmSubtask(dst, src []byte, nIter int) int { + for iter := 0; iter < nIter; iter++ { + bitFromEveryByteFancyNoasm(dst, src, 0) + } + return int(dst[0]) +} + +func bitFromEveryByteSlowSubtask(dst, src []byte, nIter int) int { + for iter := 0; iter < nIter; iter++ { + bitFromEveryByteSlow(dst, src, 0) + } + return int(dst[0]) +} + +func Benchmark_BitFromEveryByte(b *testing.B) { + funcs := []taggedMultiBenchFunc{ + { + f: bitFromEveryByteSimdSubtask, + tag: "SIMD", + }, + { + f: bitFromEveryByteFancyNoasmSubtask, + tag: "FancyNoasm", + }, + { + f: bitFromEveryByteSlowSubtask, + tag: "Slow", + }, + } + for _, f := range funcs { + multiBenchmark(f.f, f.tag+"Long", 4091904/8, 4091904, 50, b) + } +} diff --git a/simd/xor_amd64.go b/simd/xor_amd64.go index c1a6d8d4..f09218d5 100644 --- a/simd/xor_amd64.go +++ b/simd/xor_amd64.go @@ -1,8 +1,10 @@ -// Code generated from " ../gtl/generate.py --prefix=Xor -DOPCHAR=^ --package=simd --output=xor_amd64.go bitwise_amd64.go.tpl ". DO NOT EDIT. -// Copyright 2018 GRAIL, Inc. All rights reserved. +// Code generated by "../gtl/generate.py --prefix=Xor -DOPCHAR=^ --package=simd --output=xor_amd64.go bitwise_amd64.go.tpl". DO NOT EDIT. + +// Copyright 2021 GRAIL, Inc. All rights reserved. // Use of this source code is governed by the Apache-2.0 // license that can be found in the LICENSE file. +//go:build amd64 && !appengine // +build amd64,!appengine package simd @@ -12,7 +14,7 @@ import ( "unsafe" ) -// XorUnsafeInplace sets main[pos] := arg[pos] ^ main[pos] for every position +// XorUnsafeInplace sets main[pos] := main[pos] ^ arg[pos] for every position // in main[]. // // WARNING: This is a function designed to be used in inner loops, which makes @@ -30,18 +32,18 @@ import ( // changed. func XorUnsafeInplace(main, arg []byte) { mainLen := len(main) - argHeader := (*reflect.SliceHeader)(unsafe.Pointer(&arg)) - mainHeader := (*reflect.SliceHeader)(unsafe.Pointer(&main)) - argWordsIter := unsafe.Pointer(argHeader.Data) - mainWordsIter := unsafe.Pointer(mainHeader.Data) + argData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&arg)).Data) + mainData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&main)).Data) + argWordsIter := argData + mainWordsIter := mainData if mainLen > 2*BytesPerWord { nWordMinus2 := (mainLen - BytesPerWord - 1) >> Log2BytesPerWord for widx := 0; widx < nWordMinus2; widx++ { mainWord := *((*uintptr)(mainWordsIter)) argWord := *((*uintptr)(argWordsIter)) *((*uintptr)(mainWordsIter)) = mainWord ^ argWord - mainWordsIter = unsafe.Pointer(uintptr(mainWordsIter) + BytesPerWord) - argWordsIter = unsafe.Pointer(uintptr(argWordsIter) + BytesPerWord) + mainWordsIter = unsafe.Add(mainWordsIter, BytesPerWord) + argWordsIter = unsafe.Add(argWordsIter, BytesPerWord) } } else if mainLen <= BytesPerWord { mainWord := *((*uintptr)(mainWordsIter)) @@ -56,8 +58,8 @@ func XorUnsafeInplace(main, arg []byte) { mainWord1 := *((*uintptr)(mainWordsIter)) argWord1 := *((*uintptr)(argWordsIter)) finalOffset := uintptr(mainLen - BytesPerWord) - mainFinalWordPtr := unsafe.Pointer(mainHeader.Data + finalOffset) - argFinalWordPtr := unsafe.Pointer(argHeader.Data + finalOffset) + mainFinalWordPtr := unsafe.Add(mainData, finalOffset) + argFinalWordPtr := unsafe.Add(argData, finalOffset) mainWord2 := *((*uintptr)(mainFinalWordPtr)) argWord2 := *((*uintptr)(argFinalWordPtr)) *((*uintptr)(mainWordsIter)) = mainWord1 ^ argWord1 @@ -82,25 +84,25 @@ func XorInplace(main, arg []byte) { } return } - argHeader := (*reflect.SliceHeader)(unsafe.Pointer(&arg)) - mainHeader := (*reflect.SliceHeader)(unsafe.Pointer(&main)) - argWordsIter := unsafe.Pointer(argHeader.Data) - mainWordsIter := unsafe.Pointer(mainHeader.Data) + argData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&arg)).Data) + mainData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&main)).Data) + argWordsIter := argData + mainWordsIter := mainData if mainLen > 2*BytesPerWord { nWordMinus2 := (mainLen - BytesPerWord - 1) >> Log2BytesPerWord for widx := 0; widx < nWordMinus2; widx++ { mainWord := *((*uintptr)(mainWordsIter)) argWord := *((*uintptr)(argWordsIter)) *((*uintptr)(mainWordsIter)) = mainWord ^ argWord - mainWordsIter = unsafe.Pointer(uintptr(mainWordsIter) + BytesPerWord) - argWordsIter = unsafe.Pointer(uintptr(argWordsIter) + BytesPerWord) + mainWordsIter = unsafe.Add(mainWordsIter, BytesPerWord) + argWordsIter = unsafe.Add(argWordsIter, BytesPerWord) } } mainWord1 := *((*uintptr)(mainWordsIter)) argWord1 := *((*uintptr)(argWordsIter)) finalOffset := uintptr(mainLen - BytesPerWord) - mainFinalWordPtr := unsafe.Pointer(mainHeader.Data + finalOffset) - argFinalWordPtr := unsafe.Pointer(argHeader.Data + finalOffset) + mainFinalWordPtr := unsafe.Add(mainData, finalOffset) + argFinalWordPtr := unsafe.Add(argData, finalOffset) mainWord2 := *((*uintptr)(mainFinalWordPtr)) argWord2 := *((*uintptr)(argFinalWordPtr)) *((*uintptr)(mainWordsIter)) = mainWord1 ^ argWord1 @@ -135,9 +137,9 @@ func XorUnsafe(dst, src1, src2 []byte) { src1Word := *((*uintptr)(src1Iter)) src2Word := *((*uintptr)(src2Iter)) *((*uintptr)(dstIter)) = src1Word ^ src2Word - src1Iter = unsafe.Pointer(uintptr(src1Iter) + BytesPerWord) - src2Iter = unsafe.Pointer(uintptr(src2Iter) + BytesPerWord) - dstIter = unsafe.Pointer(uintptr(dstIter) + BytesPerWord) + src1Iter = unsafe.Add(src1Iter, BytesPerWord) + src2Iter = unsafe.Add(src2Iter, BytesPerWord) + dstIter = unsafe.Add(dstIter, BytesPerWord) } } @@ -154,27 +156,27 @@ func Xor(dst, src1, src2 []byte) { } return } - src1Header := (*reflect.SliceHeader)(unsafe.Pointer(&src1)) - src2Header := (*reflect.SliceHeader)(unsafe.Pointer(&src2)) - dstHeader := (*reflect.SliceHeader)(unsafe.Pointer(&dst)) + src1Data := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&src1)).Data) + src2Data := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&src2)).Data) + dstData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&dst)).Data) nWordMinus1 := (dstLen - 1) >> Log2BytesPerWord - src1Iter := unsafe.Pointer(src1Header.Data) - src2Iter := unsafe.Pointer(src2Header.Data) - dstIter := unsafe.Pointer(dstHeader.Data) + src1Iter := src1Data + src2Iter := src2Data + dstIter := dstData for widx := 0; widx < nWordMinus1; widx++ { src1Word := *((*uintptr)(src1Iter)) src2Word := *((*uintptr)(src2Iter)) *((*uintptr)(dstIter)) = src1Word ^ src2Word - src1Iter = unsafe.Pointer(uintptr(src1Iter) + BytesPerWord) - src2Iter = unsafe.Pointer(uintptr(src2Iter) + BytesPerWord) - dstIter = unsafe.Pointer(uintptr(dstIter) + BytesPerWord) + src1Iter = unsafe.Add(src1Iter, BytesPerWord) + src2Iter = unsafe.Add(src2Iter, BytesPerWord) + dstIter = unsafe.Add(dstIter, BytesPerWord) } // No store-forwarding problem here. finalOffset := uintptr(dstLen - BytesPerWord) - src1Iter = unsafe.Pointer(src1Header.Data + finalOffset) - src2Iter = unsafe.Pointer(src2Header.Data + finalOffset) - dstIter = unsafe.Pointer(dstHeader.Data + finalOffset) + src1Iter = unsafe.Add(src1Data, finalOffset) + src2Iter = unsafe.Add(src2Data, finalOffset) + dstIter = unsafe.Add(dstData, finalOffset) src1Word := *((*uintptr)(src1Iter)) src2Word := *((*uintptr)(src2Iter)) *((*uintptr)(dstIter)) = src1Word ^ src2Word @@ -197,14 +199,14 @@ func Xor(dst, src1, src2 []byte) { func XorConst8UnsafeInplace(main []byte, val byte) { mainLen := len(main) argWord := 0x101010101010101 * uintptr(val) - mainHeader := (*reflect.SliceHeader)(unsafe.Pointer(&main)) - mainWordsIter := unsafe.Pointer(mainHeader.Data) + mainData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&main)).Data) + mainWordsIter := mainData if mainLen > 2*BytesPerWord { nWordMinus2 := (mainLen - BytesPerWord - 1) >> Log2BytesPerWord for widx := 0; widx < nWordMinus2; widx++ { mainWord := *((*uintptr)(mainWordsIter)) *((*uintptr)(mainWordsIter)) = mainWord ^ argWord - mainWordsIter = unsafe.Pointer(uintptr(mainWordsIter) + BytesPerWord) + mainWordsIter = unsafe.Add(mainWordsIter, BytesPerWord) } } else if mainLen <= BytesPerWord { mainWord := *((*uintptr)(mainWordsIter)) @@ -213,7 +215,7 @@ func XorConst8UnsafeInplace(main []byte, val byte) { } mainWord1 := *((*uintptr)(mainWordsIter)) finalOffset := uintptr(mainLen - BytesPerWord) - mainFinalWordPtr := unsafe.Pointer(mainHeader.Data + finalOffset) + mainFinalWordPtr := unsafe.Add(mainData, finalOffset) mainWord2 := *((*uintptr)(mainFinalWordPtr)) *((*uintptr)(mainWordsIter)) = mainWord1 ^ argWord *((*uintptr)(mainFinalWordPtr)) = mainWord2 ^ argWord @@ -230,19 +232,19 @@ func XorConst8Inplace(main []byte, val byte) { return } argWord := 0x101010101010101 * uintptr(val) - mainHeader := (*reflect.SliceHeader)(unsafe.Pointer(&main)) - mainWordsIter := unsafe.Pointer(mainHeader.Data) + mainData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&main)).Data) + mainWordsIter := mainData if mainLen > 2*BytesPerWord { nWordMinus2 := (mainLen - BytesPerWord - 1) >> Log2BytesPerWord for widx := 0; widx < nWordMinus2; widx++ { mainWord := *((*uintptr)(mainWordsIter)) *((*uintptr)(mainWordsIter)) = mainWord ^ argWord - mainWordsIter = unsafe.Pointer(uintptr(mainWordsIter) + BytesPerWord) + mainWordsIter = unsafe.Add(mainWordsIter, BytesPerWord) } } mainWord1 := *((*uintptr)(mainWordsIter)) finalOffset := uintptr(mainLen - BytesPerWord) - mainFinalWordPtr := unsafe.Pointer(mainHeader.Data + finalOffset) + mainFinalWordPtr := unsafe.Add(mainData, finalOffset) mainWord2 := *((*uintptr)(mainFinalWordPtr)) *((*uintptr)(mainWordsIter)) = mainWord1 ^ argWord *((*uintptr)(mainFinalWordPtr)) = mainWord2 ^ argWord @@ -274,8 +276,8 @@ func XorConst8Unsafe(dst, src []byte, val byte) { for widx := 0; widx < nWord; widx++ { srcWord := *((*uintptr)(srcIter)) *((*uintptr)(dstIter)) = srcWord ^ argWord - srcIter = unsafe.Pointer(uintptr(srcIter) + BytesPerWord) - dstIter = unsafe.Pointer(uintptr(dstIter) + BytesPerWord) + srcIter = unsafe.Add(srcIter, BytesPerWord) + dstIter = unsafe.Add(dstIter, BytesPerWord) } } @@ -292,22 +294,22 @@ func XorConst8(dst, src []byte, val byte) { } return } - srcHeader := (*reflect.SliceHeader)(unsafe.Pointer(&src)) - dstHeader := (*reflect.SliceHeader)(unsafe.Pointer(&dst)) + srcData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&src)).Data) + dstData := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&dst)).Data) nWordMinus1 := (dstLen - 1) >> Log2BytesPerWord argWord := 0x101010101010101 * uintptr(val) - srcIter := unsafe.Pointer(srcHeader.Data) - dstIter := unsafe.Pointer(dstHeader.Data) + srcIter := unsafe.Pointer(srcData) + dstIter := unsafe.Pointer(dstData) for widx := 0; widx < nWordMinus1; widx++ { srcWord := *((*uintptr)(srcIter)) *((*uintptr)(dstIter)) = srcWord ^ argWord - srcIter = unsafe.Pointer(uintptr(srcIter) + BytesPerWord) - dstIter = unsafe.Pointer(uintptr(dstIter) + BytesPerWord) + srcIter = unsafe.Add(srcIter, BytesPerWord) + dstIter = unsafe.Add(dstIter, BytesPerWord) } finalOffset := uintptr(dstLen - BytesPerWord) - srcIter = unsafe.Pointer(srcHeader.Data + finalOffset) - dstIter = unsafe.Pointer(dstHeader.Data + finalOffset) + srcIter = unsafe.Add(srcData, finalOffset) + dstIter = unsafe.Add(dstData, finalOffset) srcWord := *((*uintptr)(srcIter)) *((*uintptr)(dstIter)) = srcWord ^ argWord } diff --git a/simd/xor_generic.go b/simd/xor_generic.go new file mode 100644 index 00000000..96466a32 --- /dev/null +++ b/simd/xor_generic.go @@ -0,0 +1,135 @@ +// Code generated by " ../gtl/generate.py --prefix=Xor -DOPCHAR=^ --package=simd --output=xor_generic.go bitwise_generic.go.tpl ". DO NOT EDIT. + +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +// +build !amd64 appengine + +package simd + +// XorUnsafeInplace sets main[pos] := main[pos] ^ arg[pos] for every position +// in main[]. +// +// WARNING: This is a function designed to be used in inner loops, which makes +// assumptions about length and capacity which aren't checked at runtime. Use +// the safe version of this function when that's a problem. +// Assumptions #2-3 are always satisfied when the last +// potentially-size-increasing operation on arg[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(), and the same is true for main[]. +// +// 1. len(arg) and len(main) must be equal. +// +// 2. Capacities are at least RoundUpPow2(len(main) + 1, bytesPerVec). +// +// 3. The caller does not care if a few bytes past the end of main[] are +// changed. +func XorUnsafeInplace(main, arg []byte) { + for i, x := range main { + main[i] = x ^ arg[i] + } +} + +// XorInplace sets main[pos] := main[pos] ^ arg[pos] for every position in +// main[]. It panics if slice lengths don't match. +func XorInplace(main, arg []byte) { + if len(arg) != len(main) { + panic("XorInplace() requires len(arg) == len(main).") + } + for i, x := range main { + main[i] = x ^ arg[i] + } +} + +// XorUnsafe sets dst[pos] := src1[pos] ^ src2[pos] for every position in dst. +// +// WARNING: This is a function designed to be used in inner loops, which makes +// assumptions about length and capacity which aren't checked at runtime. Use +// the safe version of this function when that's a problem. +// Assumptions #2-3 are always satisfied when the last +// potentially-size-increasing operation on src1[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(), and the same is true for src2[] and dst[]. +// +// 1. len(src1), len(src2), and len(dst) must be equal. +// +// 2. Capacities are at least RoundUpPow2(len(dst) + 1, bytesPerVec). +// +// 3. The caller does not care if a few bytes past the end of dst[] are +// changed. +func XorUnsafe(dst, src1, src2 []byte) { + for i, x := range src1 { + dst[i] = x ^ src2[i] + } +} + +// Xor sets dst[pos] := src1[pos] ^ src2[pos] for every position in dst. It +// panics if slice lengths don't match. +func Xor(dst, src1, src2 []byte) { + dstLen := len(dst) + if (len(src1) != dstLen) || (len(src2) != dstLen) { + panic("Xor() requires len(src1) == len(src2) == len(dst).") + } + for i, x := range src1 { + dst[i] = x ^ src2[i] + } +} + +// XorConst8UnsafeInplace sets main[pos] := main[pos] ^ val for every position +// in main[]. +// +// WARNING: This is a function designed to be used in inner loops, which makes +// assumptions about length and capacity which aren't checked at runtime. Use +// the safe version of this function when that's a problem. +// These assumptions are always satisfied when the last +// potentially-size-increasing operation on main[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(). +// +// 1. cap(main) is at least RoundUpPow2(len(main) + 1, bytesPerVec). +// +// 2. The caller does not care if a few bytes past the end of main[] are +// changed. +func XorConst8UnsafeInplace(main []byte, val byte) { + for i, x := range main { + main[i] = x ^ val + } +} + +// XorConst8Inplace sets main[pos] := main[pos] ^ val for every position in +// main[]. +func XorConst8Inplace(main []byte, val byte) { + for i, x := range main { + main[i] = x ^ val + } +} + +// XorConst8Unsafe sets dst[pos] := src[pos] ^ val for every position in dst. +// +// WARNING: This is a function designed to be used in inner loops, which makes +// assumptions about length and capacity which aren't checked at runtime. Use +// the safe version of this function when that's a problem. +// Assumptions #2-3 are always satisfied when the last +// potentially-size-increasing operation on src[] is {Re}makeUnsafe(), +// ResizeUnsafe(), or XcapUnsafe(), and the same is true for dst[]. +// +// 1. len(src) and len(dst) must be equal. +// +// 2. Capacities are at least RoundUpPow2(len(dst) + 1, bytesPerVec). +// +// 3. The caller does not care if a few bytes past the end of dst[] are +// changed. +func XorConst8Unsafe(dst, src []byte, val byte) { + for i, x := range src { + dst[i] = x ^ val + } +} + +// XorConst8 sets dst[pos] := src[pos] ^ val for every position in dst. It +// panics if slice lengths don't match. +func XorConst8(dst, src []byte, val byte) { + if len(src) != len(dst) { + panic("XorConst8() requires len(src) == len(dst).") + } + for i, x := range src { + dst[i] = x ^ val + } +} diff --git a/stateio/reader.go b/stateio/reader.go new file mode 100644 index 00000000..9e03ab00 --- /dev/null +++ b/stateio/reader.go @@ -0,0 +1,133 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package stateio + +import ( + "errors" + "io" + "os" + + "github.com/grailbio/base/logio" +) + +// ErrCorrupt is returned when a corrupted log is encountered. +var ErrCorrupt = errors.New("corrupt state entry") + +// Restore restores the state from the last epoch in the state log +// read by the provided reader and the given limit. The returned +// state may be nil if no snapshot was defined for the epoch. +func Restore(r io.ReaderAt, limit int64) (state []byte, epoch uint64, updates *Reader, err error) { + if limit == 0 { + return nil, 0, nil, nil + } + off, err := logio.Rewind(r, limit) + if err != nil { + return + } + reader := &readerAtReader{r, off} + log := logio.NewReader(reader, off) + entry, err := log.Read() + if err != nil { + return + } + var ( + typ uint8 + data []byte + ok bool + ) + typ, epoch, data, ok = parse(entry) + if !ok { + // TODO(marius): let the user deal with this? perhaps by providing + // a utility function in package logio to skip corrupted entries. + err = ErrCorrupt + return + } + if typ == entrySnap { + // Special case: the first entry is a snapshot, so we need to restore + // the correct epoch. + epoch = uint64(off) + } else { + reader.off = int64(epoch) + log.Reset(reader, reader.off) + entry, err = log.Read() + if err != nil { + return + } + typ, _, data, ok = parse(entry) + if !ok { + err = ErrCorrupt + return + } + } + + if typ == entrySnap { + state = append([]byte{}, data...) + } else { + reader.off = int64(epoch) + log.Reset(reader, reader.off) + } + updates = &Reader{log, epoch} + return +} + +// RestoreFile is a convenience function that restores the file from +// the provided os file. +func RestoreFile(file *os.File) (state []byte, epoch uint64, updates *Reader, err error) { + off, err := file.Seek(0, io.SeekEnd) + if err != nil { + return nil, 0, nil, err + } + state, epoch, updates, err = Restore(file, off) + if _, e := file.Seek(off, io.SeekStart); e != nil && err == nil { + err = e + } + return +} + +// Reader reads a single epoch state updates. +type Reader struct { + log *logio.Reader + offset uint64 +} + +// Read returns the next state update entry. Read returns ErrCorrupt +// if a corrupted log entry was encountered, or logio.ErrCorrupt is a +// corrupt log file was encountered. In the latter case, the user may +// skip the corrupted entry by issuing another read. +func (r *Reader) Read() ([]byte, error) { + if r == nil { + return nil, io.EOF + } + entry, err := r.log.Read() + if err != nil { + return nil, err + } + typ, offset, data, ok := parse(entry) + if !ok { + return nil, ErrCorrupt + } + if typ == entrySnap { + return nil, io.EOF + } + if offset != r.offset { + // We should always encounter a new snapshot before an offset change. + return nil, ErrCorrupt + } + return data, nil +} + +type readerAtReader struct { + r io.ReaderAt + off int64 +} + +func (r *readerAtReader) Read(p []byte) (n int, err error) { + n, err = r.r.ReadAt(p, r.off) + if err == io.ErrUnexpectedEOF { + err = nil + } + r.off += int64(n) + return n, err +} diff --git a/stateio/stateio.go b/stateio/stateio.go new file mode 100644 index 00000000..b783d193 --- /dev/null +++ b/stateio/stateio.go @@ -0,0 +1,67 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +// Package stateio implements persistent state mechanism based on log +// files that interleave indexed state snapshots with state updates. Users maintain +// state by interaction with Reader and Writer objects. A typical application +// should reconcile to the current state before writing new log entries. New +// log entries should be written only after they are known to apply cleanly. +// (In the following examples, error handling is left as an exercise to the +// reader): +// +// file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE) +// state, epoch, updates, err := stateio.RestoreFile(file) +// application.Reset(state) +// for { +// entry, err := updates.Read() +// if err == io.EOF { +// break +// } +// application.Update(entry) +// } +// +// w, err := stateio.NewFileWriter(file) +// +// // Apply new state updates: +// var update []byte = ... +// application.Update(update) +// err := w.Update(update) +// +// // Produce a new snapshot: +// var snap []byte = application.Snapshot() +// w.Snapshot(snap) +// +// Data format +// +// State files maintained by package stateio builds on package logio. +// The log file contains a sequence of epochs, each beginning with a +// state snapshot (with the exception of the first epoch, which does +// not need a state snapshot). Each entry is prefixed with the type +// of the entry as well as the the epoch to which the entry belongs. +// Snapshot entries are are prefixed with the previous epoch, so that +// log files can efficiently rewound. +// +// TODO(marius): support log file truncation +package stateio + +import "encoding/binary" + +const ( + entryUpdate = 1 + iota + entrySnap + + entryMax +) + +func parse(entry []byte) (typ uint8, epoch uint64, data []byte, ok bool) { + if len(entry) < 9 { + ok = false + return + } + typ = entry[0] + epoch = binary.LittleEndian.Uint64(entry[1:]) + data = entry[9:] + ok = typ < entryMax + return +} diff --git a/stateio/stateio_test.go b/stateio/stateio_test.go new file mode 100644 index 00000000..0b49a08e --- /dev/null +++ b/stateio/stateio_test.go @@ -0,0 +1,120 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package stateio + +import ( + "io" + "io/ioutil" + "math/rand" + "os" + "testing" +) + +func TestStateIO(t *testing.T) { + file, cleanup := tempfile(t) + defer cleanup() + + w := NewWriter(file, 0, 0) + must(t, w.Update(entry(1))) + must(t, w.Update(entry(2))) + + state, _, updates, err := RestoreFile(file) + must(t, err) + if state != nil { + t.Fatal("non-nil snapshot") + } + b, err := updates.Read() + must(t, err) + mustEntry(t, 1, b) + b, err = updates.Read() + must(t, err) + mustEntry(t, 2, b) + mustEOF(t, updates) + + must(t, file.Truncate(0)) + state, _, updates, err = Restore(file, 0) + if state != nil { + t.Fatal("non-nil snapshot") + } + must(t, err) + mustEOF(t, updates) + + const N = 100 + w, err = NewFileWriter(file) + must(t, err) + for i := 0; i < N; i++ { + if i%10 == 0 { + must(t, w.Snapshot(entry(i))) + } else { + must(t, w.Update(entry(i))) + } + if i%5 == 0 { + // Reset the writer to make sure that writers + // resume properly. + must(t, err) + w, err = NewFileWriter(file) + must(t, err) + } + } + + state, _, updates, err = RestoreFile(file) + must(t, err) + mustEntry(t, N-10, state) + for i := N - 9; i < N; i++ { + entry, err := updates.Read() + must(t, err) + mustEntry(t, i, entry) + } + mustEOF(t, updates) +} + +func entry(n int) []byte { + b := make([]byte, n) + r := rand.New(rand.NewSource(int64(n))) + for i := range b { + b[i] = byte(r.Intn(256)) + } + return b +} + +func mustEntry(t *testing.T, n int, b []byte) { + t.Helper() + if got, want := len(b), n; got != want { + t.Fatalf("got %v, want %v", got, want) + } + r := rand.New(rand.NewSource(int64(n))) + for i := range b { + if got, want := int(b[i]), r.Intn(256); got != want { + t.Fatalf("byte %d: got %v, want %v", i, got, want) + } + } +} + +func must(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } +} + +func mustEOF(t *testing.T, r *Reader) { + t.Helper() + _, err := r.Read() + if got, want := err, io.EOF; got != want { + t.Fatalf("got %v, want %v", got, want) + } +} + +func tempfile(t *testing.T) (file *os.File, cleanup func()) { + t.Helper() + var err error + file, err = ioutil.TempFile("", "") + if err != nil { + t.Fatal(err) + } + os.Remove(file.Name()) + cleanup = func() { file.Close() } + return +} diff --git a/stateio/writer.go b/stateio/writer.go new file mode 100644 index 00000000..ba862741 --- /dev/null +++ b/stateio/writer.go @@ -0,0 +1,97 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package stateio + +import ( + "encoding/binary" + "io" + "os" + + "github.com/grailbio/base/logio" +) + +type syncer interface { + Sync() error +} + +// Writer writes snapshots and update entries to an underlying log stream. +type Writer struct { + syncer syncer + log *logio.Writer + epoch uint64 +} + +// NewWriter initializes and returns a new state log writer which +// writes to the stream w, which is positioned at the provided +// offset. The provided epoch must be the current epoch of the log +// file. If the provided io.Writer is also a syncer: +// +// type Syncer interface { +// Sync() error +// } +// +// Then Sync() is called (and errors returned) after each log entry +// has been written. +func NewWriter(w io.Writer, off int64, epoch uint64) *Writer { + wr := &Writer{log: logio.NewWriter(w, off), epoch: epoch} + if s, ok := w.(syncer); ok { + wr.syncer = s + } + return wr +} + +// NewFileWriter initializes a state log writer from the provided +// os file. The file's contents is committed to stable storage +// after each log write. +func NewFileWriter(file *os.File) (*Writer, error) { + off, err := file.Seek(0, io.SeekEnd) + if err != nil { + return nil, err + } + _, epoch, _, err := Restore(file, off) + if err != nil { + return nil, err + } + off, err = file.Seek(0, io.SeekEnd) + if err != nil { + return nil, err + } + return NewWriter(file, off, epoch), nil +} + +// Snapshot writes a new snapshot to the state log. +// Subsequent updates are based on this snapshot. +func (w *Writer) Snapshot(snap []byte) error { + off := w.log.Tell() + entry := make([]byte, len(snap)+9) + entry[0] = entrySnap + binary.LittleEndian.PutUint64(entry[1:], w.epoch) + copy(entry[9:], snap) + if err := w.log.Append(entry); err != nil { + return err + } + w.epoch = uint64(off) + return w.sync() +} + +// Update writes a new state update to the log. The update +// refers to the last snapshot written. +func (w *Writer) Update(update []byte) error { + entry := make([]byte, 9+len(update)) + entry[0] = entryUpdate + binary.LittleEndian.PutUint64(entry[1:], uint64(w.epoch)) + copy(entry[9:], update) + if err := w.log.Append(entry); err != nil { + return err + } + return w.sync() +} + +func (w *Writer) sync() error { + if w.syncer == nil { + return nil + } + return w.syncer.Sync() +} diff --git a/status/http.go b/status/http.go index 22389088..7e68f3ce 100644 --- a/status/http.go +++ b/status/http.go @@ -5,10 +5,7 @@ package status import ( - "fmt" "net/http" - "text/tabwriter" - "time" ) type statusHandler struct{ *Status } @@ -21,17 +18,6 @@ func Handler(s *Status) http.Handler { func (h statusHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain") - now := time.Now() - for _, group := range h.Status.Groups() { - v := group.Value() - tw := tabwriter.NewWriter(w, 2, 4, 2, ' ', 0) - fmt.Fprintf(tw, "%s: %s\n", v.Title, v.Status) - for _, task := range group.Tasks() { - v := task.Value() - elapsed := now.Sub(v.Begin) - elapsed -= elapsed % time.Second - fmt.Fprintf(tw, "\t%s:\t%s\t%s\n", v.Title, v.Status, elapsed) - } - tw.Flush() - } + // If writing fails, there's not much we can do. + _ = h.Status.Marshal(w) } diff --git a/status/status.go b/status/status.go index 51e35cdb..65799e77 100644 --- a/status/status.go +++ b/status/status.go @@ -17,8 +17,10 @@ package status import ( "fmt" + "io" "sort" "sync" + "text/tabwriter" "time" ) @@ -301,6 +303,30 @@ func (s *Status) Groups() []*Group { return groups } +// Marshal writes s in a human-readable format to w. +func (s *Status) Marshal(w io.Writer) error { + now := time.Now() + for _, group := range s.Groups() { + v := group.Value() + tw := tabwriter.NewWriter(w, 2, 4, 2, ' ', 0) + if _, err := fmt.Fprintf(tw, "%s: %s\n", v.Title, v.Status); err != nil { + return err + } + for _, task := range group.Tasks() { + v := task.Value() + elapsed := now.Sub(v.Begin) + elapsed -= elapsed % time.Second + if _, err := fmt.Fprintf(tw, "\t%s:\t%s\t%s\n", v.Title, v.Status, elapsed); err != nil { + return err + } + } + if err := tw.Flush(); err != nil { + return err + } + } + return nil +} + func (s *Status) notify() { if s == nil { return diff --git a/status/stream.go b/status/stream.go index 5e832f48..2aa3f3ae 100644 --- a/status/stream.go +++ b/status/stream.go @@ -300,19 +300,9 @@ func (r Reporter) displaySimple(w io.Writer, status *Status) { nextReport = time.After(minSimpleReportingPeriod - elapsed) continue } - now := time.Now() - for _, group := range status.Groups() { - v := group.Value() - tw := tabwriter.NewWriter(w, 2, 4, 2, ' ', 0) - fmt.Fprintf(tw, "%s: %s\n", v.Title, v.Status) - for _, task := range group.Tasks() { - v := task.Value() - elapsed := now.Sub(v.Begin) - elapsed -= elapsed % time.Second - fmt.Fprintf(tw, "\t%s:\t%s\t%s\n", v.Title, v.Status, elapsed) - } - tw.Flush() - } + // If writing fails, there's not much we can do besides try again next + // time. + _ = status.Marshal(w) lastReport = time.Now() } } diff --git a/status/term.go b/status/term.go index 66daff5a..de73f243 100644 --- a/status/term.go +++ b/status/term.go @@ -109,11 +109,11 @@ func (t *term) Clear(w io.Writer) { // Dim returns the current dimensions of the terminal. func (t *term) Dim() (width, height int) { - ws, err := getWinsize(t.fd) + ws, err := unix.IoctlGetWinsize(int(t.fd), unix.TIOCGWINSZ) if err != nil { return 80, 20 } - return int(ws.Width), int(ws.Height) + return int(ws.Col), int(ws.Row) } func isTerminal(fd uintptr) bool { @@ -121,21 +121,3 @@ func isTerminal(fd uintptr) bool { _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), termios, uintptr(unsafe.Pointer(&t)), 0, 0, 0) return err == 0 } - -type winsize struct { - Height, Width uint16 - // We pad the struct to give us plenty of headroom. (In practice, - // darwin only has 8 additional bytes.) The proper way to do this - // would be to use cgo to get the proper struct definition, but I'd - // like to avoid this if possible. - Pad [128]byte -} - -func getWinsize(fd uintptr) (*winsize, error) { - w := new(winsize) - _, _, err := unix.Syscall(unix.SYS_IOCTL, fd, uintptr(unix.TIOCGWINSZ), uintptr(unsafe.Pointer(w))) - if err == 0 { - return w, nil - } - return w, err -} diff --git a/stress/oom/oom.go b/stress/oom/oom.go new file mode 100644 index 00000000..25c86b49 --- /dev/null +++ b/stress/oom/oom.go @@ -0,0 +1,70 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +// +build darwin dragonfly freebsd linux openbsd solaris netbsd + +// Package oom contains a single function to trigger Linux kernel OOMs. +package oom + +import ( + "bufio" + "log" + "os" + "strconv" + "strings" + + "golang.org/x/sys/unix" +) + +// Do attempts to OOM the process by allocating up +// to the provided number of bytes. Do never returns. +func Do(size int) { + log.Print("oom: allocating ", size, " bytes") + var ( + prot = unix.PROT_READ | unix.PROT_WRITE + flag = unix.MAP_PRIVATE | unix.MAP_ANON | unix.MAP_NORESERVE + ) + b, err := unix.Mmap(-1, 0, size, prot, flag) + if err != nil { + log.Fatal(err) + } + stride := os.Getpagesize() + // Touch each page so that the process gradually allocates + // more and more memory. + for i := 0; i < size; i += stride { + b[i] = 1 + } + log.Fatal("failed to OOM process") +} + +// Try attempts to OOM based on the available physical memory and +// default overcommit heuristics. Try never returns. +func Try() { + f, err := os.Open("/proc/meminfo") + if err != nil { + log.Fatal(err) + } + scan := bufio.NewScanner(f) + for scan.Scan() { + fields := strings.Fields(scan.Text()) + if len(fields) < 2 { + continue + } + if fields[0] != "MemTotal:" { + continue + } + if fields[2] != "kB" { + log.Fatalf("expected kilobytes, got %s", fields[2]) + } + kb, err := strconv.ParseInt(fields[1], 0, 64) + if err != nil { + log.Fatalf("parsing %q: %v", fields[2], err) + } + Do(int(kb << 10)) + } + if err := scan.Err(); err != nil { + log.Fatal(err) + } + log.Fatal("MemTotal not found in /proc/meminfo") +} diff --git a/sync/ctxsync/cond.go b/sync/ctxsync/cond.go new file mode 100644 index 00000000..34e11cdd --- /dev/null +++ b/sync/ctxsync/cond.go @@ -0,0 +1,59 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package ctxsync + +import ( + "context" + "sync" +) + +// A Cond is a condition variable that implements +// a context-aware Wait. +type Cond struct { + l sync.Locker + waitc chan struct{} +} + +// NewCond returns a new ContextCond based on Locker l. +func NewCond(l sync.Locker) *Cond { + return &Cond{l: l} +} + +// Broadcast notifies waiters of a state change. Broadcast must only +// be called while the cond's lock is held. +func (c *Cond) Broadcast() { + if c.waitc != nil { + close(c.waitc) + c.waitc = nil + } +} + +// Done returns a channel that is closed after the next broadcast of +// this Cond. Done must be called with the Cond's lock held; the lock +// is released before Done returns. +func (c *Cond) Done() <-chan struct{} { + if c.waitc == nil { + c.waitc = make(chan struct{}) + } + waitc := c.waitc + c.l.Unlock() + return waitc +} + +// Wait returns after the next call to Broadcast, or if the context +// is complete. The context's lock must be held when calling Wait. +// An error returns with the context's error if the context completes +// while waiting. +func (c *Cond) Wait(ctx context.Context) error { + waitc := c.Done() + var err error + select { + case <-waitc: + case <-ctx.Done(): + err = ctx.Err() + } + c.l.Lock() + return err +} diff --git a/sync/ctxsync/cond_test.go b/sync/ctxsync/cond_test.go new file mode 100644 index 00000000..6d342c70 --- /dev/null +++ b/sync/ctxsync/cond_test.go @@ -0,0 +1,58 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package ctxsync + +import ( + "context" + "sync" + "testing" +) + +func TestContextCond(t *testing.T) { + var ( + mu sync.Mutex + cond = NewCond(&mu) + start, done sync.WaitGroup + ) + const N = 100 + start.Add(N) + done.Add(N) + errs := make([]error, N) + for i := 0; i < N; i++ { + go func(idx int) { + mu.Lock() + start.Done() + if err := cond.Wait(context.Background()); err != nil { + errs[idx] = err + } + mu.Unlock() + done.Done() + }(i) + } + + start.Wait() + mu.Lock() + cond.Broadcast() + mu.Unlock() + done.Wait() + for _, err := range errs { + if err != nil { + t.Fatal(err) + } + } +} + +func TestContextCondErr(t *testing.T) { + var ( + mu sync.Mutex + cond = NewCond(&mu) + ) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + mu.Lock() + if got, want := cond.Wait(ctx), context.Canceled; got != want { + t.Errorf("got %v, want %v", got, want) + } +} diff --git a/sync/ctxsync/mutex.go b/sync/ctxsync/mutex.go new file mode 100644 index 00000000..0a118cd3 --- /dev/null +++ b/sync/ctxsync/mutex.go @@ -0,0 +1,49 @@ +// Copyright 2022 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package ctxsync + +import ( + "context" + "sync" + + "github.com/grailbio/base/errors" +) + +// Mutex is a context-aware mutex. It must not be copied. +// The zero value is ready to use. +type Mutex struct { + initOnce sync.Once + lockCh chan struct{} +} + +// Lock attempts to exclusively lock m. If the m is already locked, it will +// wait until it is unlocked. If ctx is canceled before the lock can be taken, +// Lock will not take the lock, and a non-nil error is returned. +func (m *Mutex) Lock(ctx context.Context) error { + m.init() + select { + case m.lockCh <- struct{}{}: + return nil + case <-ctx.Done(): + return errors.E(ctx.Err(), "waiting for lock") + } +} + +// Unlock unlocks m. It must be called exactly once iff Lock returns nil. +// Unlock panics if it is called while m is not locked. +func (m *Mutex) Unlock() { + m.init() + select { + case <-m.lockCh: + default: + panic("Unlock called on mutex that is not locked") + } +} + +func (m *Mutex) init() { + m.initOnce.Do(func() { + m.lockCh = make(chan struct{}, 1) + }) +} diff --git a/sync/ctxsync/mutex_test.go b/sync/ctxsync/mutex_test.go new file mode 100644 index 00000000..46b49871 --- /dev/null +++ b/sync/ctxsync/mutex_test.go @@ -0,0 +1,119 @@ +// Copyright 2022 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package ctxsync_test + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/grailbio/base/errors" + "github.com/grailbio/base/sync/ctxsync" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +// TestExclusion verifies that a mutex provides basic mutually exclusive +// access: only one goroutine can have it locked at a time. +func TestExclusion(t *testing.T) { + var ( + mu ctxsync.Mutex + wg sync.WaitGroup + x int + ) + require.NoError(t, mu.Lock(context.Background())) + wg.Add(1) + go func() { + defer wg.Done() + if err := mu.Lock(context.Background()); err != nil { + return + } + x = 100 + mu.Unlock() + }() + for i := 1; i <= 10; i++ { + // Verify that nothing penetrates our lock and changes x unexpectedly. + assert.Equal(t, i-1, x) + x = i + time.Sleep(1 * time.Millisecond) + } + mu.Unlock() + wg.Wait() + assert.Equal(t, 100, x) +} + +// TestOtherGoroutineUnlock verifies that locked mutexes can be unlocked by a +// different goroutine, and that the lock still provides mutual exclusion +// across them. +func TestOtherGoroutineUnlock(t *testing.T) { + const N = 100 + var ( + mu ctxsync.Mutex + g errgroup.Group + chLocked = make(chan struct{}) + x int + ) + // Run N goroutines each trying to lock the mutex. Run another N + // goroutines, one of which is selected to unlock the mutex after each time + // it is successfully locked. + for i := 0; i < N; i++ { + g.Go(func() error { + if err := mu.Lock(context.Background()); err != nil { + return err + } + x++ + chLocked <- struct{}{} + return nil + }) + g.Go(func() error { + <-chLocked + x++ + mu.Unlock() + return nil + }) + } + assert.NoError(t, g.Wait()) + // We run N*2 goroutines, each incrementing x by 1 while the lock is held. + assert.Equal(t, N*2, x) +} + +// TestCancel verifies that canceling the Lock context causes the attempt to +// lock the mutex to fail and return an error of kind errors.Canceled. +func TestCancel(t *testing.T) { + var ( + mu ctxsync.Mutex + wg sync.WaitGroup + errWaiter error + ) + require.NoError(t, mu.Lock(context.Background())) + ctx, cancel := context.WithCancel(context.Background()) + wg.Add(1) + go func() { + defer wg.Done() + if errWaiter = mu.Lock(ctx); errWaiter != nil { + return + } + mu.Unlock() + }() + cancel() + wg.Wait() + mu.Unlock() + // Verify that we can still lock and unlock after the canceled attempt. + if assert.NoError(t, mu.Lock(context.Background())) { + mu.Unlock() + } + // Verify that Lock returned the expected non-nil error from the canceled + // attempt. + assert.True(t, errors.Is(errors.Canceled, errWaiter), "expected errors.Canceled") +} + +// TestUnlockUnlocked verifies that unlocking a mutex that is not locked +// panics. +func TestUnlockUnlocked(t *testing.T) { + var mu ctxsync.Mutex + assert.Panics(t, func() { mu.Unlock() }) +} diff --git a/sync/loadingcache/ctxloadingcache/ctxloadingcache.go b/sync/loadingcache/ctxloadingcache/ctxloadingcache.go new file mode 100644 index 00000000..43334c17 --- /dev/null +++ b/sync/loadingcache/ctxloadingcache/ctxloadingcache.go @@ -0,0 +1,28 @@ +package ctxloadingcache + +import ( + "context" + + "github.com/grailbio/base/sync/loadingcache" +) + +type contextKeyType struct{} + +var contextKey contextKeyType + +// With returns a child context which, when passed to Value, gets and sets in m. +// Callers now control the lifetime of the returned context's cache and can clear it. +func With(ctx context.Context, m *loadingcache.Map) context.Context { + return context.WithValue(ctx, contextKey, m) +} + +// Value retrieves a *loadingcache.Value that's linked to a cache, if ctx was returned by With. +// If ctx didn't come from an earlier With call, there's no linked cache, so caching will be +// disabled by returning nil (which callers don't need to check, because nil is usable). +func Value(ctx context.Context, key interface{}) *loadingcache.Value { + var m *loadingcache.Map + if v := ctx.Value(contextKey); v != nil { + m = v.(*loadingcache.Map) + } + return m.GetOrCreate(key) +} diff --git a/sync/loadingcache/ctxloadingcache/ctxloadingcache_test.go b/sync/loadingcache/ctxloadingcache/ctxloadingcache_test.go new file mode 100644 index 00000000..499938fe --- /dev/null +++ b/sync/loadingcache/ctxloadingcache/ctxloadingcache_test.go @@ -0,0 +1,82 @@ +package ctxloadingcache_test + +import ( + "context" + "testing" + + "github.com/grailbio/base/sync/loadingcache" + "github.com/grailbio/base/sync/loadingcache/ctxloadingcache" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// These tests check same-ness of *Value and rely on Value's unit testing for complete coverage. +// Theoretically an implementation could be correct but not preserve same-ness; if we change +// our implementation to do that, we can update these tests. + +func TestBasic(t *testing.T) { + var cache loadingcache.Map + ctx := ctxloadingcache.With(context.Background(), &cache) + var i1 int + require.NoError(t, ctxloadingcache.Value(ctx, "key"). + GetOrLoad(ctx, &i1, func(ctx context.Context, opts *loadingcache.LoadOpts) error { + i1 = 1 + opts.CacheForever() + return nil + })) + assert.Equal(t, 1, i1) + var i2 int + require.NoError(t, ctxloadingcache.Value(ctx, "key"). + GetOrLoad(ctx, &i2, func(ctx context.Context, opts *loadingcache.LoadOpts) error { + panic("computing i2") + })) + assert.Equal(t, 1, i2) +} + +func TestKeys(t *testing.T) { + var cache loadingcache.Map + ctx := ctxloadingcache.With(context.Background(), &cache) + + type ( + testKeyA struct{} + testKeyB struct{} + ) + vA := ctxloadingcache.Value(ctx, testKeyA{}) + assert.NotSame(t, vA, ctxloadingcache.Value(ctx, testKeyB{})) + assert.Same(t, vA, ctxloadingcache.Value(ctx, testKeyA{})) +} + +func TestReuseKey(t *testing.T) { + var cache loadingcache.Map + ctx1 := ctxloadingcache.With(context.Background(), &cache) + + type testKey struct{} + cache1 := ctxloadingcache.Value(ctx1, testKey{}) + + ctx2 := context.WithValue(ctx1, testKey{}, "not a cache") + assert.Same(t, cache1, ctxloadingcache.Value(ctx2, testKey{})) +} + +func TestDeleteAll(t *testing.T) { + var cache loadingcache.Map + ctx1 := ctxloadingcache.With(context.Background(), &cache) + ctx2, cancel2 := context.WithCancel(ctx1) + defer cancel2() + + type ( + testKeyA struct{} + testKeyB struct{} + ) + vA := ctxloadingcache.Value(ctx1, testKeyA{}) + vB := ctxloadingcache.Value(ctx1, testKeyB{}) + assert.NotSame(t, vA, vB) + assert.Same(t, vA, ctxloadingcache.Value(ctx1, testKeyA{})) + assert.Same(t, vA, ctxloadingcache.Value(ctx2, testKeyA{})) + assert.Same(t, vB, ctxloadingcache.Value(ctx2, testKeyB{})) + + cache.DeleteAll() + + assert.NotSame(t, vA, ctxloadingcache.Value(ctx1, testKeyA{})) + assert.NotSame(t, vA, ctxloadingcache.Value(ctx2, testKeyA{})) + assert.NotSame(t, vB, ctxloadingcache.Value(ctx2, testKeyB{})) +} diff --git a/sync/loadingcache/map.go b/sync/loadingcache/map.go new file mode 100644 index 00000000..c1c8ae85 --- /dev/null +++ b/sync/loadingcache/map.go @@ -0,0 +1,42 @@ +package loadingcache + +import "sync" + +// Map is a keyed collection of Values. Map{} is ready to use. +// (*Map)(nil) is valid and never caches or shares results for any key. +// Maps are concurrency-safe. They must not be copied. +// +// Implementation notes: +// +// Compared to sync.Map, this Map is not sophisticated in terms of optimizing for high concurrency +// with disjoint sets of keys. It could probably be improved. +// +// Loading on-demand, without repeated value computation, is reminiscent of Guava's LoadingCache: +// https://github.com/google/guava/wiki/CachesExplained +type Map struct { + mu sync.Mutex + m map[interface{}]*Value +} + +// GetOrCreate returns an existing or new Value associated with key. +// Note: If m == nil, returns nil, a never-caching Value. +func (m *Map) GetOrCreate(key interface{}) *Value { + if m == nil { + return nil + } + m.mu.Lock() + defer m.mu.Unlock() + if m.m == nil { + m.m = make(map[interface{}]*Value) + } + if _, ok := m.m[key]; !ok { + m.m[key] = new(Value) + } + return m.m[key] +} + +func (m *Map) DeleteAll() { + m.mu.Lock() + defer m.mu.Unlock() + m.m = nil +} diff --git a/sync/loadingcache/value.go b/sync/loadingcache/value.go new file mode 100644 index 00000000..5460fa06 --- /dev/null +++ b/sync/loadingcache/value.go @@ -0,0 +1,150 @@ +package loadingcache + +import ( + "context" + "fmt" + "reflect" + "runtime/debug" + "sync" + "time" + + "github.com/grailbio/base/must" +) + +type ( + // Value manages the loading (calculation) and storing of a cache value. It's designed for use + // cases where loading is slow. Concurrency is well-supported: + // 1. Only one load is in progress at a time, even if concurrent callers request the value. + // 2. Cancellation is respected for loading: a caller's load function is invoked with their + // context. If it respects cancellation and returns an error immediately, the caller's + // GetOrLoad does, too. + // 3. Cancellation is respected for waiting: if a caller's context is canceled while they're + // waiting for another in-progress load (not their own), the caller's GetOrLoad returns + // immediately with the cancellation error. + // Simpler mechanisms (like just locking a sync.Mutex when starting computation) don't achieve + // all of these (in the mutex example, cancellation is not respected while waiting on Lock()). + // + // The original use case was reading rarely-changing data via RPC while letting users + // cancel the operation (Ctrl-C in their terminal). Very different uses (very fast computations + // or extremely high concurrency) may not work as well; they're at least not tested. + // Memory overhead may be quite large if small values are cached. + // + // Value{} is ready to use. (*Value)(nil) is valid and just never caches or shares a result + // (every get loads). Value must not be copied. + // + // Time-based expiration is optional. See LoadFunc and LoadOpts. + Value struct { + // init supports at-most-once initialization of subsequent fields. + init sync.Once + // c is both a semaphore (limit 1) and storage for cache state. + c chan state + // now is used for faking time in tests. + now func() time.Time + } + state struct { + // dataPtr is non-zero if there's a previously-computed value (which may be expired). + dataPtr reflect.Value + // expiresAt is the time of expiration (according to now) when dataPtr is non-zero. + // expiresAt.IsZero() means no expiration (infinite caching). + expiresAt time.Time + } + // LoadFunc computes a value. It should respect cancellation (return with cancellation error). + LoadFunc func(context.Context, *LoadOpts) error + // LoadOpts configures how long a LoadFunc result should be cached. + // Cache settings overwrite each other; last write wins. Default is don't cache at all. + // Callers should synchronize their calls themselves if using multiple goroutines (this is + // not expected). + LoadOpts struct { + // validFor is cache time if > 0, disables cache if == 0, infinite cache time if < 0. + validFor time.Duration + } +) + +// GetOrLoad either copies a cached value to dataPtr or runs load and then copies dataPtr's value +// into the cache. A properly-written load writes dataPtr's value. Example: +// +// var result string +// err := value.GetOrLoad(ctx, &result, func(ctx context.Context, opts *loadingcache.LoadOpts) error { +// var err error +// result, err = doExpensiveThing(ctx) +// opts.CacheFor(time.Hour) +// return err +// }) +// +// dataPtr must be a pointer to a copyable value (slice, int, struct without Mutex, etc.). +// +// Value does not cache errors. Consider caching a value containing an error, like +// struct{result int; err error} if desired. +func (v *Value) GetOrLoad(ctx context.Context, dataPtr interface{}, load LoadFunc) error { + ptrVal := reflect.ValueOf(dataPtr) + must.True(ptrVal.Kind() == reflect.Ptr, "%v", dataPtr) + // TODO: Check copyable? + + if v == nil { + return runNoPanic(func() error { + var opts LoadOpts + return load(ctx, &opts) + }) + } + + v.init.Do(func() { + if v.c == nil { + v.c = make(chan state, 1) + v.c <- state{} + } + if v.now == nil { + v.now = time.Now + } + }) + + var state state + select { + case <-ctx.Done(): + return ctx.Err() + case state = <-v.c: + } + defer func() { v.c <- state }() + + if state.dataPtr.IsValid() { + if state.expiresAt.IsZero() || v.now().Before(state.expiresAt) { + ptrVal.Elem().Set(state.dataPtr.Elem()) + return nil + } + state.dataPtr = reflect.Value{} + } + + var opts LoadOpts + // TODO: Consider calling load() directly rather than via runNoPanic(). + // A previous implementation needed to intercept panics to handle internal state correctly. + // That's no longer true, so we can avoid tampering with callers' panic traces. + err := runNoPanic(func() error { return load(ctx, &opts) }) + if err == nil && opts.validFor != 0 { + state.dataPtr = ptrVal + if opts.validFor > 0 { + state.expiresAt = v.now().Add(opts.validFor) + } else { + state.expiresAt = time.Time{} + } + } + return err +} + +func runNoPanic(f func() error) (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("cache: recovered panic: %v, stack:\n%v", r, string(debug.Stack())) + } + }() + return f() +} + +// setClock is for testing. It must be called before any GetOrLoad and is not concurrency-safe. +func (v *Value) setClock(now func() time.Time) { + if v == nil { + return + } + v.now = now +} + +func (o *LoadOpts) CacheFor(d time.Duration) { o.validFor = d } +func (o *LoadOpts) CacheForever() { o.validFor = -1 } diff --git a/sync/loadingcache/value_test.go b/sync/loadingcache/value_test.go new file mode 100644 index 00000000..49e9fda1 --- /dev/null +++ b/sync/loadingcache/value_test.go @@ -0,0 +1,235 @@ +package loadingcache + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/grailbio/base/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const recentUnixTimestamp = 1600000000 // 2020-09-13 12:26:40 +0000 UTC + +func TestValueExpiration(t *testing.T) { + var ( + ctx = context.Background() + v Value + clock fakeClock + ) + v.setClock(clock.Now) + + clock.Set(time.Unix(recentUnixTimestamp, 0)) + var v1 int + require.NoError(t, v.GetOrLoad(ctx, &v1, func(_ context.Context, opts *LoadOpts) error { + clock.Add(2 * time.Hour) + v1 = 1 + opts.CacheFor(time.Hour) + return nil + })) + assert.Equal(t, 1, v1) + + clock.Add(5 * time.Minute) + var v2 int + require.NoError(t, v.GetOrLoad(ctx, &v2, loadFail)) + assert.Equal(t, 1, v1) + assert.Equal(t, 1, v2) + + clock.Add(time.Hour) + var v3 int + require.NoError(t, v.GetOrLoad(ctx, &v3, func(_ context.Context, opts *LoadOpts) error { + v3 = 3 + opts.CacheForever() + return nil + })) + assert.Equal(t, 1, v1) + assert.Equal(t, 1, v2) + assert.Equal(t, 3, v3) + + clock.Add(10000 * time.Hour) + var v4 int + assert.NoError(t, v.GetOrLoad(ctx, &v4, loadFail)) + assert.Equal(t, 1, v1) + assert.Equal(t, 1, v2) + assert.Equal(t, 3, v3) + assert.Equal(t, 3, v4) +} + +func TestValueExpiration0(t *testing.T) { + var ( + ctx = context.Background() + v Value + clock fakeClock + ) + v.setClock(clock.Now) + + clock.Set(time.Unix(recentUnixTimestamp, 0)) + var v1 int + require.NoError(t, v.GetOrLoad(ctx, &v1, func(_ context.Context, opts *LoadOpts) error { + v1 = 1 + return nil + })) + assert.Equal(t, 1, v1) + + // Run v2 at the same time as v1. It should not get a cached result because v1's cache time was 0. + var v2 int + require.NoError(t, v.GetOrLoad(ctx, &v2, func(_ context.Context, opts *LoadOpts) error { + v2 = 2 + opts.CacheFor(time.Hour) + return nil + })) + assert.Equal(t, 1, v1) + assert.Equal(t, 2, v2) +} + +func TestValueNil(t *testing.T) { + var ( + ctx = context.Background() + v *Value + clock fakeClock + ) + v.setClock(clock.Now) + + clock.Set(time.Unix(recentUnixTimestamp, 0)) + var v1 int + require.NoError(t, v.GetOrLoad(ctx, &v1, func(_ context.Context, opts *LoadOpts) error { + clock.Add(2 * time.Hour) + v1 = 1 + opts.CacheForever() + return nil + })) + assert.Equal(t, 1, v1) + + var v2 int + assert.Error(t, v.GetOrLoad(ctx, &v2, loadFail)) + assert.Equal(t, 1, v1) + + clock.Add(time.Hour) + var v3 int + require.NoError(t, v.GetOrLoad(ctx, &v3, func(_ context.Context, opts *LoadOpts) error { + v3 = 3 + opts.CacheForever() + return nil + })) + assert.Equal(t, 1, v1) + assert.Equal(t, 3, v3) +} + +func TestValueCancellation(t *testing.T) { + var ( + v Value + clock fakeClock + ) + v.setClock(clock.Now) + clock.Set(time.Unix(recentUnixTimestamp, 0)) + const cacheDuration = time.Minute + + type participant struct { + cancel context.CancelFunc + // participant waits for these before proceeding. + waitGet, waitLoad chan<- struct{} + // participant returns these signals of its progress. + loadStarted <-chan struct{} + result <-chan error + } + makeParticipant := func(dst *int, loaded int) participant { + ctx, cancel := context.WithCancel(context.Background()) + var ( + waitGet = make(chan struct{}) + waitLoad = make(chan struct{}) + loadStarted = make(chan struct{}) + result = make(chan error) + ) + go func() { + <-waitGet + result <- v.GetOrLoad(ctx, dst, func(ctx context.Context, opts *LoadOpts) error { + close(loadStarted) + select { + case <-ctx.Done(): + return ctx.Err() + case <-waitLoad: + *dst = loaded + opts.CacheFor(cacheDuration) + return nil + } + }) + }() + return participant{cancel, waitGet, waitLoad, loadStarted, result} + } + + // Start participant 1 and wait for its cache load to start. + var v1 int + p1 := makeParticipant(&v1, 1) + close(p1.waitGet) + <-p1.loadStarted + + // Start participant 2, then cancel its context and wait for its error. + var v2 int + p2 := makeParticipant(&v2, 2) + p2.waitGet <- struct{}{} + p2.cancel() + err2 := <-p2.result + assert.True(t, errors.Is(errors.Canceled, err2), "got: %v", err2) + + // Start participant 3, then cancel participant 1 and wait for 3 to start loading. + var v3 int + p3 := makeParticipant(&v3, 3) + p3.waitGet <- struct{}{} + p1.cancel() + <-p3.loadStarted + err1 := <-p1.result + assert.True(t, errors.Is(errors.Canceled, err1), "got: %v", err1) + + // Start participant 4 later (according to clock). + var v4 int + p4 := makeParticipant(&v4, 4) + clock.Add(time.Second) + p4.waitGet <- struct{}{} + + // Let participant 3 finish loading and wait for results. + close(p3.waitLoad) + require.NoError(t, <-p3.result) + require.NoError(t, <-p4.result) + assert.Equal(t, 3, v3) + assert.Equal(t, 3, v4) // Got cached result. + + // Start participant 5 past cache time so it recomputes. + var v5 int + p5 := makeParticipant(&v5, 5) + clock.Add(cacheDuration * 2) + p5.waitGet <- struct{}{} + close(p5.waitLoad) + require.NoError(t, <-p5.result) + assert.Equal(t, 3, v3) + assert.Equal(t, 3, v4) + assert.Equal(t, 5, v5) +} + +type fakeClock struct { + mu sync.Mutex + now time.Time +} + +func (c *fakeClock) Now() time.Time { + c.mu.Lock() + defer c.mu.Unlock() + return c.now +} + +func (c *fakeClock) Set(now time.Time) { + c.mu.Lock() + defer c.mu.Unlock() + c.now = now +} + +func (c *fakeClock) Add(d time.Duration) { + c.mu.Lock() + defer c.mu.Unlock() + c.now = c.now.Add(d) +} + +func loadFail(context.Context, *LoadOpts) error { + panic("unexpected load") +} diff --git a/sync/multierror/multierror.go b/sync/multierror/multierror.go index 2dc81dc1..68a0d71f 100644 --- a/sync/multierror/multierror.go +++ b/sync/multierror/multierror.go @@ -6,102 +6,98 @@ import ( "sync" ) -// MultiError is a mechanism for capturing errors from parallel -// go-routines. Usage: -// -// errs := NewMultiError(3) -// do := func(f foo) error {...} -// for foo in range foos { -// go errs.capture(do(foo)) -// } -// // Wait for completion -// -// Will gather all errors returned in a MultiError, which in turn will -// behave as a normal error. -type MultiError struct { - errs []error - count int64 - mu sync.Mutex -} - -// NewMultiError creates a new MultiError struct. -func NewMultiError(max int) *MultiError { - return &MultiError{errs: make([]error, 0, max), mu: sync.Mutex{}} -} - -func (me *MultiError) add(err error) { - if len(me.errs) == cap(me.errs) { - me.count++ - return - } - - me.errs = append(me.errs, err) +type multiError struct { + errs []error + dropped int } -// Add captures an error from a go-routine and adds it to the MultiError. -func (me *MultiError) Add(err error) *MultiError { - if err == nil || me == nil { - return me - } - - me.mu.Lock() - defer me.mu.Unlock() - - multi, ok := err.(*MultiError) - if ok { - // Aggregate if it is a multierror. - for _, e := range multi.errs { - me.add(e) +// Error returns a string describing the multiple errors represented by e. +func (e multiError) Error() string { + switch len(e.errs) { + case 0, 1: + panic("invalid multiError") + default: + var b strings.Builder + b.WriteString("[") + for i, err := range e.errs { + if i > 0 { + b.WriteString("\n") + } + b.WriteString(err.Error()) + } + b.WriteString("]") + if e.dropped > 0 { + fmt.Fprintf(&b, " [plus %d other error(s)]", e.dropped) } - me.count += multi.count - return me + return b.String() } - - me.add(err) - - return me } -// Error returns a string version of the MultiError. This implements the error -// interface. -func (me *MultiError) Error() string { - if me == nil { - return "" - } - - me.mu.Lock() - defer me.mu.Unlock() +// Builder captures errors from parallel goroutines. +// +// Example usage: +// var ( +// errs = multierror.NewBuilder(3) +// wg sync.WaitGroup +// ) +// for _, foo := range foos { +// wg.Add(1) +// go func() { +// defer wg.Done() +// errs.Add(someWork(foo)) +// }() +// } +// wg.Wait() +// if err := errs.Err(); err != nil { +// // handle err +// } +type Builder struct { + mu sync.Mutex // mu guards all fields + errs []error + dropped int +} - if len(me.errs) == 0 { - return "" - } +func NewBuilder(max int) *Builder { + return &Builder{errs: make([]error, 0, max)} +} - if len(me.errs) == 1 { - return me.errs[0].Error() +// Add adds an error to b. b must be non-nil. +func (b *Builder) Add(err error) { + if err == nil { + return } - s := make([]string, len(me.errs)) - for i, e := range me.errs { - s[i] = e.Error() - } - errs := strings.Join(s, "\n") + b.mu.Lock() + defer b.mu.Unlock() - if me.count == 0 { - return fmt.Sprintf("[%s]", errs) + if len(b.errs) == cap(b.errs) { + b.dropped++ + return } - - return fmt.Sprintf("[%s] [plus %d other error(s)]", errs, me.count) + b.errs = append(b.errs, err) } -// ErrorOrNil returns nil if no errors were captured, itself otherwise. -func (me *MultiError) ErrorOrNil() error { - if me == nil { +// Err returns an error combining all the errors that were already Add-ed. Otherwise returns nil. +// b may be nil. +func (b *Builder) Err() error { + if b == nil { return nil } - - if len(me.errs) == 0 { + b.mu.Lock() + defer b.mu.Unlock() + switch len(b.errs) { + case 0: return nil + case 1: + // TODO: This silently ignores b.dropped which is bad because it may be non-zero. + // Maybe we should make multiError{*, 1} legal. Or, maybe forbid max < 2. + return b.errs[0] + default: + return multiError{ + // Sharing b.errs is ok because multiError doesn't mutate or append and Builder + // only appends. + errs: b.errs, + dropped: b.dropped, + } } - - return me } diff --git a/sync/multierror/multierror_test.go b/sync/multierror/multierror_test.go index 32ebac69..f409f382 100644 --- a/sync/multierror/multierror_test.go +++ b/sync/multierror/multierror_test.go @@ -6,6 +6,12 @@ import ( ) func TestMultiError(t *testing.T) { + me2a := NewBuilder(2) + me2a.Add(errors.New("a")) + me1ab := NewBuilder(1) + me1ab.Add(errors.New("a")) + me1ab.Add(errors.New("b")) + for _, test := range []struct { errs []error expected error @@ -24,23 +30,18 @@ func TestMultiError(t *testing.T) { 2] [plus 1 other error(s)]`), }, { - []error{errors.New("1"), NewMultiError(2).Add(errors.New("a"))}, + []error{errors.New("1"), me2a.Err()}, errors.New(`[1 a]`), }, - { - []error{errors.New("1"), NewMultiError(1).Add(errors.New("a")).Add(errors.New("b"))}, - errors.New(`[1 -a] [plus 1 other error(s)]`), - }, } { - errs := NewMultiError(2) + errs := NewBuilder(2) for _, e := range test.errs { errs.Add(e) } - got := errs.ErrorOrNil() + got := errs.Err() if test.expected == nil && got == nil { continue @@ -51,7 +52,7 @@ a] [plus 1 other error(s)]`), } if test.expected.Error() != got.Error() { - t.Fatalf("error mismatch: %v vs %v", test.expected, got) + t.Fatalf("error mismatch: %q vs %q", test.expected, got) } } } diff --git a/sync/once/BUILD b/sync/once/BUILD new file mode 100644 index 00000000..2dda3d7c --- /dev/null +++ b/sync/once/BUILD @@ -0,0 +1,14 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "go_default_library", + srcs = ["once.go"], + importpath = "github.com/grailbio/base/sync/once", + visibility = ["//visibility:public"], +) + +go_test( + name = "go_default_test", + srcs = ["once_test.go"], + embed = [":go_default_library"], +) diff --git a/sync/once/once.go b/sync/once/once.go new file mode 100644 index 00000000..d800db6d --- /dev/null +++ b/sync/once/once.go @@ -0,0 +1,69 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +// Package once contains utilities for managing actions that +// must be performed exactly once. +package once + +import ( + "sync" + "sync/atomic" +) + +// Task manages a computation that must be run at most once. +// It's similar to sync.Once, except it also handles and returns errors. +type Task struct { + mu sync.Mutex + done uint32 + err error +} + +// Do run the function do at most once. Successive invocations of Do +// guarantee exactly one invocation of the function do. Do returns +// the error of do's invocation. +func (o *Task) Do(do func() error) error { + if atomic.LoadUint32(&o.done) == 1 { + return o.err + } + o.mu.Lock() + defer o.mu.Unlock() + if atomic.LoadUint32(&o.done) == 0 { + o.err = do() + atomic.StoreUint32(&o.done, 1) + } + return o.err +} + +// Done returns whether the task is done. +func (o *Task) Done() bool { + o.mu.Lock() + defer o.mu.Unlock() + return 1 == atomic.LoadUint32(&o.done) +} + +// Reset resets the task effectively making it possible for `Do` to invoke the underlying do func again. +// Reset will only reset the task if it was already completed. +func (o *Task) Reset() { + o.mu.Lock() + defer o.mu.Unlock() + atomic.CompareAndSwapUint32(&o.done, 1, 0) +} + +// Map coordinates actions that must happen exactly once, keyed +// by user-defined keys. +type Map sync.Map + +// Perform the provided action named by a key. Do invokes the action +// exactly once for each key, and returns any errors produced by the +// provided action. +func (m *Map) Do(key interface{}, do func() error) error { + taskv, _ := (*sync.Map)(m).LoadOrStore(key, new(Task)) + task := taskv.(*Task) + return task.Do(do) +} + +// Forget forgets past computations associated with the provided key. +func (m *Map) Forget(key interface{}) { + (*sync.Map)(m).Delete(key) +} diff --git a/sync/once/once_test.go b/sync/once/once_test.go new file mode 100644 index 00000000..d7274041 --- /dev/null +++ b/sync/once/once_test.go @@ -0,0 +1,84 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package once + +import ( + "errors" + "sync/atomic" + "testing" + + "github.com/grailbio/base/traverse" +) + +func TestTaskOnceConcurrency(t *testing.T) { + const ( + N = 10 + resets = 2 + ) + var ( + o Task + count int32 + ) + for r := 0; r < resets; r++ { + err := traverse.Each(N, func(_ int) error { + return o.Do(func() error { + atomic.AddInt32(&count, 1) + return nil + }) + }) + if err != nil { + t.Fatal(err) + } + if got, want := atomic.LoadInt32(&count), int32(r+1); got != want { + t.Errorf("got %v, want %v", got, want) + } + if got, want := o.Done(), true; got != want { + t.Errorf("got %v, want %v", got, want) + } + o.Reset() + if got, want := o.Done(), false; got != want { + t.Errorf("got %v, want %v", got, want) + } + } +} + +func TestMapOnceConcurrency(t *testing.T) { + const N = 10 + var ( + once Map + count uint32 + ) + err := traverse.Each(N, func(jobIdx int) error { + return once.Do(123, func() error { + atomic.AddUint32(&count, 1) + return nil + }) + }) + if err != nil { + t.Fatal(err) + } + if got, want := count, uint32(1); got != want { + t.Errorf("got %v, want %v", got, want) + } +} + +func TestTaskOnceError(t *testing.T) { + var ( + once Map + expected = errors.New("expected error") + ) + err := once.Do(123, func() error { return expected }) + if got, want := err, expected; got != want { + t.Errorf("got %v, want %v", got, want) + } + err = once.Do(123, func() error { panic("should not be called") }) + if got, want := err, expected; got != want { + t.Errorf("got %v, want %v", got, want) + } + err = once.Do(124, func() error { return nil }) + if err != nil { + t.Errorf("unexpected error %v", err) + } +} diff --git a/sync/workerpool/limiter.go b/sync/workerpool/limiter.go deleted file mode 100644 index e51ca622..00000000 --- a/sync/workerpool/limiter.go +++ /dev/null @@ -1,21 +0,0 @@ -package workerpool - -// Limiter implements a counting semaphore based on a channel. Acquire is -// implemented by putting an integer in the channel and blocking if it -// is full; Release is implemented by removing an integer from the channel. -type Limiter chan int - -// NewLimiter creates a new semaphore with a given capacity. -func NewLimiter(count int) Limiter { - return make(Limiter, count) -} - -// Acquire adds an integer to the channel, blocking if it is full. -func (s Limiter) Acquire() { - s <- 1 -} - -// Release removes an integer from the channel. -func (s Limiter) Release() { - <-s -} diff --git a/sync/workerpool/workerpool.go b/sync/workerpool/workerpool.go index eca4cb8f..b65be556 100644 --- a/sync/workerpool/workerpool.go +++ b/sync/workerpool/workerpool.go @@ -70,7 +70,7 @@ func New(ctx context.Context, concurrency int) *WorkerPool { // specific subgroup of Tasks to Wait. type TaskGroup struct { Name string - ErrHandler *multierror.MultiError + ErrHandler *multierror.Builder Wp *WorkerPool activity sync.WaitGroup // Count active tasks } @@ -81,7 +81,7 @@ type TaskGroup struct { // separate from the WorkerPool context.Context. // // TODO(pknudsgaard): Should return a closure calling Wait. -func (wp *WorkerPool) NewTaskGroup(name string, errHandler *multierror.MultiError) *TaskGroup { +func (wp *WorkerPool) NewTaskGroup(name string, errHandler *multierror.Builder) *TaskGroup { vlog.VI(2).Infof("Creating TaskGroup: %s", name) grp := &TaskGroup{ diff --git a/syncqueue/ordered_queue_test.go b/syncqueue/ordered_queue_test.go index a6b99064..fb305d8e 100644 --- a/syncqueue/ordered_queue_test.go +++ b/syncqueue/ordered_queue_test.go @@ -2,11 +2,10 @@ package syncqueue_test import ( "fmt" - "sync" "testing" - "github.com/stretchr/testify/assert" "github.com/grailbio/base/syncqueue" + "github.com/stretchr/testify/assert" ) func checkNext(t *testing.T, q *syncqueue.OrderedQueue, value interface{}, ok bool) { @@ -73,28 +72,21 @@ func TestNoBlockWhenInsertNext(t *testing.T) { } func TestInsertBlockWithNextAvailable(t *testing.T) { - cond := sync.NewCond(&sync.Mutex{}) + resultChan := make(chan bool, 1) q := syncqueue.NewOrderedQueue(2) q.Insert(1, "one") q.Insert(0, "zero") - insertedTwo := false go func() { q.Insert(2, "two") - insertedTwo = true - cond.Signal() + resultChan <- true + close(resultChan) }() - assert.False(t, insertedTwo, "Expected insert(2, two) to block until there is space in the queue") checkNext(t, q, "zero", true) - cond.L.Lock() - for !insertedTwo { - cond.Wait() - } - assert.True(t, insertedTwo, "Expected insert(2, two) to complete after removing an item from the queue") - cond.L.Unlock() - + result := <-resultChan + assert.True(t, result, "Expected insert(2, two) to complete after removing an item from the queue") q.Close(nil) checkNext(t, q, "one", true) @@ -103,30 +95,23 @@ func TestInsertBlockWithNextAvailable(t *testing.T) { } func TestInsertBlockWithoutNextAvailable(t *testing.T) { - cond := sync.NewCond(&sync.Mutex{}) + resultChan := make(chan bool, 1) q := syncqueue.NewOrderedQueue(2) q.Insert(1, "one") - insertedTwo := false go func() { q.Insert(2, "two") - insertedTwo = true - cond.Signal() + resultChan <- true + close(resultChan) }() - assert.False(t, insertedTwo, "Expected insert(2, two) to block until there is space in the queue") q.Insert(0, "zero") checkNext(t, q, "zero", true) // Wait until insert two finishes - cond.L.Lock() - for !insertedTwo { - cond.Wait() - } - assert.True(t, insertedTwo, "Expected insert(2, two) to complete after removing an item from the queue") - cond.L.Unlock() - + result := <-resultChan + assert.True(t, result, "Expected insert(2, two) to complete after removing an item from the queue") q.Close(nil) checkNext(t, q, "one", true) @@ -135,71 +120,54 @@ func TestInsertBlockWithoutNextAvailable(t *testing.T) { } func TestNextBlockWhenEmpty(t *testing.T) { - cond := sync.NewCond(&sync.Mutex{}) + resultChan := make(chan bool, 1) q := syncqueue.NewOrderedQueue(2) - gotZero := false go func() { checkNext(t, q, "zero", true) - gotZero = true - cond.Signal() + resultChan <- true + close(resultChan) }() - assert.False(t, gotZero, "Expected Next block until there is something in the queue") - // Insert zero and then wait until Next returns q.Insert(0, "zero") - cond.L.Lock() - for !gotZero { - cond.Wait() - } - assert.True(t, gotZero, "Expected Next() to complete after inserting zero") - cond.L.Unlock() + result := <-resultChan + assert.True(t, result, "Expected Next() to complete after inserting zero") q.Close(nil) checkNext(t, q, nil, false) } func TestInsertGetsError(t *testing.T) { - cond := sync.NewCond(&sync.Mutex{}) + errors := make(chan error, 1) q := syncqueue.NewOrderedQueue(1) q.Insert(0, "zero") - var insertError error go func() { - insertError = q.Insert(1, "one") - cond.Signal() + errors <- q.Insert(1, "one") + close(errors) }() - assert.Nil(t, insertError, "Expected insert(1, one) to be nil until there is an error") + // Close q with an error. q.Close(fmt.Errorf("Foo error")) - cond.L.Lock() - for insertError == nil { - cond.Wait() - } - assert.Equal(t, "Foo error", insertError.Error()) - cond.L.Unlock() + // Wait for Insert to return with an error, and verify the value of the error. + e := <-errors + assert.Equal(t, "Foo error", e.Error()) } func TestNextGetsError(t *testing.T) { - cond := sync.NewCond(&sync.Mutex{}) + errorChan := make(chan error, 1) q := syncqueue.NewOrderedQueue(1) - var nextError error go func() { - _, _, nextError = q.Next() - cond.Signal() + _, _, err := q.Next() + errorChan <- err + close(errorChan) }() - assert.Nil(t, nextError, "Expected nextError to be nil until there is an error") q.Close(fmt.Errorf("Foo error")) - - cond.L.Lock() - for nextError == nil { - cond.Wait() - } - assert.Equal(t, "Foo error", nextError.Error()) - cond.L.Unlock() + err := <-errorChan + assert.Equal(t, "Foo error", err.Error()) } diff --git a/traverse/example_test.go b/traverse/example_test.go new file mode 100644 index 00000000..f78e5b75 --- /dev/null +++ b/traverse/example_test.go @@ -0,0 +1,23 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. +package traverse_test + +import ( + "math/rand" + + "github.com/grailbio/base/traverse" +) + +func Example() { + // Compute N random numbers in parallel. + const N = 1e5 + out := make([]float64, N) + _ = traverse.Parallel.Range(len(out), func(start, end int) error { + r := rand.New(rand.NewSource(rand.Int63())) + for i := start; i < end; i++ { + out[i] = r.Float64() + } + return nil + }) +} diff --git a/traverse/reporter.go b/traverse/reporter.go new file mode 100644 index 00000000..96a5bc47 --- /dev/null +++ b/traverse/reporter.go @@ -0,0 +1,70 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package traverse + +import ( + "fmt" + "os" + "sync" +) + +// A Reporter receives events from an ongoing traversal. Reporters +// can be passed as options into Traverse, and are used to monitor +// progress of long-running traversals. +type Reporter interface { + // Init is called when processing is about to begin. Parameter + // n indicates the number of tasks to be executed by the traversal. + Init(n int) + // Complete is called after the traversal has completed. + Complete() + + // Begin is called when task i is begun. + Begin(i int) + // End is called when task i has completed. + End(i int) +} + +// NewSimpleReporter returns a new reporter that prints the number +// of queued, running, and completed tasks to stderr. +func NewSimpleReporter(name string) Reporter { + return &simpleReporter{name: name} +} + +type simpleReporter struct { + name string + mu sync.Mutex + queued, running, done int +} + +func (r *simpleReporter) Init(n int) { + r.mu.Lock() + r.queued = n + r.update() + r.mu.Unlock() +} + +func (r *simpleReporter) Complete() { + fmt.Fprintf(os.Stderr, "\n") +} + +func (r *simpleReporter) Begin(i int) { + r.mu.Lock() + r.queued-- + r.running++ + r.update() + r.mu.Unlock() +} + +func (r *simpleReporter) End(i int) { + r.mu.Lock() + r.running-- + r.done++ + r.update() + r.mu.Unlock() +} + +func (r *simpleReporter) update() { + fmt.Fprintf(os.Stderr, "%s: (queued: %d -> running: %d -> done: %d) \r", r.name, r.queued, r.running, r.done) +} diff --git a/traverse/time_estimate_reporter.go b/traverse/time_estimate_reporter.go index c6b732f5..984a71bc 100644 --- a/traverse/time_estimate_reporter.go +++ b/traverse/time_estimate_reporter.go @@ -8,15 +8,10 @@ import ( "time" ) -// TimeEstimateReporter is a Reporter that prints to stderr the number of jobs queued, -// running, and done, as well as the running time of the Traverse and an estimate for -// the amount of time remaining. -// Note: for estimation, it assumes jobs have roughly equal running time and are FIFO-ish -// (that is, it does not try to account for the bias of shorter jobs finishing first and -// therefore skewing the average estimated job run time). -type TimeEstimateReporter struct { - // Name is the name of the job to display. - Name string +type timeEstimateReporter struct { + name string + + mu sync.Mutex numWorkers int32 numQueued int32 @@ -24,122 +19,118 @@ type TimeEstimateReporter struct { numDone int32 // start time of the Traverse startTime time.Time - // start times of the individual jobs that are currently running - startTimes timeQueue - cummulativeRuntime time.Duration - ticker *time.Ticker + + cumulativeRuntime time.Duration + ticker *time.Ticker + + startTimes map[int]time.Time // used to prevent race conditions with printStatus and startTimes queue - mut sync.Mutex } -// Report prints the number of jobs currently queued, running, and done. -func (reporter *TimeEstimateReporter) Report(queued, running, done int32) { - reporter.mut.Lock() - - currentTime := time.Now() - - if running == 0 && done == 0 { - reporter.startTime = currentTime - reporter.startTimes.init(queued) - reporter.numWorkers = 1 - reporter.numQueued = queued - reporter.ticker = time.NewTicker(1 * time.Second) - - go func(reporter *TimeEstimateReporter) { - for range reporter.ticker.C { - reporter.mut.Lock() - reporter.printStatus() - reporter.mut.Unlock() - } - }(reporter) +// NewTimeEstimateReporter returns a reporter that reports the number +// of jobs queued, running, and done, as well as the running time of +// the Traverse and an estimate for the amount of time remaining. +// Note: for estimation, it assumes jobs have roughly equal running +// time and are FIFO-ish (that is, it does not try to account for the +// bias of shorter jobs finishing first and therefore skewing the +// average estimated job run time). +func NewTimeEstimateReporter(name string) Reporter { + return &timeEstimateReporter{ + name: name, + startTimes: make(map[int]time.Time), } +} - if running > reporter.numWorkers { - reporter.numWorkers = running - } +func (r *timeEstimateReporter) Init(n int) { + r.numQueued = int32(n) + r.numWorkers = 1 + r.startTime = time.Now() + r.ticker = time.NewTicker(time.Second) + + go func() { + for range r.ticker.C { + r.mu.Lock() + r.printStatus() + r.mu.Unlock() + } + fmt.Fprintf(os.Stderr, "\n") + }() +} - // Job started - if reporter.numQueued-1 == queued && reporter.numRunning+1 == running && reporter.numDone == done { - reporter.startTimes.push(time.Now()) - } +func (r *timeEstimateReporter) Complete() { + r.ticker.Stop() +} - // Job finished - if reporter.numQueued == queued && reporter.numRunning-1 == running && reporter.numDone+1 == done { - reporter.cummulativeRuntime += time.Since(reporter.startTimes.pop()) +func (r *timeEstimateReporter) Begin(i int) { + r.mu.Lock() + defer r.mu.Unlock() + r.startTimes[i] = time.Now() + r.numQueued-- + r.numRunning++ + if r.numRunning > r.numWorkers { + r.numWorkers = r.numRunning } + r.printStatus() +} - reporter.numQueued = queued - reporter.numRunning = running - reporter.numDone = done - - reporter.printStatus() - if queued == 0 && running == 0 { - reporter.ticker.Stop() - fmt.Fprintf(os.Stderr, "\n") +func (r *timeEstimateReporter) End(i int) { + r.mu.Lock() + defer r.mu.Unlock() + start, ok := r.startTimes[i] + if !ok { + panic("end called without start") } - reporter.mut.Unlock() + delete(r.startTimes, i) + r.numRunning-- + r.numDone++ + r.cumulativeRuntime += time.Since(start) + + r.printStatus() } -func (reporter *TimeEstimateReporter) printStatus() { - timeLeftStr := reporter.buildTimeLeftStr(time.Now()) +func (r *timeEstimateReporter) printStatus() { + timeLeftStr := r.buildTimeLeftStr(time.Now()) fmt.Fprintf(os.Stderr, "%s: (queued: %d -> running: %d -> done: %d) %v %s \r", - reporter.Name, reporter.numQueued, reporter.numRunning, reporter.numDone, - time.Since(reporter.startTime).Round(time.Second), timeLeftStr) + r.name, r.numQueued, r.numRunning, r.numDone, + time.Since(r.startTime).Round(time.Second), timeLeftStr) } -func (reporter TimeEstimateReporter) buildTimeLeftStr(currentTime time.Time) string { +func (r *timeEstimateReporter) buildTimeLeftStr(currentTime time.Time) string { // If some jobs have finished, use their running time for the estimate. Otherwise, use the duration // that the first job has been running. var modifier string var avgRunTime time.Duration - if reporter.cummulativeRuntime > 0 { + if r.cumulativeRuntime > 0 { modifier = "~" - avgRunTime = reporter.cummulativeRuntime / time.Duration(reporter.numDone) - } else if reporter.numRunning > 0 { + avgRunTime = r.cumulativeRuntime / time.Duration(r.numDone) + } else if r.numRunning > 0 { modifier = ">" - avgRunTime = currentTime.Sub(reporter.startTimes.peek()) + for _, t := range r.startTimes { + avgRunTime += currentTime.Sub(t) + } + avgRunTime /= time.Duration(len(r.startTimes)) } - runningJobsTimeLeft := time.Duration(reporter.numRunning)*avgRunTime - reporter.sumCurrentRunningTimes(currentTime) - if reporter.numRunning > 0 { - runningJobsTimeLeft /= time.Duration(reporter.numRunning) + runningJobsTimeLeft := time.Duration(r.numRunning)*avgRunTime - r.sumCurrentRunningTimes(currentTime) + if r.numRunning > 0 { + runningJobsTimeLeft /= time.Duration(r.numRunning) } if runningJobsTimeLeft < 0 { runningJobsTimeLeft = time.Duration(0) } - queuedJobsTimeLeft := time.Duration(math.Ceil(float64(reporter.numQueued)/float64(reporter.numWorkers))) * avgRunTime + queuedJobsTimeLeft := time.Duration(math.Ceil(float64(r.numQueued)/float64(r.numWorkers))) * avgRunTime return fmt.Sprintf("(%s%v left %v avg)", modifier, (queuedJobsTimeLeft + runningJobsTimeLeft).Round(time.Second), avgRunTime.Round(time.Second)) } -func (reporter TimeEstimateReporter) sumCurrentRunningTimes(currentTime time.Time) time.Duration { +func (r *timeEstimateReporter) sumCurrentRunningTimes(currentTime time.Time) time.Duration { var totalRunningTime time.Duration - for _, startTime := range reporter.startTimes { + for _, startTime := range r.startTimes { totalRunningTime += currentTime.Sub(startTime) } return totalRunningTime } - -type timeQueue []time.Time - -func (q *timeQueue) init(capacity int32) { - *q = make([]time.Time, 0, capacity) -} - -func (q *timeQueue) push(t time.Time) { - (*q) = append((*q), t) -} - -func (q *timeQueue) pop() time.Time { - elem := (*q)[0] - *q = (*q)[1:] - return elem -} - -func (q timeQueue) peek() time.Time { - return q[0] -} diff --git a/traverse/time_estimate_reporter_test.go b/traverse/time_estimate_reporter_test.go index 34809fd8..231fe647 100644 --- a/traverse/time_estimate_reporter_test.go +++ b/traverse/time_estimate_reporter_test.go @@ -9,176 +9,141 @@ import ( "time" ) -func TestTimeQueue(t *testing.T) { - var queue timeQueue - queue.init(5) - - t1 := time.Now() - t2 := t1.Add(1) - - expectQueue(queue, []time.Time{}, t) - queue.push(t1) - expectQueue(queue, []time.Time{t1}, t) - queue.push(t2) - expectQueue(queue, []time.Time{t1, t2}, t) - result := queue.peek() - if result != t1 { - t.Errorf("Expected Peek result: %v actual %v", t1, result) - } - expectQueue(queue, []time.Time{t1, t2}, t) - result = queue.pop() - if result != t1 { - t.Errorf("Expected Pop resule: %v actual %v", t1, result) - } - expectQueue(queue, []time.Time{t2}, t) -} - -func expectQueue(queue timeQueue, times []time.Time, t *testing.T) { - if len(queue) != len(times) { - t.Errorf("Expected queue: %v actual %v", times, queue) - } - for i, elem := range queue { - if elem != times[i] { - t.Errorf("Expected queue: %v actual %v", times, queue) - } - } -} - func TestBuildTimeLeftStr(t *testing.T) { currentTime := time.Now() tests := []struct { - reporter TimeEstimateReporter + reporter *timeEstimateReporter expected string }{ { - reporter: TimeEstimateReporter{ - numWorkers: 1, - numQueued: 10, - numRunning: 0, - numDone: 0, - startTime: currentTime, - startTimes: []time.Time{}, - cummulativeRuntime: time.Duration(0)}, + reporter: &timeEstimateReporter{ + numWorkers: 1, + numQueued: 10, + numRunning: 0, + numDone: 0, + startTime: currentTime, + startTimes: map[int]time.Time{}, + cumulativeRuntime: time.Duration(0)}, expected: "(0s left 0s avg)", }, { - reporter: TimeEstimateReporter{ - numWorkers: 1, - numQueued: 9, - numRunning: 1, - numDone: 0, - startTime: currentTime.Add(-1 * time.Second), - startTimes: []time.Time{currentTime.Add(-1 * time.Second)}, - cummulativeRuntime: time.Duration(0), + reporter: &timeEstimateReporter{ + numWorkers: 1, + numQueued: 9, + numRunning: 1, + numDone: 0, + startTime: currentTime.Add(-1 * time.Second), + startTimes: map[int]time.Time{0: currentTime.Add(-1 * time.Second)}, + cumulativeRuntime: time.Duration(0), }, expected: "(>9s left 1s avg)", }, { - reporter: TimeEstimateReporter{ - numWorkers: 1, - numQueued: 9, - numRunning: 0, - numDone: 1, - startTime: currentTime.Add(-5 * time.Second), - startTimes: []time.Time{}, - cummulativeRuntime: time.Duration(5 * time.Second), + reporter: &timeEstimateReporter{ + numWorkers: 1, + numQueued: 9, + numRunning: 0, + numDone: 1, + startTime: currentTime.Add(-5 * time.Second), + startTimes: map[int]time.Time{}, + cumulativeRuntime: time.Duration(5 * time.Second), }, expected: "(~45s left 5s avg)", }, { - reporter: TimeEstimateReporter{ - numWorkers: 1, - numQueued: 8, - numRunning: 1, - numDone: 1, - startTime: currentTime.Add(-10 * time.Second), - startTimes: []time.Time{currentTime.Add(-4 * time.Second)}, - cummulativeRuntime: time.Duration(5 * time.Second), + reporter: &timeEstimateReporter{ + numWorkers: 1, + numQueued: 8, + numRunning: 1, + numDone: 1, + startTime: currentTime.Add(-10 * time.Second), + startTimes: map[int]time.Time{0: currentTime.Add(-4 * time.Second)}, + cumulativeRuntime: time.Duration(5 * time.Second), }, expected: "(~41s left 5s avg)", }, { - reporter: TimeEstimateReporter{ - numWorkers: 1, - numQueued: 0, - numRunning: 1, - numDone: 9, - startTime: currentTime.Add(-45 * time.Second), - startTimes: []time.Time{currentTime.Add(-1 * time.Second)}, - cummulativeRuntime: time.Duration(9 * 5 * time.Second), + reporter: &timeEstimateReporter{ + numWorkers: 1, + numQueued: 0, + numRunning: 1, + numDone: 9, + startTime: currentTime.Add(-45 * time.Second), + startTimes: map[int]time.Time{0: currentTime.Add(-1 * time.Second)}, + cumulativeRuntime: time.Duration(9 * 5 * time.Second), }, expected: "(~4s left 5s avg)", }, { - reporter: TimeEstimateReporter{ - numWorkers: 2, - numQueued: 8, - numRunning: 2, - numDone: 0, - startTime: currentTime.Add(-2 * time.Second), - startTimes: []time.Time{currentTime.Add(-2 * time.Second), currentTime.Add(-1 * time.Second)}, - cummulativeRuntime: time.Duration(0), + reporter: &timeEstimateReporter{ + numWorkers: 2, + numQueued: 8, + numRunning: 2, + numDone: 0, + startTime: currentTime.Add(-2 * time.Second), + startTimes: map[int]time.Time{0: currentTime.Add(-2 * time.Second), 1: currentTime.Add(-1 * time.Second)}, + cumulativeRuntime: time.Duration(0), }, - expected: "(>9s left 2s avg)", + expected: "(>6s left 2s avg)", }, { - reporter: TimeEstimateReporter{ - numWorkers: 2, - numQueued: 6, - numRunning: 2, - numDone: 2, - startTime: currentTime.Add(-14 * time.Second), - startTimes: []time.Time{currentTime.Add(-4 * time.Second), currentTime.Add(-2 * time.Second)}, - cummulativeRuntime: time.Duration(2 * 5 * time.Second), + reporter: &timeEstimateReporter{ + numWorkers: 2, + numQueued: 6, + numRunning: 2, + numDone: 2, + startTime: currentTime.Add(-14 * time.Second), + startTimes: map[int]time.Time{0: currentTime.Add(-4 * time.Second), 1: currentTime.Add(-2 * time.Second)}, + cumulativeRuntime: time.Duration(2 * 5 * time.Second), }, expected: "(~17s left 5s avg)", }, { - reporter: TimeEstimateReporter{ - numWorkers: 2, - numQueued: 2, - numRunning: 0, - numDone: 8, - startTime: currentTime.Add(-45 * time.Second), - startTimes: []time.Time{}, - cummulativeRuntime: time.Duration(8 * 5 * time.Second), + reporter: &timeEstimateReporter{ + numWorkers: 2, + numQueued: 2, + numRunning: 0, + numDone: 8, + startTime: currentTime.Add(-45 * time.Second), + startTimes: map[int]time.Time{}, + cumulativeRuntime: time.Duration(8 * 5 * time.Second), }, expected: "(~5s left 5s avg)", }, { // Note even though we have 2 workers, only one can process the single queued job, so expected time left is 5s. - reporter: TimeEstimateReporter{ - numWorkers: 2, - numQueued: 1, - numRunning: 0, - numDone: 9, - startTime: currentTime.Add(-45 * time.Second), - startTimes: []time.Time{}, - cummulativeRuntime: time.Duration(9 * 5 * time.Second), + reporter: &timeEstimateReporter{ + numWorkers: 2, + numQueued: 1, + numRunning: 0, + numDone: 9, + startTime: currentTime.Add(-45 * time.Second), + startTimes: map[int]time.Time{}, + cumulativeRuntime: time.Duration(9 * 5 * time.Second), }, expected: "(~5s left 5s avg)", }, { - reporter: TimeEstimateReporter{ - numWorkers: 2, - numQueued: 0, - numRunning: 1, - numDone: 9, - startTime: currentTime.Add(-48 * time.Second), - startTimes: []time.Time{currentTime.Add(-3 * time.Second)}, - cummulativeRuntime: time.Duration(9 * 5 * time.Second), + reporter: &timeEstimateReporter{ + numWorkers: 2, + numQueued: 0, + numRunning: 1, + numDone: 9, + startTime: currentTime.Add(-48 * time.Second), + startTimes: map[int]time.Time{0: currentTime.Add(-3 * time.Second)}, + cumulativeRuntime: time.Duration(9 * 5 * time.Second), }, expected: "(~2s left 5s avg)", }, { // Last job is taking longer than average to run. - reporter: TimeEstimateReporter{ - numWorkers: 2, - numQueued: 0, - numRunning: 1, - numDone: 9, - startTime: currentTime.Add(-52 * time.Second), - startTimes: []time.Time{currentTime.Add(-7 * time.Second)}, - cummulativeRuntime: time.Duration(9 * 5 * time.Second), + reporter: &timeEstimateReporter{ + numWorkers: 2, + numQueued: 0, + numRunning: 1, + numDone: 9, + startTime: currentTime.Add(-52 * time.Second), + startTimes: map[int]time.Time{0: currentTime.Add(-7 * time.Second)}, + cumulativeRuntime: time.Duration(9 * 5 * time.Second), }, expected: "(~0s left 5s avg)", }, diff --git a/traverse/traverse.go b/traverse/traverse.go index 2d617576..73f81f8d 100644 --- a/traverse/traverse.go +++ b/traverse/traverse.go @@ -2,221 +2,292 @@ // Use of this source code is governed by the Apache-2.0 // license that can be found in the LICENSE file. -// Package traverse provides facilities for concurrent and parallel slice traversal. +// Package traverse provides primitives for concurrent and parallel +// traversal of slices or user-defined collections. package traverse import ( "fmt" - "github.com/grailbio/base/errorreporter" - "os" + "log" "runtime" "runtime/debug" "sync" "sync/atomic" -) -type panicErr struct { - v interface{} - stack []byte -} - -func (p panicErr) Error() string { return fmt.Sprint(p.v) } - -// Traverse is a traversal of a given length. Traverse instances -// should be instantiated with Each and Parallel. -type Traverse struct { - n, maxConcurrent, nshards int - debugStatus *status -} + "github.com/grailbio/base/errors" +) -// Each creates a new traversal of length n appropriate for -// concurrent traversal. -func Each(n int) Traverse { - return Traverse{n, n, 0, nil} -} +const cachelineSize = 64 -// Parallel creates a new traversal of length n appropriate for -// parallel traversal. -func Parallel(n int) Traverse { - return Each(n).Limit(runtime.NumCPU()) +// A T is a traverser: it provides facilities for concurrently +// invoking functions that traverse collections of data. +type T struct { + // Limit is the traverser's concurrency limit: there will be no more + // than Limit concurrent invocations per traversal. A limit value of + // zero (the default value) denotes no limit. + Limit int + // Sequential indicates that early indexes should be handled before later + // ones. E.g. if there are 40000 tasks and Limit == 40, the initial + // assignment is usually + // worker 0 <- tasks 0-999 + // worker 1 <- tasks 1000-1999 + // ... + // worker 39 <- tasks 39000-39999 + // but when Sequential == true, only tasks 0-39 are initially assigned, then + // task 40 goes to the first worker to finish, etc. + // Note that this increases synchronization overhead. It should not be used + // with e.g. > 1 billion tiny tasks; in that scenario, the caller should + // organize such tasks into e.g. 10000-task chunks and perform a + // sequential-traverse on the chunks. + // This scheduling algorithm does perform well when tasks are sorted in order + // of decreasing size. + Sequential bool + // Reporter receives status reports for each traversal. It is + // intended for users who wish to monitor the progress of large + // traversal jobs. + Reporter Reporter } -// Limit limits the concurrency of the traversal to maxConcurrent. -func (t Traverse) Limit(maxConcurrent int) Traverse { - t.maxConcurrent = maxConcurrent - return t +// Limit returns a traverser with limit n. +func Limit(n int) T { + if n <= 0 { + log.Panicf("traverse.Limit: invalid limit: %d", n) + } + return T{Limit: n} } -// Sharded sets the number of shards we want to use for the traverse -// (for traverse of large number of elements where processing each element -// is very fast, it will be more efficient to shard the processing, -// rather than use a separate goroutine for each element). -// If not set, by default the number of shards is equal to the number -// of elements to traverse (shard size 1). -// When using a Reporter, each shard will be reported as a single job. -func (t Traverse) Sharded(nshards int) Traverse { - t.nshards = nshards - return t +// LimitSequential returns a sequential traverser with limit n. +func LimitSequential(n int) T { + if n <= 0 { + log.Panicf("traverse.LimitSequential: invalid limit: %d", n) + } + return T{Limit: n, Sequential: true} } -// WithReporter will use the given reporter to report the progress on the jobs. -// Ex. traverse.Each(9).WithReporter(traverse.DefaultReporter{Name: "Processing:"}).Do(func(i int) error { ... -func (t Traverse) WithReporter(reporter Reporter) Traverse { - t.debugStatus = &status{&sync.Mutex{}, reporter, 0, 0, 0} - return t -} +// Parallel is the default traverser for parallel traversal, intended +// CPU-intensive parallel computing. Parallel limits the number of +// concurrent invocations to a small multiple of the runtime's +// available processors. +var Parallel = T{Limit: 2 * runtime.GOMAXPROCS(0)} -// Do performs a traversal, invoking function op for each index, 0 <= -// i < t.n. Do returns the first error returned by any invoked op, or -// nil when all ops succeed. Traversal is terminated early on error. -// Panics are recovered in ops and propagated to the calling -// goroutine, printing the original stack trace. Do guarantees that, -// after it returns, no more ops will be invoked. -func (t Traverse) Do(op func(i int) error) (err error) { - return t.DoRange(func(start, end int) error { - for i := start; i < end && err == nil; i++ { - err = op(i) - } - return err - }) -} - -// DoRange is similar to Do above, except it accepts a function that runs -// over a block of indices [start, end). This can be more efficient if -// running sharded traverse on an input where the operation for each index -// if very simple/fast. For example, to add 1 to each element of a []int -// traverse.Each(len(slice)).Limit(10).Sharded(10).DoRange(func(start, end int) error { -// for i := start; i < end; i++ { -// slice[i]++ -// } -// return nil -// } -func (t Traverse) DoRange(op func(start, end int) error) error { - if t.n == 0 { +// Each performs a traversal on fn. Specifically, Each invokes fn(i) +// for 0 <= i < n, managing concurrency and error propagation. Each +// returns when the all invocations have completed, or after the +// first invocation fails, in which case the first invocation error +// is returned. Each also propagates panics from underlying invocations +// to the caller. Note that if a function panics and doesn't release +// shared resources that fn might need in a traverse child, this could +// lead to deadlock. +func (t T) Each(n int, fn func(i int) error) error { + if t.Reporter != nil { + t.Reporter.Init(n) + defer t.Reporter.Complete() + } + var err error + if t.Limit == 1 || n == 1 { + err = t.eachSerial(n, fn) + } else if t.Limit == 0 || t.Limit >= n { + err = t.each(n, fn) + } else if t.Sequential { + err = t.eachSequential(n, fn) + } else { + err = t.eachLimit(n, fn) + } + if err == nil { return nil } - - numShards := t.n - shardSize := 1 - if t.nshards > 0 { - numShards = min(t.nshards, t.n) - shardSize = (t.n + t.nshards - 1) / t.nshards + // Propagate panics. + if err, ok := err.(panicErr); ok { + panic(fmt.Sprintf("traverse child: %v\n%s", err.v, string(err.stack))) } + return err +} - if numShards < t.maxConcurrent { - t.maxConcurrent = numShards +func (t T) each(n int, fn func(i int) error) error { + var ( + errors errors.Once + wg sync.WaitGroup + ) + wg.Add(n) + for i := 0; i < n; i++ { + go func(i int) { + if t.Reporter != nil { + t.Reporter.Begin(i) + } + if err := apply(fn, i); err != nil { + errors.Set(err) + } + if t.Reporter != nil { + t.Reporter.End(i) + } + wg.Done() + }(i) } + wg.Wait() + return errors.Err() +} - var errorReporter errorreporter.T - apply := func(i int) (err error) { - defer func() { - if perr := recover(); perr != nil { - err = panicErr{perr, debug.Stack()} - } - }() - start := i * shardSize - return op(start, min(start+shardSize, t.n)) +// eachSerial runs on the local thread using a conventional for loop. +// all invocations will be run in numerical order. +func (t T) eachSerial(n int, fn func(i int) error) error { + for i := 0; i < n; i++ { + if t.Reporter != nil { + t.Reporter.Begin(i) + } + if err := apply(fn, i); err != nil { + return err + } + if t.Reporter != nil { + t.Reporter.End(i) + } } - var wg sync.WaitGroup - wg.Add(t.maxConcurrent) - t.debugStatus.queueJobs(int32(numShards)) + return nil +} - var x int64 = -1 // x is treated with atomic operations and accessed from multiple go routines - for i := 0; i < t.maxConcurrent; i++ { +// eachSequential performs a concurrent run where tasks are assigned in strict +// numerical order. Unlike eachLimit(), it can be used when the traversal must +// be done sequentially. +func (t T) eachSequential(n int, fn func(i int) error) error { + var ( + errors errors.Once + wg sync.WaitGroup + syncStruct struct { + _ [cachelineSize - 8]byte // cache padding + N int64 + _ [cachelineSize - 8]byte // cache padding + } + ) + syncStruct.N = -1 + wg.Add(t.Limit) + for i := 0; i < t.Limit; i++ { go func() { - defer wg.Done() - for { - i := int(atomic.AddInt64(&x, 1)) // the first iteration will return 0. - if i >= numShards || errorReporter.Err() != nil { - return + for errors.Err() == nil { + idx := int(atomic.AddInt64(&syncStruct.N, 1)) + if idx >= n { + break + } + if t.Reporter != nil { + t.Reporter.Begin(idx) + } + if err := apply(fn, idx); err != nil { + errors.Set(err) } - t.debugStatus.startJob() - err := apply(i) - t.debugStatus.finishJob() - if err != nil { - errorReporter.Set(err) - return + if t.Reporter != nil { + t.Reporter.End(idx) } } + wg.Done() }() } - wg.Wait() - // read the first errors that may have occurred - foundError := errorReporter.Err() - if foundError != nil { - if err, ok := foundError.(panicErr); ok { - panic(fmt.Sprintf("traverse child: %s\n%s", err.v, string(err.stack))) - } - return foundError - } - return nil + return errors.Err() } -// Reporter is the interface for reporting the progress on traverse jobs. -type Reporter interface { - // Report is called every time the number of jobs queued, running, or done changes. - Report(queued, running, done int32) -} - -// DefaultReporter is a simple Reporter that prints to stderr the number of -// jobs queued, running, and done -type DefaultReporter struct { - Name string +// eachLimit performs a concurrent run where tasks can be assigned in any +// order. +func (t T) eachLimit(n int, fn func(i int) error) error { + var ( + errors errors.Once + wg sync.WaitGroup + next = make([]struct { + N int64 + _ [cachelineSize - 8]byte // cache padding + }, t.Limit) + size = (n + t.Limit - 1) / t.Limit + ) + wg.Add(t.Limit) + for i := 0; i < t.Limit; i++ { + go func(w int) { + orig := w + for errors.Err() == nil { + // Each worker traverses contiguous segments since there is + // often usable data locality associated with index locality. + idx := int(atomic.AddInt64(&next[w].N, 1) - 1) + which := w*size + idx + if idx >= size || which >= n { + w = (w + 1) % t.Limit + if w == orig { + break + } + continue + } + if t.Reporter != nil { + t.Reporter.Begin(which) + } + if err := apply(fn, which); err != nil { + errors.Set(err) + } + if t.Reporter != nil { + t.Reporter.End(which) + } + } + wg.Done() + }(i) + } + wg.Wait() + return errors.Err() } -// Report prints the number of jobs currently queued, running, and done. -func (reporter DefaultReporter) Report(queued, running, done int32) { - fmt.Fprintf(os.Stderr, "%s: (queued: %d -> running: %d -> done: %d) \r", reporter.Name, queued, running, done) - if queued == 0 && running == 0 { - fmt.Fprintf(os.Stderr, "\n") +// Range performs ranged traversal on fn: n is split into +// contiguous ranges, and fn is invoked for each range. The range +// sizes are determined by the traverser's concurrency limits. Range +// allows the caller to amortize function call costs, and is +// typically used when limit is small and n is large, for example on +// parallel traversal over large collections, where each item's +// processing time is comparatively small. +func (t T) Range(n int, fn func(start, end int) error) error { + if t.Sequential { + // interface for this should take a chunk size. + log.Panicf("traverse.Range: sequential traversal unsupported") + } + m := n + if t.Limit > 0 && t.Limit < n { + m = t.Limit } + // TODO: consider splitting ranges into smaller chunks so that can + // take better advantage of the load balancing underneath. + return t.Each(m, func(i int) error { + var ( + size = float64(n) / float64(m) + start = int(float64(i) * size) + end = int(float64(i+1) * size) + ) + if start >= n { + return nil + } + if i == m-1 { + end = n + } + return fn(start, end) + }) } -// status keeps track of how many jobs are queued, running, and done. -type status struct { - mu *sync.Mutex - reporter Reporter - queued int32 - done int32 - running int32 -} +var defaultT = T{} -func (s *status) queueJobs(numjobs int32) { - if s == nil { - return - } - s.mu.Lock() - s.queued += numjobs - s.reporter.Report(s.queued, s.running, s.done) - s.mu.Unlock() +// Each performs concurrent traversal over n elements. It is a +// shorthand for (T{}).Each. +func Each(n int, fn func(i int) error) error { + return defaultT.Each(n, fn) } -func (s *status) startJob() { - if s == nil { - return - } - s.mu.Lock() - s.queued-- - s.running++ - s.reporter.Report(s.queued, s.running, s.done) - s.mu.Unlock() +// CPU calls the function fn for each available system CPU. CPU +// returns when all calls have completed or on first error. +func CPU(fn func() error) error { + return Each(runtime.NumCPU(), func(int) error { return fn() }) } -func (s *status) finishJob() { - if s == nil { - return - } - s.mu.Lock() - s.running-- - s.done++ - s.reporter.Report(s.queued, s.running, s.done) - s.mu.Unlock() +func apply(fn func(i int) error, i int) (err error) { + defer func() { + if perr := recover(); perr != nil { + err = panicErr{perr, debug.Stack()} + } + }() + return fn(i) } -func min(x, y int) int { - if x < y { - return x - } - return y +type panicErr struct { + v interface{} + stack []byte } + +func (p panicErr) Error() string { return fmt.Sprint(p.v) } diff --git a/traverse/traverse_test.go b/traverse/traverse_test.go index 176f0901..837f5f28 100644 --- a/traverse/traverse_test.go +++ b/traverse/traverse_test.go @@ -2,14 +2,20 @@ // Use of this source code is governed by the Apache-2.0 // license that can be found in the LICENSE file. -package traverse +package traverse_test import ( "errors" "fmt" + "math/rand" "reflect" "strings" + "sync" + "sync/atomic" "testing" + "time" + + "github.com/grailbio/base/traverse" ) func recovered(f func()) (v interface{}) { @@ -20,7 +26,7 @@ func recovered(f func()) (v interface{}) { func TestTraverse(t *testing.T) { list := make([]int, 5) - err := Each(5).Do(func(i int) error { + err := traverse.Each(5, func(i int) error { list[i] += i return nil }) @@ -31,7 +37,7 @@ func TestTraverse(t *testing.T) { t.Errorf("got %v, want %v", got, want) } expectedErr := errors.New("test error") - err = Each(5).Do(func(i int) error { + err = traverse.Each(5, func(i int) error { if i == 3 { return expectedErr } @@ -42,104 +48,147 @@ func TestTraverse(t *testing.T) { } } -func TestPanic(t *testing.T) { - expectedPanic := "panic in the disco!!" - f := func() { - Each(5).Do(func(i int) error { - if i == 3 { - panic(expectedPanic) - } - return nil - }) - } - v := recovered(f) - s, ok := v.(string) - if !ok { - t.Fatal("expected string") - } - if got, want := s, fmt.Sprintf("traverse child: %s", expectedPanic); !strings.HasPrefix(got, want) { - t.Errorf("got %q, want %q", got, want) - } -} - -func TestSharding(t *testing.T) { +func TestTraverseLarge(t *testing.T) { tests := []struct { - n int - nshards int + N int + Limit int }{ { - n: 5, - nshards: 5, + N: 1, + Limit: 1, }, { - n: 5, - nshards: 10, + N: 10, + Limit: 2, }, { - n: 5, - nshards: 2, + N: 2999999, + Limit: 5, }, { - n: 15, - nshards: 3, + N: 3000001, + Limit: 5, }, } - - for _, test := range tests { - expectedList := make([]int, test.n) - for i := range expectedList { - expectedList[i] = i + for testId, test := range tests { + data := make([]int32, test.N) + _ = traverse.Limit(test.Limit).Each(test.N, func(i int) error { + atomic.AddInt32(&data[i], 1) + return nil + }) + for i, d := range data { + if d != 1 { + t.Errorf("Test %d - Each. element %d is %d. Expected 1", testId, i, d) + break + } } - list := make([]int, test.n) - err := Each(test.n).Sharded(test.nshards).Do(func(i int) error { - list[i] += i + data = make([]int32, test.N) + _ = traverse.Limit(test.Limit).Range(test.N, func(i, j int) error { + for k := i; k < j; k++ { + atomic.AddInt32(&data[k], 1) + } return nil }) - if err != nil { - t.Fatal(err) + for i, d := range data { + if d != 1 { + t.Errorf("Test %d - Range. element %d is %d. Expected 1", testId, i, d) + break + } } - if !reflect.DeepEqual(list, expectedList) { - t.Errorf("got %v, want %v", list, expectedList) + + // Emulate a sequential writer. + // The test still passes if LimitSequential is replaced with Limit, but it + // should take noticeably longer to execute. + // (Note that we can't just e.g. guard 'data' with a mutex. Just because + // tasks are launched in numerical order does not mean that they will be + // completed in numerical order.) + data = data[:0] + const cachelineSize = 64 + var nextWriteIndex struct { + _ [cachelineSize - 8]byte + N int64 + _ [cachelineSize - 8]byte + } + _ = traverse.LimitSequential(test.Limit).Each(test.N, func(i int) error { + time.Sleep(50 * time.Nanosecond) + for { + j := atomic.LoadInt64(&nextWriteIndex.N) + if int(j) == i { + break + } + } + data = append(data, int32(i)) + _ = atomic.AddInt64(&nextWriteIndex.N, 1) + return nil + }) + for i, d := range data { + if int(d) != i { + t.Errorf("Test %d - LimitSequential. element %d is %d. Expected %d", testId, i, d, i) + break + } } - rangeList := make([]int, test.n) - err = Each(test.n).Sharded(test.nshards).DoRange(func(start, end int) error { + } +} + +func TestRange(t *testing.T) { + const N = 5000 + var ( + counts = make([]int64, N) + invocations int64 + ) + var tr traverse.T + for i := 0; i < N; i++ { + tr.Limit = rand.Intn(N*2) + 1 + err := tr.Range(N, func(start, end int) error { + if start < 0 || end > N || end < start { + return fmt.Errorf("invalid range [%d,%d)", start, end) + } + atomic.AddInt64(&invocations, 1) for i := start; i < end; i++ { - rangeList[i] += i + atomic.AddInt64(&counts[i], 1) } return nil }) if err != nil { - t.Fatal(err) + t.Errorf("limit %d: %v", tr.Limit, err) + continue } - if !reflect.DeepEqual(rangeList, expectedList) { - t.Errorf("DoRange failed: got %v, want %v", rangeList, expectedList) + expect := int64(tr.Limit) + if expect > N { + expect = N } - - // test error propagation - expectedErr := errors.New("test error") - err = Each(test.n).Sharded(test.nshards).Do(func(i int) error { - if i == test.n/2 { - return expectedErr + if got, want := invocations, expect; got != want { + t.Errorf("got %v, want %v", got, want) + } + invocations = 0 + for i := range counts { + if got, want := counts[i], int64(1); got != want { + t.Errorf("counts[%d,%d]: got %v, want %v", i, tr.Limit, got, want) } - return nil - }) - if got, want := err, expectedErr; got != want { - t.Errorf("got %v want %v", got, want) + counts[i] = 0 } + } +} - err = Each(test.n).Sharded(test.nshards).DoRange(func(start, end int) error { - for i := start; i < end; i++ { - if i == test.n/2 { - return expectedErr - } +func TestPanic(t *testing.T) { + expectedPanic := "panic in the disco!!" + f := func() { + _ = traverse.Each(5, func(i int) error { + if i == 3 { + panic(expectedPanic) } return nil }) - if got, want := err, expectedErr; got != want { - t.Errorf("got %v want %v", got, want) - } + } + v := recovered(f) + s, ok := v.(string) + if !ok { + t.Fatal("expected string") + } + if got, want := s, fmt.Sprintf("traverse child: %s", expectedPanic); !strings.HasPrefix(got, want) { + t.Errorf("got %q, want %q", got, want) } } @@ -148,20 +197,40 @@ type testStatus struct { } type testReporter struct { - statusHistory []testStatus + mu sync.Mutex + statusHistory []testStatus + queued, running, done int32 +} + +func (r *testReporter) Init(n int) { + r.update(int32(n), 0, 0) +} + +func (r *testReporter) Complete() {} + +func (r *testReporter) Begin(i int) { + r.update(-1, 1, 0) } -func (reporter *testReporter) Report(queued, running, done int32) { - reporter.statusHistory = - append(reporter.statusHistory, testStatus{queued: queued, running: running, done: done}) +func (r *testReporter) End(i int) { + r.update(0, -1, 1) +} + +func (r *testReporter) update(queued, running, done int32) { + r.mu.Lock() + defer r.mu.Unlock() + r.queued += queued + r.running += running + r.done += done + r.statusHistory = + append(r.statusHistory, testStatus{queued: r.queued, running: r.running, done: r.done}) } func TestReportingSingleJob(t *testing.T) { - reporter := testReporter{} + reporter := new(testReporter) - Each(5).Limit(1).WithReporter(&reporter).Do(func(i int) error { - return nil - }) + tr := traverse.T{Reporter: reporter, Limit: 1} + _ = tr.Each(5, func(i int) error { return nil }) expectedStatuses := []testStatus{ testStatus{queued: 5, running: 0, done: 0}, @@ -186,14 +255,13 @@ func TestReportingSingleJob(t *testing.T) { } func TestReportingManyJobs(t *testing.T) { - reporter := testReporter{} + reporter := new(testReporter) numJobs := 50 numConcurrent := 5 - Each(numJobs).Limit(numConcurrent).WithReporter(&reporter).Do(func(i int) error { - return nil - }) + tr := traverse.T{Limit: numConcurrent, Reporter: reporter} + _ = tr.Each(numJobs, func(i int) error { return nil }) // first status should be all jobs queued if (reporter.statusHistory[0] != testStatus{queued: int32(numJobs), running: 0, done: 0}) { @@ -228,7 +296,7 @@ func TestReportingManyJobs(t *testing.T) { } if status.queued > previousStatus.queued { - t.Errorf("Can't have queued jobs count increase - status: %v, prevoius status: %v", + t.Errorf("Can't have queued jobs count increase - status: %v, previous status: %v", status, previousStatus) } @@ -241,74 +309,27 @@ func TestReportingManyJobs(t *testing.T) { } func BenchmarkDo(b *testing.B) { - arr := make([]int, b.N) - for n := 0; n < b.N; n++ { - err := Each(n).Do(func(i int) error { - arr[i]++ - return nil + for _, n := range []int{1, 1e6, 1e8} { + b.Run(fmt.Sprintf("n=%d", n), func(b *testing.B) { + for k := 0; k < b.N; k++ { + err := traverse.Parallel.Each(n, func(i int) error { + return nil + }) + if err != nil { + b.Error(err) + } + } }) - if err != nil { - b.Error(err) - } } } -func benchmarkDo(n int, nshards int, b *testing.B) { - arr := make([]int, n) - for k := 0; k < b.N; k++ { - err := Each(n).Sharded(nshards).Do(func(i int) error { - arr[i]++ - return nil - }) - if err != nil { - b.Error(err) - } - } +//go:noinline +func fn(i int) error { + return nil } -func benchmarkDoRange(n int, nshards int, b *testing.B) { - arr := make([]int, n) +func BenchmarkInvoke(b *testing.B) { for k := 0; k < b.N; k++ { - err := Each(n).Sharded(nshards).DoRange(func(start, end int) error { - for i := start; i < end; i++ { - arr[i]++ - } - return nil - }) - if err != nil { - b.Error(err) - } + _ = fn(k) } } - -func BenchmarkDoShardSize1(b *testing.B) { - benchmarkDo(1000, 1000, b) -} - -func BenchmarkDoRangeShardSize1(b *testing.B) { - benchmarkDoRange(1000, 1000, b) -} - -func BenchmarkDoShardSize10(b *testing.B) { - benchmarkDo(1000, 100, b) -} - -func BenchmarkDoRangeShardSize10(b *testing.B) { - benchmarkDoRange(1000, 100, b) -} - -func BenchmarkDoShardSize100(b *testing.B) { - benchmarkDo(1000, 10, b) -} - -func BenchmarkDoRangeShardSize100(b *testing.B) { - benchmarkDoRange(1000, 10, b) -} - -func BenchmarkDoShardSize1000(b *testing.B) { - benchmarkDo(1000, 1, b) -} - -func BenchmarkDoRangeShardSize1000(b *testing.B) { - benchmarkDoRange(1000, 1, b) -} diff --git a/tsv/doc.go b/tsv/doc.go new file mode 100644 index 00000000..fb1c2912 --- /dev/null +++ b/tsv/doc.go @@ -0,0 +1,13 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +// Package tsv provides a simple TSV writer which takes care of number->string +// conversions and tabs, and is far more performant than fmt.Fprintf (thanks to +// use of strconv.Append{Uint,Float}). +// +// Usage is similar to bufio.Writer, except that in place of the usual Write() +// method, there are typed WriteString(), WriteUint32(), etc. methods which +// append one field at a time to the current line, and an EndLine() method to +// finish the line. +package tsv diff --git a/tsv/reader.go b/tsv/reader.go new file mode 100644 index 00000000..bc0d350a --- /dev/null +++ b/tsv/reader.go @@ -0,0 +1,422 @@ +package tsv + +import ( + "encoding/csv" + "fmt" + "io" + "reflect" + "sort" + "strconv" + "strings" + "unsafe" + + "github.com/grailbio/base/errors" +) + +type columnFormat struct { + fieldName string // Go struct field name. + columnName string // expected column name in TSV. Defaults to fieldName unless `tsv:"colname"` tag is set. + typ reflect.Type // Go type information of the column. + kind reflect.Kind // type of the column. + fmt string // Optional format directive for writing this value. + index int // index of this column in a row, 0-based. + offset uintptr // byte offset of this field within the Go struct. +} + +type rowFormat []columnFormat + +// Reader reads a TSV file. It wraps around the standard csv.Reader and allows +// parsing row contents into a Go struct directly. Thread compatible. +// +// TODO(saito) Support passing a custom bool parser. +// +// TODO(saito) Support a custom "NA" detector. +type Reader struct { + *csv.Reader + + // HasHeaderRow should be set to true to indicate that the input contains a + // single header row that lists column names of the rows that follow. It must + // be set before reading any data. + HasHeaderRow bool + + // UseHeaderNames causes the reader to set struct fields by matching column + // names to struct field names (or `tsv` tag). It must be set before reading + // any data. + // + // If not set, struct fields are filled in order, EVEN IF HasHeaderRow=true. + // If set, all struct fields must have a corresponding column in the file or + // IgnoreMissingColumns must also be set. An error will be reported through + // Read(). + // + // REQUIRES: HasHeaderRow=true + UseHeaderNames bool + + // RequireParseAllColumns causes Read() report an error if there are columns + // not listed in the passed-in struct. It must be set before reading any data. + // + // REQUIRES: HasHeaderRow=true + RequireParseAllColumns bool + + // IgnoreMissingColumns causes the reader to ignore any struct fields that are + // not present as columns in the file. It must be set before reading any + // data. + // + // REQUIRES: HasHeaderRow=true AND UseHeaderNames=true + IgnoreMissingColumns bool + + nRow int // # of rows read so far, excluding the header. + + // columnIndex x maps colname -> colindex (0-based). Filled from the header + // line. + columnIndex map[string]int + + cachedRowType reflect.Type + cachedRowFormat rowFormat +} + +// NewReader creates a new TSV reader that reads from the given input. +func NewReader(in io.Reader) *Reader { + r := &Reader{ + Reader: csv.NewReader(in), + } + r.Reader.Comma = '\t' + r.ReuseRecord = true + return r +} + +// Filter columns from the row format that are not present in the file being read. +func (r *Reader) filterRowFormat(format rowFormat) rowFormat { + var filtered rowFormat + for _, f := range format { + if _, ok := r.columnIndex[f.columnName]; ok { + filtered = append(filtered, f) + } + } + return filtered +} + +// Validates and canonicalizes the given row format object when column names +// are being used from the header row. This method may modify the input. +func (r *Reader) validateRowFormat(format rowFormat) (rowFormat, error) { + if r.IgnoreMissingColumns { + format = r.filterRowFormat(format) + } + if r.RequireParseAllColumns && len(format) != len(r.columnIndex) { + return format, fmt.Errorf("number of columns found in %+v does not match format %v", r.columnIndex, format) + } + for i := range format { + col := &format[i] + var ok bool + if col.index, ok = r.columnIndex[col.columnName]; !ok { + return format, fmt.Errorf("column %s does not appear in the header: %+v", col.columnName, r.columnIndex) + } + } + sort.Slice(format, func(i, j int) bool { + return format[i].index < format[j].index + }) + return format, nil +} + +func parseRowFormat(typ reflect.Type) (rowFormat, error) { + var format rowFormat + if typ.Kind() != reflect.Ptr || typ.Elem().Kind() != reflect.Struct { + return nil, fmt.Errorf("destination must be a pointer to struct, but found %v", typ) + } + typ = typ.Elem() + nField := typ.NumField() + for i := 0; i < nField; i++ { + f := typ.Field(i) + if f.PkgPath != "" { // Unexported field. + if tag := f.Tag.Get("tsv"); tag != "" { + return nil, fmt.Errorf("unexported field '%s' should not have a tsv tag '%s'", f.Name, tag) + } + // Unexported embedded (anonymous) struct is OK, but skip other fields. + if !f.Anonymous { + continue + } + } + // Fields from embedded structs are parsed recursively. + if f.Anonymous && f.Type.Kind() == reflect.Struct { + embeddedFormat, err := parseRowFormat(reflect.PtrTo(f.Type)) + if err != nil { + return nil, err + } + for _, col := range embeddedFormat { + col.offset += f.Offset // Shift offsets to be relative to the outer struct. + col.index = len(format) // Reset column index. + format = append(format, col) + } + continue + } + columnName := f.Name + var fmt string + if tag := f.Tag.Get("tsv"); tag != "" { + if tag == "-" { + continue + } + tagArray := strings.Split(tag, ",") + if tagArray[0] != "" { + columnName = tagArray[0] + } + for _, tag := range tagArray[1:] { + if strings.HasPrefix(tag, "fmt=") { + fmt = tag[4:] + } + } + } + format = append(format, columnFormat{ + fieldName: f.Name, + columnName: columnName, + typ: f.Type, + kind: f.Type.Kind(), + fmt: fmt, + index: len(format), + offset: f.Offset, + }) + } + return format, nil +} + +func (r *Reader) wrapError(err error, col columnFormat) error { + var name string + if col.columnName != col.fieldName { + name = fmt.Sprintf("'%s' (Go field '%s')", col.columnName, col.fieldName) + } else { + name = fmt.Sprintf("'%s'", col.columnName) + } + return errors.E(err, fmt.Sprintf("line %d, column %d, %s", r.nRow, col.index, name)) +} + +// fillRow fills Go struct fields from the TSV row. dest is the pointer to the +// struct, and format defines the struct format. +func (r *Reader) fillRow(val interface{}, row []string) error { + p := unsafe.Pointer(reflect.ValueOf(val).Pointer()) + if r.RequireParseAllColumns && len(r.cachedRowFormat) != len(row) { // check this for headerless TSVs + return fmt.Errorf("extra columns found in %+v", r.cachedRowFormat) + } + + for _, col := range r.cachedRowFormat { + if len(row) < col.index { + return r.wrapError(fmt.Errorf("row has only %d columns", len(row)), col) + } + colVal := row[col.index] + if col.fmt != "" { + // Not all format directives are recognized while scanning. Try to + // standardize some of the common options. + colfmt := col.fmt + if strings.ContainsAny(colfmt, "efg") { + // Standardize all base 10 floating point number formats to 'g', and + // drop precision and width which are not supported while scanning. + colfmt = "g" + } + if len(strings.Fields(colVal)) != 1 { + // Scanf functions tokenize by space. + return r.wrapError(fmt.Errorf("value with fmt option can not have whitespace"), col) + } + var ( + typ1 = col.typ + p1 = unsafe.Pointer(uintptr(p) + col.offset) + v = reflect.NewAt(typ1, p1).Interface() + n, err = fmt.Sscanf(colVal, "%"+colfmt, v) + ) + if err != nil { + return r.wrapError(err, col) + } + if n != 1 { + return r.wrapError(fmt.Errorf("%d objects scanned for %s; expected 1", n, colVal), col) + } + continue + } + switch col.kind { + case reflect.Bool: + var v bool + switch colVal { + case "Y", "yes": + v = true + case "N", "no": + v = false + default: + var err error + if v, err = strconv.ParseBool(colVal); err != nil { + return r.wrapError(err, col) + } + } + *(*bool)(unsafe.Pointer(uintptr(p) + col.offset)) = v + case reflect.String: + *(*string)(unsafe.Pointer(uintptr(p) + col.offset)) = colVal + case reflect.Int8: + v, err := strconv.ParseInt(colVal, 0, 8) + if err != nil { + return r.wrapError(err, col) + } + *(*int8)(unsafe.Pointer(uintptr(p) + col.offset)) = int8(v) + case reflect.Int16: + v, err := strconv.ParseInt(colVal, 0, 16) + if err != nil { + return r.wrapError(err, col) + } + *(*int16)(unsafe.Pointer(uintptr(p) + col.offset)) = int16(v) + case reflect.Int32: + v, err := strconv.ParseInt(colVal, 0, 32) + if err != nil { + return r.wrapError(err, col) + } + *(*int32)(unsafe.Pointer(uintptr(p) + col.offset)) = int32(v) + case reflect.Int64: + v, err := strconv.ParseInt(colVal, 0, 64) + if err != nil { + return r.wrapError(err, col) + } + *(*int64)(unsafe.Pointer(uintptr(p) + col.offset)) = v + case reflect.Int: + v, err := strconv.ParseInt(colVal, 0, 64) + if err != nil { + return r.wrapError(err, col) + } + *(*int)(unsafe.Pointer(uintptr(p) + col.offset)) = int(v) + case reflect.Uint8: + v, err := strconv.ParseUint(colVal, 0, 8) + if err != nil { + return r.wrapError(err, col) + } + *(*uint8)(unsafe.Pointer(uintptr(p) + col.offset)) = uint8(v) + case reflect.Uint16: + v, err := strconv.ParseUint(colVal, 0, 16) + if err != nil { + return r.wrapError(err, col) + } + *(*uint16)(unsafe.Pointer(uintptr(p) + col.offset)) = uint16(v) + case reflect.Uint32: + v, err := strconv.ParseUint(colVal, 0, 32) + if err != nil { + return r.wrapError(err, col) + + } + *(*uint32)(unsafe.Pointer(uintptr(p) + col.offset)) = uint32(v) + case reflect.Uint64: + v, err := strconv.ParseUint(colVal, 0, 64) + if err != nil { + return r.wrapError(err, col) + + } + *(*uint64)(unsafe.Pointer(uintptr(p) + col.offset)) = v + case reflect.Uint: + v, err := strconv.ParseUint(colVal, 0, 64) + if err != nil { + return r.wrapError(err, col) + } + *(*uint)(unsafe.Pointer(uintptr(p) + col.offset)) = uint(v) + + case reflect.Float32: + v, err := strconv.ParseFloat(colVal, 32) + if err != nil { + return r.wrapError(err, col) + + } + *(*float32)(unsafe.Pointer(uintptr(p) + col.offset)) = float32(v) + case reflect.Float64: + v, err := strconv.ParseFloat(colVal, 64) + if err != nil { + return r.wrapError(err, col) + + } + *(*float64)(unsafe.Pointer(uintptr(p) + col.offset)) = v + default: + return r.wrapError(fmt.Errorf("unsupported type %v", col.kind), col) + } + } + return nil +} + +// EmptyReadErrStr is the error-string returned by Read() when the file is +// empty, and at least a header line was expected. +const EmptyReadErrStr = "empty file: could not read the header row" + +// Read reads the next TSV row into a go struct. The argument must be a pointer +// to a struct. It parses each column in the row into the matching struct +// fields. +// +// Example: +// r := tsv.NewReader(...) +// ... +// type row struct { +// Col0 string +// Col1 int +// Float int +// } +// var v row +// err := r.Read(&v) +// +// +// If !Reader.HasHeaderRow or !Reader.UseHeaderNames, the N-th column (base +// zero) will be parsed into the N-th field in the struct. +// +// If Reader.HasHeaderRow and Reader.UseHeaderNames, then the struct's field +// name must match one of the column names listed in the first row in the TSV +// input. The contents of the column with the matching name will be parsed +// into the struct field. +// +// By default, the column name is the struct's field name, but you can override +// it by setting `tsv:"columnname"` tag in the field. The struct tag may also +// take an fmt option to specify how to parse the value using the fmt package. +// This is useful for parsing numbers written in a different base. Note that +// not all verbs are supported with the scanning functions in the fmt package. +// Using the fmt option may lead to slower performance. +// Imagine the following row type: +// +// type row struct { +// Chr string `tsv:"chromo"` +// Start int `tsv:"pos"` +// Length int +// Score int `tsv:"score,fmt=x"` +// } +// +// and the following TSV file: +// +// | chromo | Length | pos | score +// | chr1 | 1000 | 10 | 0a +// | chr2 | 950 | 20 | ff +// +// The first Read() will return row{"chr1", 10, 1000, 10}. +// +// The second Read() will return row{"chr2", 20, 950, 15}. +// +// Embedded structs are supported, and the default column name for nested +// fields will be the unqualified name of the field. +func (r *Reader) Read(v interface{}) error { + if r.nRow == 0 && r.HasHeaderRow { + headerRow, err := r.Reader.Read() + if err != nil { + if err == io.EOF { + err = errors.E(EmptyReadErrStr) + } + return err + } + r.nRow++ + r.columnIndex = map[string]int{} + for i, colName := range headerRow { + r.columnIndex[colName] = i + } + } + row, err := r.Reader.Read() + if err != nil { + return err + } + r.nRow++ + typ := reflect.TypeOf(v) + if typ != r.cachedRowType { + format, err := parseRowFormat(typ) + if err != nil { + return err + } + if r.UseHeaderNames { + format, err = r.validateRowFormat(format) + if err != nil { + return err + } + } + r.cachedRowType = typ + r.cachedRowFormat = format + } + return r.fillRow(v, row) +} diff --git a/tsv/reader_test.go b/tsv/reader_test.go new file mode 100644 index 00000000..0a53d038 --- /dev/null +++ b/tsv/reader_test.go @@ -0,0 +1,462 @@ +package tsv_test + +import ( + "bytes" + "fmt" + "io" + "testing" + + "github.com/grailbio/base/tsv" + "github.com/grailbio/testutil/assert" + "github.com/grailbio/testutil/expect" +) + +func TestReadBool(t *testing.T) { + read := func(data string) bool { + type row struct { + Col0 bool + } + r := tsv.NewReader(bytes.NewReader([]byte("col0\n" + data))) + r.HasHeaderRow = true + var v row + expect.NoError(t, r.Read(&v)) + return v.Col0 + } + + expect.True(t, read("true")) + expect.False(t, read("false")) + expect.True(t, read("Y")) + expect.True(t, read("yes")) + expect.False(t, read("N")) + expect.False(t, read("no")) +} + +func TestReadInt(t *testing.T) { + newReader := func() *tsv.Reader { + r := tsv.NewReader(bytes.NewReader([]byte(`col0 col1 +0 0.5 +`))) + r.HasHeaderRow = true + return r + } + + { + type row struct { + Col0 int8 + Col1 float32 + } + r := newReader() + var v row + expect.NoError(t, r.Read(&v)) + expect.EQ(t, v, row{0, 0.5}) + } + + { + type row struct { + Col0 int16 + Col1 float64 + } + r := newReader() + var v row + expect.NoError(t, r.Read(&v)) + expect.EQ(t, v, row{0, 0.5}) + } + + { + type row struct { + Col0 int32 + Col1 float64 + } + r := newReader() + var v row + expect.NoError(t, r.Read(&v)) + expect.EQ(t, v, row{0, 0.5}) + } + { + type row struct { + Col0 int64 + Col1 float64 + } + r := newReader() + var v row + expect.NoError(t, r.Read(&v)) + expect.EQ(t, v, row{0, 0.5}) + } + { + type row struct { + Col0 int + Col1 float64 + } + r := newReader() + var v row + expect.NoError(t, r.Read(&v)) + expect.EQ(t, v, row{0, 0.5}) + } + { + type row struct { + Col0 uint8 + Col1 float32 + } + r := newReader() + var v row + expect.NoError(t, r.Read(&v)) + expect.EQ(t, v, row{0, 0.5}) + } + + { + type row struct { + Col0 uint16 + Col1 float64 + } + r := newReader() + var v row + expect.NoError(t, r.Read(&v)) + expect.EQ(t, v, row{0, 0.5}) + } + + { + type row struct { + Col0 uint32 + Col1 float64 + } + r := newReader() + var v row + expect.NoError(t, r.Read(&v)) + expect.EQ(t, v, row{0, 0.5}) + } +} + +func TestReadFmt(t *testing.T) { + r := tsv.NewReader(bytes.NewReader([]byte(`"""helloworld""" 05.20 true 0a`))) + type row struct { + ColA string `tsv:",fmt=q"` + ColB float64 `tsv:",fmt=1.2f"` + ColC bool `tsv:",fmt=t"` + ColD int `tsv:",fmt=x"` + } + var v row + assert.NoError(t, r.Read(&v)) + assert.EQ(t, v, row{`helloworld`, 5.2, true, 10}) +} + +func TestReadFmtWithSpace(t *testing.T) { + r := tsv.NewReader(bytes.NewReader([]byte(`"hello world"`))) + type row struct { + ColA string `tsv:",fmt=s"` + } + var v row + expect.Regexp(t, r.Read(&v), "value with fmt option can not have whitespace") +} + +func TestReadWithoutHeader(t *testing.T) { + type row struct { + ColA string + ColB int + } + r := tsv.NewReader(bytes.NewReader([]byte(`key1 2 +key2 3 +`))) + var v row + assert.NoError(t, r.Read(&v)) + expect.EQ(t, v, row{"key1", 2}) + assert.NoError(t, r.Read(&v)) + expect.EQ(t, v, row{"key2", 3}) + assert.EQ(t, r.Read(&v), io.EOF) +} + +func TestReadSkipUnexportedFields(t *testing.T) { + type row struct { + colA string + colB int + ColC int `tsv:"col0"` + } + r := tsv.NewReader(bytes.NewReader([]byte(`key col0 col1 +key0 1 0.5 +key1 2 1.5 +`))) + r.HasHeaderRow = true + r.UseHeaderNames = true + var v row + assert.NoError(t, r.Read(&v)) + expect.EQ(t, v, row{"", 0, 1}) + assert.NoError(t, r.Read(&v)) + expect.EQ(t, v, row{"", 0, 2}) + assert.EQ(t, r.Read(&v), io.EOF) +} + +func TestReadEmbeddedStruct(t *testing.T) { + type embedded1 struct { + Col1 int `tsv:"col1"` + Col2 float64 `tsv:"col2_2,fmt=0.3f"` + } + type embedded2 struct { + Col2 float32 `tsv:"col2_1"` + } + type row struct { + Key string `tsv:"key"` + embedded1 + embedded2 + } + r := tsv.NewReader(bytes.NewReader([]byte(`key col2_1 col1 col2_2 +key0 0.5 1 0.123 +key1 1.5 2 0.789 +`))) + r.HasHeaderRow = true + r.UseHeaderNames = true + var v row + assert.NoError(t, r.Read(&v)) + expect.EQ(t, v, row{"key0", embedded1{1, 0.123}, embedded2{0.5}}) + assert.NoError(t, r.Read(&v)) + expect.EQ(t, v, row{"key1", embedded1{2, 0.789}, embedded2{1.5}}) + assert.EQ(t, r.Read(&v), io.EOF) +} + +func TestReadExtraColumns(t *testing.T) { + type row struct { + ColA string + ColB int + } + r := tsv.NewReader(bytes.NewReader([]byte(`key1 2 22 +key2 3 33 +`))) + r.RequireParseAllColumns = true + var v row + expect.Regexp(t, r.Read(&v), "extra columns found") +} + +func TestReadDisallowExtraNamedColumns(t *testing.T) { + type row struct { + ColA string + ColB int + } + r := tsv.NewReader(bytes.NewReader([]byte(`ColA ColB ColC +key1 2 22 +key2 3 33 +`))) + r.HasHeaderRow = true + r.UseHeaderNames = true + r.RequireParseAllColumns = true + var v row + expect.Regexp(t, r.Read(&v), "number of columns found") +} + +func TestReadMissingColumns(t *testing.T) { + type row struct { + ColA string + ColB int + } + r := tsv.NewReader(bytes.NewReader([]byte(`ColA +key1 +key2 +`))) + r.HasHeaderRow = true + r.UseHeaderNames = true + r.RequireParseAllColumns = true + var v row + expect.Regexp(t, r.Read(&v), "number of columns found") +} + +func TestReadMismatchedColumns(t *testing.T) { + type row struct { + ColA string + ColB int + } + r := tsv.NewReader(bytes.NewReader([]byte(`ColA ColC +key1 2 +key2 3 +`))) + r.HasHeaderRow = true + r.UseHeaderNames = true + r.RequireParseAllColumns = true + var v row + expect.Regexp(t, r.Read(&v), "does not appear in the header") +} + +func TestReadPartialStruct(t *testing.T) { + type row struct { + ColA string + ColB int + } + r := tsv.NewReader(bytes.NewReader([]byte(`ColA +key1 +key2 +`))) + r.HasHeaderRow = true + r.UseHeaderNames = true + r.RequireParseAllColumns = true + r.IgnoreMissingColumns = true + var v row + assert.NoError(t, r.Read(&v)) + expect.EQ(t, v, row{"key1", 0}) + assert.NoError(t, r.Read(&v)) + expect.EQ(t, v, row{"key2", 0}) + assert.EQ(t, r.Read(&v), io.EOF) +} + +func TestReadAllowExtraNamedColumns(t *testing.T) { + type row struct { + ColB int + ColA string + } + r := tsv.NewReader(bytes.NewReader([]byte(`ColA ColB ColC +key1 2 22 +key2 3 33 +`))) + r.HasHeaderRow = true + r.UseHeaderNames = true + var v row + expect.NoError(t, r.Read(&v)) + expect.EQ(t, v, row{2, "key1"}) + expect.NoError(t, r.Read(&v)) + expect.EQ(t, v, row{3, "key2"}) +} + +func TestReadParseError(t *testing.T) { + type row struct { + ColA int `tsv:"cola"` + ColB string `tsv:"colb"` + } + r := tsv.NewReader(bytes.NewReader([]byte(`key1 2 +`))) + var v row + expect.Regexp(t, r.Read(&v), `line 1, column 0, 'cola' \(Go field 'ColA'\):`) +} + +func TestReadValueError(t *testing.T) { + type row struct { + ColA string + ColB int + } + r := tsv.NewReader(bytes.NewReader([]byte(`key1 2 +key2 3 +`))) + var v int + expect.Regexp(t, r.Read(&v), `destination must be a pointer to struct, but found \*int`) + expect.Regexp(t, r.Read(v), `destination must be a pointer to struct, but found int`) +} + +func TestReadMultipleRowTypes(t *testing.T) { + r := tsv.NewReader(bytes.NewReader([]byte(`key1 2 +3 key2 +`))) + { + type row struct { + ColA string + ColB int + } + var v row + assert.NoError(t, r.Read(&v)) + expect.EQ(t, v, row{"key1", 2}) + } + { + type row struct { + ColA int + ColB string + } + var v row + assert.NoError(t, r.Read(&v)) + expect.EQ(t, v, row{3, "key2"}) + } +} + +func ExampleReader() { + type row struct { + Key string + Col0 uint + Col1 float64 + } + + readRow := func(r *tsv.Reader) row { + var v row + if err := r.Read(&v); err != nil { + panic(err) + } + return v + } + + r := tsv.NewReader(bytes.NewReader([]byte(`Key Col0 Col1 +key0 0 0.5 +key1 1 1.5 +`))) + r.HasHeaderRow = true + r.UseHeaderNames = true + fmt.Printf("%+v\n", readRow(r)) + fmt.Printf("%+v\n", readRow(r)) + + var v row + if err := r.Read(&v); err != io.EOF { + panic(err) + } + // Output: + // {Key:key0 Col0:0 Col1:0.5} + // {Key:key1 Col0:1 Col1:1.5} +} + +func ExampleReader_withTag() { + type row struct { + ColA string `tsv:"key"` + ColB float64 `tsv:"col1"` + Skipped int `tsv:"-"` + ColC int `tsv:"col0,fmt=d"` + Hex int `tsv:",fmt=x"` + Hyphen int `tsv:"-,"` + } + readRow := func(r *tsv.Reader) row { + var v row + if err := r.Read(&v); err != nil { + panic(err) + } + return v + } + + r := tsv.NewReader(bytes.NewReader([]byte(`key col0 col1 Hex - +key0 0 0.5 a 1 +key1 1 1.5 f 2 +`))) + r.HasHeaderRow = true + r.UseHeaderNames = true + fmt.Printf("%+v\n", readRow(r)) + fmt.Printf("%+v\n", readRow(r)) + + var v row + if err := r.Read(&v); err != io.EOF { + panic(err) + } + // Output: + // {ColA:key0 ColB:0.5 Skipped:0 ColC:0 Hex:10 Hyphen:1} + // {ColA:key1 ColB:1.5 Skipped:0 ColC:1 Hex:15 Hyphen:2} +} + +func BenchmarkReader(b *testing.B) { + b.StopTimer() + const nRow = 10000 + data := bytes.Buffer{} + for i := 0; i < nRow; i++ { + data.WriteString(fmt.Sprintf("key%d\t%d\t%f\n", i, i, float64(i)+0.5)) + } + b.StartTimer() + + type row struct { + Key string + Int int + Float float64 + } + for i := 0; i < b.N; i++ { + r := tsv.NewReader(bytes.NewReader(data.Bytes())) + var ( + val row + n int + ) + for { + err := r.Read(&val) + if err != nil { + if err == io.EOF { + break + } + panic(err) + } + n++ + } + assert.EQ(b, n, nRow) + } +} diff --git a/tsv/row_writer.go b/tsv/row_writer.go new file mode 100644 index 00000000..a6e86176 --- /dev/null +++ b/tsv/row_writer.go @@ -0,0 +1,144 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +package tsv + +import ( + "fmt" + "io" + "reflect" + "unsafe" +) + +// RowWriter writes structs to TSV files using field names or "tsv" tags +// as TSV column headers. +// +// TODO: Consider letting the caller filter or reorder columns. +type RowWriter struct { + w Writer + headerDone bool + cachedRowType reflect.Type + cachedRowFormat rowFormat +} + +// NewRowWriter constructs a writer. +// +// User must call Flush() after last Write(). +func NewRowWriter(w io.Writer) *RowWriter { + return &RowWriter{w: *NewWriter(w)} +} + +// Write writes a TSV row containing the values of v's exported fields. +// v must be a pointer to a struct. +// +// On first Write, a TSV header row is written using v's type. +// Subsequent Write()s may pass v of different type, but no guarantees are made +// about consistent column ordering with different types. +// +// By default, the column name is the struct's field name, but you can +// override it by setting `tsv:"columnname"` tag in the field. +// +// You can optionally specify an fmt option in the tag which will control how +// to format the value using the fmt package. Note that the reader may not +// support all the verbs. Without the fmt option, formatting options are preset +// for each type. Using the fmt option may lead to slower performance. +// +// Embedded structs are supported, and the default column name for nested +// fields will be the unqualified name of the field. +func (w *RowWriter) Write(v interface{}) error { + typ := reflect.TypeOf(v) + if typ != w.cachedRowType { + rowFormat, err := parseRowFormat(typ) + if err != nil { + return err + } + w.cachedRowType = typ + w.cachedRowFormat = rowFormat + } + if !w.headerDone { + if err := w.writeHeader(); err != nil { + return err + } + w.headerDone = true + } + return w.writeRow(v) +} + +// Flush flushes all previously-written rows. +func (w *RowWriter) Flush() error { + return w.w.Flush() +} + +func (w *RowWriter) writeHeader() error { + for _, col := range w.cachedRowFormat { + w.w.WriteString(col.columnName) + } + return w.w.EndLine() +} + +func (w *RowWriter) writeRow(v interface{}) error { + p := unsafe.Pointer(reflect.ValueOf(v).Pointer()) + for _, col := range w.cachedRowFormat { + if col.fmt != "" { + var ( + typ1 = col.typ + p1 = unsafe.Pointer(uintptr(p) + col.offset) + v = reflect.Indirect(reflect.NewAt(typ1, p1)) + ) + w.w.WriteString(fmt.Sprintf("%"+col.fmt, v)) + continue + } + switch col.kind { + case reflect.Bool: + v := *(*bool)(unsafe.Pointer(uintptr(p) + col.offset)) + if v { + w.w.WriteString("true") + } else { + w.w.WriteString("false") + } + case reflect.String: + v := *(*string)(unsafe.Pointer(uintptr(p) + col.offset)) + w.w.WriteString(v) + case reflect.Int8: + v := *(*int8)(unsafe.Pointer(uintptr(p) + col.offset)) + w.w.WriteInt64(int64(v)) + case reflect.Int16: + v := *(*int16)(unsafe.Pointer(uintptr(p) + col.offset)) + w.w.WriteInt64(int64(v)) + case reflect.Int32: + v := *(*int32)(unsafe.Pointer(uintptr(p) + col.offset)) + w.w.WriteInt64(int64(v)) + case reflect.Int64: + v := *(*int64)(unsafe.Pointer(uintptr(p) + col.offset)) + w.w.WriteInt64(int64(v)) + case reflect.Int: + v := *(*int)(unsafe.Pointer(uintptr(p) + col.offset)) + w.w.WriteInt64(int64(v)) + case reflect.Uint8: + v := *(*uint8)(unsafe.Pointer(uintptr(p) + col.offset)) + w.w.WriteUint64(uint64(v)) + case reflect.Uint16: + v := *(*uint16)(unsafe.Pointer(uintptr(p) + col.offset)) + w.w.WriteUint64(uint64(v)) + case reflect.Uint32: + v := *(*uint32)(unsafe.Pointer(uintptr(p) + col.offset)) + w.w.WriteUint64(uint64(v)) + case reflect.Uint64: + v := *(*uint64)(unsafe.Pointer(uintptr(p) + col.offset)) + w.w.WriteUint64(uint64(v)) + case reflect.Uint: + v := *(*uint)(unsafe.Pointer(uintptr(p) + col.offset)) + w.w.WriteUint64(uint64(v)) + case reflect.Float32: + v := *(*float32)(unsafe.Pointer(uintptr(p) + col.offset)) + w.w.WriteFloat64(float64(v), 'g', -1) + case reflect.Float64: + v := *(*float64)(unsafe.Pointer(uintptr(p) + col.offset)) + w.w.WriteFloat64(v, 'g', -1) + default: + return fmt.Errorf("unsupported type %v", col.kind) + } + } + return w.w.EndLine() +} diff --git a/tsv/row_writer_test.go b/tsv/row_writer_test.go new file mode 100644 index 00000000..48998957 --- /dev/null +++ b/tsv/row_writer_test.go @@ -0,0 +1,94 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package tsv_test + +import ( + "bytes" + "fmt" + "testing" + + "github.com/grailbio/base/tsv" +) + +func TestRowWriter(t *testing.T) { + var buf bytes.Buffer + rw := tsv.NewRowWriter(&buf) + type embedded struct { + EmbeddedString string `tsv:"estring"` + EmbeddedFloat float64 `tsv:"efloat,fmt=0.3f"` + } + var row struct { + Bool bool `tsv:"true_or_false"` + String string `tsv:"name"` + Int8 int8 + Int16 int16 + Int32 int32 + Int64 int64 + Int int + Uint8 uint8 + Uint16 uint16 + Uint32 uint32 + Uint64 uint64 + Uint uint + Float32 float32 + Float64 float64 + embedded + skippedString string + skippedFunc func() + } + row.String = "abc" + row.Float32 = -3 + row.Float64 = 1e300 + if err := rw.Write(&row); err != nil { + t.Error(err) + } + row.String = "def" + row.Int = 2 + row.Float32 = 0 + row.EmbeddedString = "estring" + row.EmbeddedFloat = 0.123456 + if err := rw.Write(&row); err != nil { + t.Error(err) + } + if err := rw.Flush(); err != nil { + t.Error(err) + } + got := buf.String() + want := `true_or_false name Int8 Int16 Int32 Int64 Int Uint8 Uint16 Uint32 Uint64 Uint Float32 Float64 estring efloat +false abc 0 0 0 0 0 0 0 0 0 0 -3 1e+300 0.000 +false def 0 0 0 0 2 0 0 0 0 0 0 1e+300 estring 0.123 +` + if got != want { + t.Errorf("got: %q, want %q", got, want) + } +} + +func ExampleRowWriter() { + type rowTyp struct { + Foo float64 `tsv:"foo,fmt=.2f"` + Bar float64 `tsv:"bar,fmt=.3f"` + Baz float64 + } + rows := []rowTyp{ + {Foo: 0.1234, Bar: 0.4567, Baz: 0.9876}, + {Foo: 1.1234, Bar: 1.4567, Baz: 1.9876}, + } + var buf bytes.Buffer + w := tsv.NewRowWriter(&buf) + for i := range rows { + if err := w.Write(&rows[i]); err != nil { + panic(err) + } + } + if err := w.Flush(); err != nil { + panic(err) + } + fmt.Print(string(buf.Bytes())) + + // Output: + // foo bar Baz + // 0.12 0.457 0.9876 + // 1.12 1.457 1.9876 +} diff --git a/tsv/writer.go b/tsv/writer.go new file mode 100644 index 00000000..d7b85576 --- /dev/null +++ b/tsv/writer.go @@ -0,0 +1,156 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache-2.0 +// license that can be found in the LICENSE file. + +package tsv + +import ( + "bufio" + "fmt" + "io" + "strconv" +) + +// Writer provides an efficient and concise way to append a field at a time to +// a TSV. However, note that it does NOT have a Write() method; the interface +// is deliberately restricted. +// +// We force this to fill at least one cacheline to prevent false sharing when +// make([]Writer, parallelism) is used. +type Writer struct { + w *bufio.Writer + line []byte + padding [32]byte // nolint: megacheck, structcheck, staticcheck +} + +// NewWriter creates a new tsv.Writer from an io.Writer. +func NewWriter(w io.Writer) (tw *Writer) { + return &Writer{ + w: bufio.NewWriter(w), + line: make([]byte, 0, 256), + } +} + +// WriteString appends the given string and a tab to the current line. (It is +// safe to use this to write multiple fields at a time.) +func (w *Writer) WriteString(s string) { + w.line = append(w.line, s...) + w.line = append(w.line, '\t') +} + +// WriteBytes appends the given []byte and a tab to the current line. +func (w *Writer) WriteBytes(s []byte) { + w.line = append(w.line, s...) + w.line = append(w.line, '\t') +} + +// WriteUint32 converts the given uint32 to a string, and appends that and a +// tab to the current line. +func (w *Writer) WriteUint32(ui uint32) { + w.WriteUint64(uint64(ui)) +} + +// WriteInt64 converts the given int64 to a string, and appends that and a +// tab to the current line. +func (w *Writer) WriteInt64(i int64) { + w.line = strconv.AppendInt(w.line, i, 10) + w.line = append(w.line, '\t') +} + +// WriteUint64 converts the given uint64 to a string, and appends that and a +// tab to the current line. +func (w *Writer) WriteUint64(ui uint64) { + w.line = strconv.AppendUint(w.line, ui, 10) + w.line = append(w.line, '\t') +} + +// WriteFloat64 converts the given float64 to a string with the given +// strconv.AppendFloat parameters, and appends that and a tab to the current +// line. +func (w *Writer) WriteFloat64(f float64, fmt byte, prec int) { + w.line = strconv.AppendFloat(w.line, f, fmt, prec, 64) + w.line = append(w.line, '\t') +} + +// WriteByte appends the given literal byte (no number->string conversion) and +// a tab to the current line. +func (w *Writer) WriteByte(b byte) { // "go vet" complaint expected + w.line = append(w.line, b) + w.line = append(w.line, '\t') +} + +// WritePartialString appends a string WITHOUT the usual subsequent tab. It +// must be followed by a non-Partial Write at some point to end the field; +// otherwise EndLine will clobber the last character. +func (w *Writer) WritePartialString(s string) { + w.line = append(w.line, s...) +} + +// WritePartialBytes appends a []byte WITHOUT the usual subsequent tab. It +// must be followed by a non-Partial Write at some point to end the field; +// otherwise EndLine will clobber the last character. +func (w *Writer) WritePartialBytes(s []byte) { + w.line = append(w.line, s...) +} + +// WritePartialUint32 converts the given uint32 to a string, and appends that +// WITHOUT the usual subsequent tab. It must be followed by a non-Partial +// Write at some point to end the field; otherwise EndLine will clobber the +// last character. +func (w *Writer) WritePartialUint32(ui uint32) { + w.line = strconv.AppendUint(w.line, uint64(ui), 10) +} + +// WritePartialByte appends the given literal byte (no number->string +// conversion) WITHOUT the usual subsequent tab. It must be followed by a +// non-Partial Write at some point to end the field; otherwise EndLine will +// clobber the last character. +func (w *Writer) WritePartialByte(b byte) { + w.line = append(w.line, b) +} + +// WriteCsvUint32 converts the given uint32 to a string, and appends that and a +// comma to the current line. +func (w *Writer) WriteCsvUint32(ui uint32) { + w.line = strconv.AppendUint(w.line, uint64(ui), 10) + w.line = append(w.line, ',') +} + +// WriteCsvByte appends the given literal byte (no number->string conversion) +// and a comma to the current line. +func (w *Writer) WriteCsvByte(b byte) { + w.line = append(w.line, b) + w.line = append(w.line, ',') +} + +// (Other Csv functions will be added as they're needed.) + +// EndCsv finishes the current comma-separated field, converting the last comma +// to a tab. It must be nonempty. +func (w *Writer) EndCsv() { + w.line[len(w.line)-1] = '\t' +} + +// EndLine finishes the current line. It must be nonempty. +func (w *Writer) EndLine() (err error) { + w.line[len(w.line)-1] = '\n' + // Tried making less frequent Write calls, doesn't seem to help. + _, err = w.w.Write(w.line) + w.line = w.line[:0] + return +} + +// Flush flushes all finished lines. +func (w *Writer) Flush() error { + return w.w.Flush() +} + +// Copy appends the entire contents of the given io.Reader (assumed to be +// another TSV file). +func (w *Writer) Copy(r io.Reader) error { + if len(w.line) != 0 { + return fmt.Errorf("Writer.Copy: current line is nonempty") + } + _, err := io.Copy(w.w, r) + return err +} diff --git a/tsv/writer_test.go b/tsv/writer_test.go new file mode 100644 index 00000000..2124e6d5 --- /dev/null +++ b/tsv/writer_test.go @@ -0,0 +1,34 @@ +// Copyright 2018 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package tsv_test + +import ( + "bytes" + "testing" + + "github.com/grailbio/base/tsv" +) + +func TestWriter(t *testing.T) { + var buf bytes.Buffer + tw := tsv.NewWriter(&buf) + tw.WriteString("field1") + tw.WriteUint32(2) + tw.WritePartialString("field") + tw.WriteByte('3') + tw.WriteBytes([]byte{'f', 'i', 'e', 'l', 'd', '4'}) + tw.WriteFloat64(1.2345, 'G', 6) + tw.WriteInt64(123456) + err := tw.EndLine() + if err != nil { + t.Errorf("Error while adding end of line") + } + tw.Flush() + got := buf.String() + want := "field1\t2\tfield3\tfield4\t1.2345\t123456\n" + if got != want { + t.Errorf("got: %q, want %q", got, want) + } +} diff --git a/unsafe/unsafe.go b/unsafe/unsafe.go index 8a1132b8..2e93f36a 100644 --- a/unsafe/unsafe.go +++ b/unsafe/unsafe.go @@ -29,17 +29,3 @@ func StringToBytes(src string) (d []byte) { dh.Cap = sh.Len return d } - -// ExtendBytes extends the given byte slice, without zero-initializing the new -// storage space. The caller must guarantee that cap(d) >= newLen (using e.g. -// a Grow() call on the parent buffer). -func ExtendBytes(dptr *[]byte, newLen int) { - // An earlier version of this function returned a new byte slice. However, I - // don't see a use case where you'd want to keep the old slice object, so - // I've changed the function to modify the slice object in-place. - if cap(*dptr) < newLen { - panic(newLen) - } - dh := (*reflect.SliceHeader)(unsafe.Pointer(dptr)) - dh.Len = newLen -} diff --git a/unsafe/unsafe_test.go b/unsafe/unsafe_test.go index 1d45e8dd..5c807420 100644 --- a/unsafe/unsafe_test.go +++ b/unsafe/unsafe_test.go @@ -38,22 +38,3 @@ func ExampleStringToBytes() { fmt.Println(unsafe.StringToBytes("AbC")) // Output: [65 98 67] } - -func TestExtendBytes(t *testing.T) { - for _, src := range []string{"aceg", "abcdefghi"} { - d := []byte(src) - dExt := d[:3] - unsafe.ExtendBytes(&dExt, len(src)) - if string(dExt) != src { - t.Error(dExt) - } - } -} - -func ExampleExtendBytes() { - d := []byte{'A', 'b', 'C'} - d = d[:1] - unsafe.ExtendBytes(&d, 2) - fmt.Println(d) - // Output: [65 98] -} diff --git a/vcontext/vcontext.go b/vcontext/vcontext.go index 58f927bf..72d7c49e 100644 --- a/vcontext/vcontext.go +++ b/vcontext/vcontext.go @@ -8,10 +8,12 @@ package vcontext import ( "sync" - "github.com/grailbio/base/grail" - "v.io/v23" + "github.com/grailbio/base/backgroundcontext" + "github.com/grailbio/base/shutdown" + _ "github.com/grailbio/v23/factories/grail" // Needed to initialize v23 + v23 "v.io/v23" "v.io/v23/context" - _ "v.io/x/ref/runtime/factories/grail" // Needed to initialize v23 + "v.io/x/ref/runtime/factories/library" ) var ( @@ -19,6 +21,10 @@ var ( ctx *context.T ) +func init() { + library.AllowMultipleInitializations = true +} + // Background returns the singleton Vanadium context for v23. It initializes v23 // on the first call. GRAIL applications should always use this function to // initialize and create a context instead of calling v23.Init() manually. @@ -27,9 +33,10 @@ var ( // production pipeline controller. Be extremely careful when changing it. func Background() *context.T { once.Do(func() { - var shutdown v23.Shutdown - ctx, shutdown = v23.Init() - grail.RegisterShutdownCallback(grail.Shutdown(shutdown)) + var done v23.Shutdown + ctx, done = v23.Init() + shutdown.Register(shutdown.Func(done)) + backgroundcontext.Set(ctx) }) return ctx } diff --git a/web/webutil/browser.go b/web/webutil/browser.go index 3e4c9e09..6892cb12 100644 --- a/web/webutil/browser.go +++ b/web/webutil/browser.go @@ -7,6 +7,7 @@ package webutil import ( + "os" "os/exec" "runtime" ) @@ -16,9 +17,14 @@ import ( func StartBrowser(url string) bool { // try to start the browser var args []string + aws_env := os.Getenv("AWS_ENV") switch runtime.GOOS { case "darwin": - args = []string{"open"} + if aws_env != "" { + args = []string{"open", "-na", "Google Chrome", "--args", "--profile-directory=" + aws_env, "--new-window"} + } else { + args = []string{"open"} + } case "windows": args = []string{"cmd", "/c", "start"} default: diff --git a/writehash/writehash.go b/writehash/writehash.go new file mode 100644 index 00000000..e10efd1e --- /dev/null +++ b/writehash/writehash.go @@ -0,0 +1,105 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +// Package writehash provides a set of utility functions to hash +// common types into hashes. +package writehash + +import ( + "encoding/binary" + "fmt" + "hash" + "io" + "math" +) + +func must(n int, err error) { + if err != nil { + panic(fmt.Sprintf("writehash: hash.Write returned unexpected error: %v", err)) + } +} + +// String encodes the string s into writer w. +func String(h hash.Hash, s string) { + must(io.WriteString(h, s)) +} + +// Int encodes the integer v into writer w. +func Int(h hash.Hash, v int) { + Uint64(h, uint64(v)) +} + +// Int16 encodes the 16-bit integer v into writer w. +func Int16(h hash.Hash, v int) { + Uint16(h, uint16(v)) +} + +// Int32 encodes the 32-bit integer v into writer w. +func Int32(h hash.Hash, v int32) { + Uint32(h, uint32(v)) +} + +// Int64 encodes the 64-bit integer v into writer w. +func Int64(h hash.Hash, v int64) { + Uint64(h, uint64(v)) +} + +// Uint encodes the unsigned integer v into writer w. +func Uint(h hash.Hash, v uint) { + Uint64(h, uint64(v)) +} + +// Uint16 encodes the unsigned 16-bit integer v into writer w. +func Uint16(h hash.Hash, v uint16) { + var buf [2]byte + binary.LittleEndian.PutUint16(buf[:], v) + must(h.Write(buf[:])) +} + +// Uint32 encodes the unsigned 32-bit integer v into writer w. +func Uint32(h hash.Hash, v uint32) { + var buf [4]byte + binary.LittleEndian.PutUint32(buf[:], v) + must(h.Write(buf[:])) +} + +// Uint64 encodes the unsigned 64-bit integer v into writer w. +func Uint64(h hash.Hash, v uint64) { + var buf [8]byte + binary.LittleEndian.PutUint64(buf[:], v) + must(h.Write(buf[:])) +} + +// Float32 encodes the 32-bit floating point number v into writer w. +func Float32(h hash.Hash, v float32) { + Uint32(h, math.Float32bits(v)) +} + +// Float64 encodes the 64-bit floating point number v into writer w. +func Float64(h hash.Hash, v float64) { + Uint64(h, math.Float64bits(v)) +} + +// Bool encodes the boolean v into writer w. +func Bool(h hash.Hash, v bool) { + if v { + Byte(h, 1) + } else { + Byte(h, 0) + } +} + +// Byte writes the byte b into writer w. +func Byte(h hash.Hash, b byte) { + if w, ok := h.(io.ByteWriter); ok { + must(0, w.WriteByte(b)) + return + } + must(h.Write([]byte{b})) +} + +// Run encodes the rune r into writer w. +func Rune(h hash.Hash, r rune) { + Int(h, int(r)) +} diff --git a/writehash/writehash_test.go b/writehash/writehash_test.go new file mode 100644 index 00000000..33d96c31 --- /dev/null +++ b/writehash/writehash_test.go @@ -0,0 +1,62 @@ +// Copyright 2019 GRAIL, Inc. All rights reserved. +// Use of this source code is governed by the Apache 2.0 +// license that can be found in the LICENSE file. + +package writehash_test + +import ( + "bytes" + "io" + "testing" + + "github.com/grailbio/base/writehash" +) + +type fakeHasher struct{ io.Writer } + +func (fakeHasher) Sum([]byte) []byte { panic("sum") } +func (fakeHasher) Reset() { panic("reset") } +func (fakeHasher) Size() int { panic("size") } +func (fakeHasher) BlockSize() int { panic("blocksize") } + +func TestWritehash(t *testing.T) { + b := new(bytes.Buffer) + var lastLen int + check := func(n int) { + t.Helper() + if got, want := b.Len()-lastLen, n; got != want { + t.Fatalf("got %v, want %v", got, want) + } + if bytes.Equal(b.Bytes()[lastLen:], make([]byte, n)) { + t.Error("wrote zeros") + } + lastLen = b.Len() + } + h := fakeHasher{b} + writehash.String(h, "hello world") + check(11) + writehash.Int(h, 1) + check(8) + writehash.Int16(h, 1) + check(2) + writehash.Int32(h, 1) + check(4) + writehash.Int64(h, 1) + check(8) + writehash.Uint(h, 1) + check(8) + writehash.Uint16(h, 1) + check(2) + writehash.Uint32(h, 1) + check(4) + writehash.Float32(h, 1) + check(4) + writehash.Float64(h, 1) + check(8) + writehash.Bool(h, true) + check(1) + writehash.Byte(h, 1) + check(1) + writehash.Rune(h, 'x') + check(8) +}