diff --git a/clihttp/client.go b/clihttp/client.go index b87ad0a7..29f96582 100644 --- a/clihttp/client.go +++ b/clihttp/client.go @@ -73,66 +73,72 @@ func NewClient(tracer opentracing.Tracer, options ...Option) *Client { // Do sends the request. func (c *Client) Do(req *http.Request) (*http.Response, error) { - req, tracer := nethttp.TraceRequest(c.tracer, req) - defer tracer.Finish() + clientSpan := c.tracer.StartSpan("HTTP Client") + defer clientSpan.Finish() - c.logRequest(req, tracer) + ext.SpanKindRPCClient.Set(clientSpan) + ext.HTTPUrl.Set(clientSpan, req.RequestURI) + ext.HTTPMethod.Set(clientSpan, req.Method) + // Inject the client span context into the headers + c.logRequest(req, clientSpan) + + c.tracer.Inject(clientSpan.Context(), opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(req.Header)) response, err := c.underlying.Do(req) if err != nil { return response, err } - c.logResponse(response, tracer) + c.logResponse(response, clientSpan) return response, err } -func (c *Client) logRequest(req *http.Request, tracer *nethttp.Tracer) { +func (c *Client) logRequest(req *http.Request, span opentracing.Span) { if req.Body == nil { return } body, err := req.GetBody() if err != nil { - ext.Error.Set(tracer.Span(), true) - tracer.Span().LogKV("error", errors.Wrap(err, "cannot get request body")) + ext.Error.Set(span, true) + span.LogKV("error", errors.Wrap(err, "cannot get request body")) return } length, _ := strconv.Atoi(req.Header.Get(http.CanonicalHeaderKey("Content-Length"))) if length > c.requestLogThreshold { - ext.Error.Set(tracer.Span(), true) - tracer.Span().LogKV("request", "elided: Content-Length too large") + ext.Error.Set(span, true) + span.LogKV("request", "elided: Content-Length too large") return } byt, err := ioutil.ReadAll(body) if err != nil { - ext.Error.Set(tracer.Span(), true) - tracer.Span().LogKV("error", errors.Wrap(err, "cannot read request body")) + ext.Error.Set(span, true) + span.LogKV("error", errors.Wrap(err, "cannot read request body")) return } - if tracer.Span() != nil { - tracer.Span().LogKV("request", string(byt)) + if span != nil { + span.LogKV("request", string(byt)) } } -func (c *Client) logResponse(response *http.Response, tracer *nethttp.Tracer) { +func (c *Client) logResponse(response *http.Response, span opentracing.Span) { if response.Body == nil { return } length, _ := strconv.Atoi(response.Header.Get(http.CanonicalHeaderKey("Content-Length"))) if length > c.responseLogThreshold { - tracer.Span().LogKV("response", "elided: Content-Length too large") + span.LogKV("response", "elided: Content-Length too large") return } var buf bytes.Buffer byt, err := ioutil.ReadAll(response.Body) if err != nil { - ext.Error.Set(tracer.Span(), true) - tracer.Span().LogFields(log.Error(err)) + ext.Error.Set(span, true) + span.LogFields(log.Error(err)) } - if tracer.Span() != nil { - tracer.Span().LogKV("response", string(byt)) + if span != nil { + span.LogKV("response", string(byt)) } buf.Write(byt) response.Body = ioutil.NopCloser(&buf) diff --git a/clihttp/client_test.go b/clihttp/client_test.go index 390071af..e6e03145 100644 --- a/clihttp/client_test.go +++ b/clihttp/client_test.go @@ -1,12 +1,13 @@ package clihttp import ( - "github.com/opentracing/opentracing-go" - "github.com/stretchr/testify/assert" "net/http" "strings" "testing" + "github.com/opentracing/opentracing-go" + "github.com/stretchr/testify/assert" + "github.com/opentracing/opentracing-go/mocktracer" ) @@ -49,27 +50,16 @@ func TestClient_Do(t *testing.T) { } func TestClient_race(t *testing.T) { - cases := []struct { - name string - request *http.Request - Option []Option - }{ - { - "normal", - func() *http.Request { r, _ := http.NewRequest("GET", "https://baidu.com", nil); return r }(), - []Option{}, - }, - } - for _, c := range cases { - c := c - // the mock tracer is not concurrent safe. - tracer := opentracing.GlobalTracer() - client := NewClient(tracer, c.Option...) - for i := 0; i < 10; i++ { - t.Run(c.name, func(t *testing.T) { - t.Parallel() - _, _ = client.Do(c.request) - }) - } + // the mock tracer is not concurrent safe. + //tracer := opentracing.GlobalTracer() + tracer := opentracing.NoopTracer{} + client := NewClient(tracer) + for i := 0; i < 100; i++ { + t.Run("", func(t *testing.T) { + t.Parallel() + r, _ := http.NewRequest("GET", "https://baidu.com", nil) + _, _ = client.Do(r) + }) } + } diff --git a/config/config_test.go b/config/config_test.go index 7792f777..216c49c2 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -52,16 +52,22 @@ func TestKoanfAdapter_Watch(t *gotesting.T) { assert.Equal(t, "baz", ka.String("foo")) ctx, cancel := context.WithCancel(context.Background()) defer cancel() + + var ch = make(chan struct{}) go func() { ka.watcher.Watch(ctx, func() error { + assert.NoError(t, ka.Reload(), "reload should be successful") err := ka.Reload() fmt.Println(err) + ch <- struct{}{} return nil }) }() time.Sleep(time.Second) ioutil.WriteFile(f.Name(), []byte("foo: bar"), 0644) ioutil.WriteFile(f.Name(), []byte("foo: bar"), 0644) + <-ch + // The following test is flaky on CI. Unable to reproduce locally. /* diff --git a/dtx/correlation_id.go b/dtx/correlation_id.go new file mode 100644 index 00000000..007b0835 --- /dev/null +++ b/dtx/correlation_id.go @@ -0,0 +1,6 @@ +package dtx + +type correlationIDType string + +// CorrelationID is an identifier to correlate transactions in context. +const CorrelationID correlationIDType = "CorrelationID" diff --git a/dtx/doc.go b/dtx/doc.go new file mode 100644 index 00000000..74f2c2cf --- /dev/null +++ b/dtx/doc.go @@ -0,0 +1,43 @@ +/* +Package dtx contains common utilities in the context of distributed transaction. + +Context Passing + +It is curial for all parties in the distributed transaction to share an +transaction id. This package provides utility to pass this id across services. + + HTTPToContext() http.RequestFunc + ContextToHTTP() http.RequestFunc + GRPCToContext() grpc.ServerRequestFunc + ContextToGRPC() grpc.ClientRequestFunc + +Idempotency + +Certain operations will be retried by the client more than once. A middleware is +provided for the server to shield against repeated request in the same +transaction. + + func MakeIdempotence(s Oncer) endpoint.Middleware + +Lock + +Certain resource in transaction cannot be concurrently accessed. A middleware is +provided to lock such resources. + + func MakeLock(l Locker) endpoint.Middleware + +Allow Null Compensation and Prevent Resource Suspension + +Transaction participants may receive the compensation +order before performing normal operations due to network exceptions. In this +case, null compensation is required. + +If the forward operation arrives later than the compensating operation due to +network exceptions, the forward operation must be discarded. Otherwise, resource +suspension occurs. + + func MakeAttempt(s Sequencer) endpoint.Middleware + func MakeCancel(s Sequencer) endpoint.Middleware + +*/ +package dtx diff --git a/dtx/middleware.go b/dtx/middleware.go new file mode 100644 index 00000000..e820dcb4 --- /dev/null +++ b/dtx/middleware.go @@ -0,0 +1,110 @@ +package dtx + +import ( + "context" + + "github.com/go-kit/kit/endpoint" + "github.com/pkg/errors" +) + +// ErrNonIdempotent is returned when an endpoint is requested more than once with the same CorrelationID. +var ErrNonIdempotent = errors.New("rejected repeated request") + +// ErrNoLock is returned when the endpoint fail to fetch the distributed lock under the same CorrelationID. +var ErrNoLock = errors.New("failed to grab lock") + +// Oncer should return true if the key has been observed before. +type Oncer interface { + Once(ctx context.Context, key string) bool +} + +// MakeIdempotence returns a middleware that ensures the next endpoint can only be executed once per CorrelationID. +func MakeIdempotence(s Oncer) endpoint.Middleware { + return func(e endpoint.Endpoint) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + correlationID, ok := ctx.Value(CorrelationID).(string) + if !ok { + return e(ctx, request) + } + if s.Once(ctx, correlationID) { + return nil, ErrNonIdempotent + } + return e(ctx, request) + } + } +} + +// Locker is an interface for the distributed lock. +type Locker interface { + // Lock should return true only when it successfully grabs the lock. + Lock(ctx context.Context, key string) bool + Unlock(ctx context.Context, key string) +} + +// MakeLock returns a middleware that ensures the next endpoint is never concurrently accessed. +func MakeLock(l Locker) endpoint.Middleware { + return func(e endpoint.Endpoint) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + correlationID, ok := ctx.Value(CorrelationID).(string) + if !ok { + return e(ctx, request) + } + if l.Lock(ctx, correlationID) { + defer l.Unlock(ctx, correlationID) + return e(ctx, request) + } + return nil, ErrNoLock + } + } +} + +// Sequencer is an interface that shields against the disordering of +// attempt and cancel in a transactional context. +type Sequencer interface { + MarkCancelledCheckAttempted(context.Context, string) bool + MarkAttemptedCheckCancelled(context.Context, string) bool +} + +// MakeAttempt returns a middleware that wraps around an attempt endpoint. If the +// this segment of the distributed transaction is already cancelled, the next +// endpoint will never be executed. +// +// If the forward operation arrives later than the compensating operation due to +// network exceptions, the forward operation must be discarded. Otherwise, +// resource suspension occurs. +func MakeAttempt(s Sequencer) endpoint.Middleware { + return func(e endpoint.Endpoint) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + correlationID, ok := ctx.Value(CorrelationID).(string) + if !ok { + return e(ctx, request) + } + if s.MarkAttemptedCheckCancelled(ctx, correlationID) { + return nil, nil + } + return e(ctx, request) + } + } +} + +// MakeCancel returns a middleware that wraps around the cancellation endpoint. +// It guarantees if this segment of the distributed transaction is never attempted, +// the cancellation endpoint will not be executed. +// +// Transaction participants may receive the compensation order before performing +// normal operations due to network exceptions. In this case, null compensation +// is required. +func MakeCancel(s Sequencer) endpoint.Middleware { + return func(e endpoint.Endpoint) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + correlationID, ok := ctx.Value(CorrelationID).(string) + if !ok { + return e(ctx, request) + } + if !s.MarkCancelledCheckAttempted(ctx, correlationID) { + return nil, nil + } + return e(ctx, request) + } + } +} diff --git a/dtx/middleware_test.go b/dtx/middleware_test.go new file mode 100644 index 00000000..311b2e6f --- /dev/null +++ b/dtx/middleware_test.go @@ -0,0 +1,223 @@ +package dtx + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func ep(ctx context.Context, req interface{}) (resp interface{}, err error) { + return req, nil +} + +type oncer func(ctx context.Context, key string) bool + +func (o oncer) Once(ctx context.Context, key string) bool { + return o(ctx, key) +} + +type locker struct { + Lockf func(ctx context.Context, key string) bool + Unlockf func(ctx context.Context, key string) +} + +func (l locker) Lock(ctx context.Context, key string) bool { + return l.Lockf(ctx, key) +} + +func (l locker) Unlock(ctx context.Context, key string) { + l.Unlockf(ctx, key) +} + +type transactioner struct { + attempt func(ctx context.Context, s string) bool + cancel func(ctx context.Context, s string) bool +} + +func (t transactioner) MarkCancelledCheckAttempted(ctx context.Context, s string) bool { + return t.attempt(ctx, s) +} + +func (t transactioner) MarkAttemptedCheckCancelled(ctx context.Context, s string) bool { + return t.cancel(ctx, s) +} + +func TestMakeIdempotence(t *testing.T) { + t.Run("without context", func(t *testing.T) { + t.Parallel() + var attempt = 0 + var s = func(ctx context.Context, key string) bool { + if attempt == 0 { + attempt++ + return false + } + return true + } + m := MakeIdempotence(oncer(s)) + f := m(ep) + resp, err := f(context.Background(), 1) + assert.NoError(t, err) + assert.Equal(t, 1, resp) + }) + + t.Run("with context", func(t *testing.T) { + t.Parallel() + var attempt = 0 + var s = func(ctx context.Context, key string) bool { + if attempt == 0 { + attempt++ + return false + } + return true + } + m := MakeIdempotence(oncer(s)) + f := m(ep) + ctx := context.WithValue(context.Background(), CorrelationID, "foobar") + resp, err := f(ctx, 1) + assert.NoError(t, err) + assert.Equal(t, 1, resp) + resp, err = f(ctx, 2) + assert.Error(t, err) + assert.Equal(t, nil, resp) + }) +} + +func TestMakeLock(t *testing.T) { + t.Run("no context", func(t *testing.T) { + t.Parallel() + var lock = locker{ + Lockf: func(ctx context.Context, key string) bool { + return true + }, + Unlockf: func(ctx context.Context, key string) { + }, + } + m := MakeLock(lock) + f := m(ep) + resp, err := f(context.Background(), 1) + assert.NoError(t, err) + assert.Equal(t, 1, resp) + }) + + t.Run("with context", func(t *testing.T) { + t.Parallel() + var lock = locker{ + Lockf: func(ctx context.Context, key string) bool { + return true + }, + Unlockf: func(ctx context.Context, key string) { + }, + } + m := MakeLock(lock) + f := m(ep) + ctx := context.WithValue(context.Background(), CorrelationID, "foobar") + resp, err := f(ctx, 1) + assert.NoError(t, err) + assert.Equal(t, 1, resp) + }) + + t.Run("failed to grab lock", func(t *testing.T) { + t.Parallel() + var lock = locker{ + Lockf: func(ctx context.Context, key string) bool { + return false + }, + Unlockf: func(ctx context.Context, key string) { + }, + } + m := MakeLock(lock) + f := m(ep) + ctx := context.WithValue(context.Background(), CorrelationID, "foobar") + + resp, err := f(ctx, 2) + assert.Error(t, err) + assert.Equal(t, nil, resp) + }) +} + +func TestMakeAttempt(t *testing.T) { + t.Run("no context", func(t *testing.T) { + t.Parallel() + var tr = transactioner{ + attempt: func(ctx context.Context, key string) bool { + return false + }, + cancel: func(ctx context.Context, key string) bool { + return false + }, + } + f := MakeAttempt(tr)(ep) + g := MakeCancel(tr)(ep) + resp, err := f(context.Background(), 1) + assert.NoError(t, err) + assert.Equal(t, 1, resp) + resp, err = g(context.Background(), 1) + assert.NoError(t, err) + assert.Equal(t, 1, resp) + }) + + t.Run("with context, attempted", func(t *testing.T) { + t.Parallel() + ctx := context.WithValue(context.Background(), CorrelationID, "foobar") + var tr = transactioner{ + attempt: func(ctx context.Context, key string) bool { + return true + }, + cancel: func(ctx context.Context, key string) bool { + return false + }, + } + f := MakeAttempt(tr)(ep) + g := MakeCancel(tr)(ep) + + resp, err := f(ctx, 1) + assert.NoError(t, err) + assert.Equal(t, 1, resp) + + resp, err = g(ctx, 1) + assert.NoError(t, err) + assert.Equal(t, 1, resp) + }) + + t.Run("with context, not attempted", func(t *testing.T) { + t.Parallel() + ctx := context.WithValue(context.Background(), CorrelationID, "foobar") + var tr = transactioner{ + attempt: func(ctx context.Context, key string) bool { + return false + }, + cancel: func(ctx context.Context, key string) bool { + return false + }, + } + f := MakeAttempt(tr)(ep) + g := MakeCancel(tr)(ep) + + resp, err := f(ctx, 1) + assert.NoError(t, err) + assert.Equal(t, 1, resp) + + resp, err = g(ctx, 1) + assert.NoError(t, err) + assert.Equal(t, nil, resp) + }) + + t.Run("with context, cancelled", func(t *testing.T) { + t.Parallel() + ctx := context.WithValue(context.Background(), CorrelationID, "foobar") + var tr = transactioner{ + attempt: func(ctx context.Context, key string) bool { + return false + }, + cancel: func(ctx context.Context, key string) bool { + return true + }, + } + f := MakeAttempt(tr)(ep) + + resp, err := f(ctx, 1) + assert.NoError(t, err) + assert.Equal(t, nil, resp) + }) +} diff --git a/dtx/redis_store.go b/dtx/redis_store.go new file mode 100644 index 00000000..d2ae409f --- /dev/null +++ b/dtx/redis_store.go @@ -0,0 +1,79 @@ +package dtx + +import ( + "context" + "time" + + "github.com/DoNewsCode/core/contract" + "github.com/go-redis/redis/v8" +) + +// RedisStore is an implementation of Oncer, Locker and Sequencer. +type RedisStore struct { + keyer contract.Keyer + client redis.UniversalClient +} + +// MarkCancelledCheckAttempted returns true if the CorrelationID has been attempted before. +// It also marks the CorrelationID as cancelled. +func (r RedisStore) MarkCancelledCheckAttempted(ctx context.Context, s string) bool { + b, _ := r.client.Eval(ctx, ` +redis.call('SET', KEYS[1], "1", "EX", "86400") +if redis.call('EXISTS', KEYS[2]) == 1 then + return 1 +end +return 0 +`, []string{r.keyer.Key(":", "cancel", s), r.keyer.Key(":", "attempt", s)}).Bool() + return b +} + +// MarkAttemptedCheckCancelled returns true if the CorrelationID has been cancelled before. +// It also marks the CorrelationID as attempted. +func (r RedisStore) MarkAttemptedCheckCancelled(ctx context.Context, s string) bool { + b, _ := r.client.Eval(ctx, ` +redis.call('SET', KEYS[1], "1", "EX", "86400") +if redis.call('EXISTS', KEYS[2]) == 1 then + return 1 +end +return 0 +`, []string{r.keyer.Key(":", "attempt", s), r.keyer.Key(":", "cancel", s)}).Bool() + return b +} + +// Lock grabs the lock for the given key. It returns true if the lock is +// successfully acquired. If the lock is not available, this method will block until +// the lock is released or the context expired. In latter case, false is +// returned. +func (r RedisStore) Lock(ctx context.Context, key string) bool { + var expiration = time.Minute + if deadline, ok := ctx.Deadline(); ok { + expiration = deadline.Sub(time.Now()) + } + for { + ok, err := r.client.SetNX(ctx, r.keyer.Key(":", "lock", key), "1", expiration).Result() + if err == nil && ok { + return true + } + if ctx.Err() != nil { + return false + } + select { + case <-time.After(time.Second): + case <-ctx.Done(): + return false + } + } +} + +// Unlock unlocks the lock named by key. +func (r RedisStore) Unlock(ctx context.Context, key string) { + r.client.Del(ctx, r.keyer.Key(":", "lock", key)) +} + +// Once returns true if this method has been called before with the given key. If +// not, it internally set the key as called and +// returns false. +func (r RedisStore) Once(ctx context.Context, key string) bool { + _, err := r.client.GetSet(ctx, r.keyer.Key(":", "once", key), "1").Result() + return err != redis.Nil +} diff --git a/dtx/redis_store_test.go b/dtx/redis_store_test.go new file mode 100644 index 00000000..9fa1b384 --- /dev/null +++ b/dtx/redis_store_test.go @@ -0,0 +1,63 @@ +// +build integration + +package dtx + +import ( + "context" + "testing" + "time" + + "github.com/DoNewsCode/core/key" + "github.com/go-redis/redis/v8" + "github.com/stretchr/testify/assert" +) + +func TestOnce(t *testing.T) { + s := RedisStore{ + keyer: key.New(), + client: redis.NewUniversalClient(&redis.UniversalOptions{}), + } + ctx := context.Background() + defer s.client.Del(ctx, "once:foobar") + + assert.False(t, s.Once(ctx, "foobar")) + assert.True(t, s.Once(ctx, "foobar")) + assert.True(t, s.Once(ctx, "foobar")) + +} + +func TestLock(t *testing.T) { + s := RedisStore{ + keyer: key.New(), + client: redis.NewUniversalClient(&redis.UniversalOptions{}), + } + ctx := context.Background() + defer s.client.Del(ctx, "lock:foobar") + + assert.True(t, s.Lock(ctx, "foobar")) + + ctx, cancel := context.WithTimeout(ctx, time.Millisecond) + defer cancel() + assert.False(t, s.Lock(ctx, "foobar")) + + s.Unlock(ctx, "foobar") +} + +func TestRedisStore_MarkAttemptedCheckCancelled(t *testing.T) { + s := RedisStore{ + keyer: key.New(), + client: redis.NewUniversalClient(&redis.UniversalOptions{}), + } + ctx := context.Background() + defer s.client.Del(ctx, "attempt:foobar") + defer s.client.Del(ctx, "cancel:foobar") + + assert.False(t, s.MarkCancelledCheckAttempted(ctx, "foobar")) + assert.True(t, s.MarkAttemptedCheckCancelled(ctx, "foobar")) + + s.client.Del(ctx, "attempt:foobar") + s.client.Del(ctx, "cancel:foobar") + + assert.False(t, s.MarkAttemptedCheckCancelled(ctx, "foobar")) + assert.True(t, s.MarkCancelledCheckAttempted(ctx, "foobar")) +} diff --git a/dtx/sagas/dependency.go b/dtx/sagas/dependency.go new file mode 100644 index 00000000..b43e3735 --- /dev/null +++ b/dtx/sagas/dependency.go @@ -0,0 +1,118 @@ +package sagas + +import ( + "context" + "time" + + "github.com/DoNewsCode/core/config" + "github.com/DoNewsCode/core/contract" + "github.com/DoNewsCode/core/di" + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" + "github.com/oklog/run" +) + +/* +Providers returns a set of dependency providers. + Depends On: + contract.ConfigAccessor + log.Logger + Store `optional:"true"` + []*Step `group:"saga"` + Provide: + *Registry + SagaEndpoints +*/ +func Providers() di.Deps { + return []interface{}{provide, provideConfig} +} + +// in is the injection parameter for saga module. +type in struct { + di.In + + Conf contract.ConfigAccessor + Logger log.Logger + Store Store `optional:"true"` + Steps []*Step `group:"saga"` +} + +type recoverInterval time.Duration + +// SagaEndpoints is a collection of all registered endpoint in the saga registry +type SagaEndpoints map[string]endpoint.Endpoint + +type out struct { + di.Out + di.Module + + Registry *Registry + Interval recoverInterval + SagaEndpoints SagaEndpoints +} + +// provide creates a new saga module. +func provide(in in) out { + if in.Store == nil { + in.Store = NewInProcessStore() + } + timeoutSec := in.Conf.Float64("sagas.sagaTimeoutSecond") + if timeoutSec == 0 { + timeoutSec = 600 + } + registry := NewRegistry( + in.Store, + WithLogger(in.Logger), + WithTimeout(time.Duration(timeoutSec)*time.Second), + ) + eps := make(SagaEndpoints) + + for i := range in.Steps { + eps[in.Steps[i].Name] = registry.AddStep(in.Steps[i]) + } + + recoverSec := in.Conf.Float64("sagas.recoverIntervalSecond") + if recoverSec == 0 { + recoverSec = 60 + } + return out{Registry: registry, Interval: recoverInterval(time.Duration(recoverSec) * time.Second), SagaEndpoints: eps} +} + +// ProvideRunGroup implements the RunProvider. +func (m out) ProvideRunGroup(group *run.Group) { + ctx, cancel := context.WithCancel(context.Background()) + ticker := time.NewTicker(time.Duration(m.Interval)) + group.Add(func() error { + m.Registry.Recover(ctx) + for { + select { + case <-ticker.C: + m.Registry.Recover(ctx) + case <-ctx.Done(): + return nil + } + } + }, func(err error) { + cancel() + ticker.Stop() + }) +} + +type configOut struct { + Config []config.ExportedConfig +} + +func provideConfig() configOut { + return configOut{Config: []config.ExportedConfig{ + { + Owner: "sagas", + Data: map[string]interface{}{ + "sagas": map[string]interface{}{ + "sagaTimeoutSecond": "600", + "recoverIntervalSecond": "60", + }, + }, + Comment: "The saga config", + }, + }} +} diff --git a/dtx/sagas/dependency_test.go b/dtx/sagas/dependency_test.go new file mode 100644 index 00000000..806124ab --- /dev/null +++ b/dtx/sagas/dependency_test.go @@ -0,0 +1,64 @@ +package sagas + +import ( + "context" + "testing" + "time" + + "github.com/DoNewsCode/core" + "github.com/DoNewsCode/core/di" + "github.com/ghodss/yaml" + "github.com/oklog/run" + "github.com/stretchr/testify/assert" +) + +type sagas struct { + di.Out + + Step *Step `group:"saga"` +} + +func TestNew(t *testing.T) { + t.Parallel() + var g run.Group + c := core.Default() + c.Provide(Providers()) + c.Provide(di.Deps{func() sagas { + return sagas{ + Step: &Step{ + Name: "bar", + Do: func(ctx context.Context, request interface{}) (response interface{}, err error) { + return 1, nil + }, + Undo: func(ctx context.Context, req interface{}) (response interface{}, err error) { + return nil, nil + }, + }, + } + }}) + c.Invoke(func(r *Registry, endpoints SagaEndpoints) { + tx, ctx := r.StartTX(context.Background()) + resp, _ := endpoints["bar"](ctx, nil) + assert.Equal(t, 1, resp) + tx.Commit(ctx) + c.ApplyRunGroup(&g) + timeout(time.Second, &g) + assert.NoError(t, g.Run()) + }) +} + +func TestExportedConfigs(t *testing.T) { + conf := provideConfig() + _, err := yaml.Marshal(conf) + assert.NoError(t, err) +} + +func timeout(duration time.Duration, g *run.Group) { + ctx, cancel := context.WithTimeout(context.Background(), duration) + g.Add(func() error { + <-ctx.Done() + return nil + }, func(err error) { + cancel() + }) +} diff --git a/dtx/sagas/doc.go b/dtx/sagas/doc.go new file mode 100644 index 00000000..64607b2c --- /dev/null +++ b/dtx/sagas/doc.go @@ -0,0 +1,91 @@ +/* +Package sagas implements the orchestration based saga pattern. +See https://microservices.io/patterns/data/saga.html + +Introduction + +A saga is a sequence of local transactions. Each local transaction updates the +database and publishes a message or event to trigger the next local +transaction in the saga. If a local transaction fails because it violates a +business rule then the saga executes a series of compensating transactions +that undo the changes that were made by the preceding local transactions. + +Usage + +The saga is managed by sagas.Registry. Each saga step has an forward operation +and a rollback counterpart. They must be registered beforehand by calling +Registry.AddStep. A new endpoint will be returned to the caller. Use the +returned endpoint to perform transactional operation. + + store := sagas.NewInProcessStore() + registry := sagas.NewRegistry(store) + addOrder := registry.AddStep(&sagas.Step{ + Name: "Add Order", + Do: func(ctx context.Context, request interface{}) (response interface{}, err error) { + resp, err := orderEndpoint(ctx, request.(OrderRequest)) + if err != nil { + return nil, err + } + return resp, nil + }, + Undo: func(ctx context.Context, req interface{}) (response interface{}, err error) { + return orderCancelEndpoint(ctx, req) + }, + }) + makePayment := registry.AddStep(&sagas.Step{ + Name: "Make Payment", + Do: func(ctx context.Context, request interface{}) (response interface{}, err error) { + resp, err := paymentEndpoint(ctx, request.(PaymentRequest)) + if err != nil { + return nil, err + } + return resp, nil + }, + Undo: func(ctx context.Context, req interface{}) (response interface{}, err error) { + return paymentCancelEndpoint(ctx) + }, + }) + +Initiate the transaction by calling registry.StartTX. Pass the context returned +to the transaction branches. You can rollback or commit at your will. If the +TX.Rollback is called, the previously registered rollback operations will be +applied automatically, on condition that the forward operation is indeed +executed within the transaction. + + tx, ctx := registry.StartTX(context.Background()) + resp, err := addOrder(ctx, OrderRequest{Sku: "1"}) + if err != nil { + tx.Rollback(ctx) + } + resp, err = makePayment(ctx, PaymentRequest{}) + if err != nil { + tx.Rollback(ctx) + } + tx.Commit(ctx) + +Integration + +The package leader exports configuration in this format: + + saga: + sagaTimeoutSecond: 600 + recoverIntervalSecond: 60 + +To use package sagas with package core: + + var c *core.C = core.Default() + c.Provide(sagas.Providers) + c.Invoke(func(registry *sagas.Registry) { + tx, ctx := registry.StartTX(context.Background()) + resp, err := addOrder(ctx, OrderRequest{Sku: "1"}) + if err != nil { + tx.Rollback(ctx) + } + resp, err = makePayment(ctx, PaymentRequest{}) + if err != nil { + tx.Rollback(ctx) + } + tx.Commit(ctx) + }) +*/ +package sagas diff --git a/dtx/sagas/example_test.go b/dtx/sagas/example_test.go new file mode 100644 index 00000000..804edd08 --- /dev/null +++ b/dtx/sagas/example_test.go @@ -0,0 +1,119 @@ +package sagas_test + +import ( + "context" + "fmt" + + "github.com/DoNewsCode/core/dtx" + "github.com/DoNewsCode/core/dtx/sagas" +) + +var orderTable = make(map[string]interface{}) +var paymentTable = make(map[string]interface{}) + +type OrderRequest struct { + Sku string +} + +type OrderResponse struct { + OrderID string + Sku string + Cost float64 +} + +type PaymentRequest struct { + OrderID string + Sku string + Cost float64 +} + +type PaymentResponse struct { + Success bool +} + +func orderEndpoint(ctx context.Context, request interface{}) (response interface{}, err error) { + correlationID := ctx.Value(dtx.CorrelationID).(string) + orderTable[correlationID] = request + return OrderResponse{ + OrderID: "1", + Sku: "1", + Cost: 10.0, + }, nil +} + +func orderCancelEndpoint(ctx context.Context, request interface{}) (response interface{}, err error) { + correlationID := ctx.Value(dtx.CorrelationID).(string) + delete(orderTable, correlationID) + return nil, nil +} + +func paymentEndpoint(ctx context.Context, request interface{}) (response interface{}, err error) { + correlationID := ctx.Value(dtx.CorrelationID).(string) + paymentTable[correlationID] = request + if request.(PaymentRequest).Cost < 20 { + return PaymentResponse{ + Success: true, + }, nil + } + return PaymentResponse{ + Success: false, + }, nil +} + +func paymentCancelEndpoint(ctx context.Context) (response interface{}, err error) { + correlationID := ctx.Value(dtx.CorrelationID).(string) + delete(paymentTable, correlationID) + return nil, nil +} + +func Example() { + store := sagas.NewInProcessStore() + registry := sagas.NewRegistry(store) + addOrder := registry.AddStep(&sagas.Step{ + Name: "Add Order", + Do: func(ctx context.Context, request interface{}) (response interface{}, err error) { + resp, err := orderEndpoint(ctx, request.(OrderRequest)) + if err != nil { + return nil, err + } + // Convert the response to next request + return resp, nil + }, + Undo: func(ctx context.Context, req interface{}) (response interface{}, err error) { + return orderCancelEndpoint(ctx, req) + }, + }) + makePayment := registry.AddStep(&sagas.Step{ + Name: "Make Payment", + Do: func(ctx context.Context, request interface{}) (response interface{}, err error) { + resp, err := paymentEndpoint(ctx, request.(PaymentRequest)) + if err != nil { + return nil, err + } + return resp, nil + }, + Undo: func(ctx context.Context, req interface{}) (response interface{}, err error) { + return paymentCancelEndpoint(ctx) + }, + }) + + tx, ctx := registry.StartTX(context.Background()) + resp, err := addOrder(ctx, OrderRequest{Sku: "1"}) + if err != nil { + tx.Rollback(ctx) + } + resp, err = makePayment(ctx, PaymentRequest{ + OrderID: resp.(OrderResponse).OrderID, + Sku: resp.(OrderResponse).Sku, + Cost: resp.(OrderResponse).Cost, + }) + if err != nil { + tx.Rollback(ctx) + } + tx.Commit(ctx) + fmt.Println(resp.(PaymentResponse).Success) + + // Output: + // true + +} diff --git a/dtx/sagas/in_process_store.go b/dtx/sagas/in_process_store.go new file mode 100644 index 00000000..c886dd48 --- /dev/null +++ b/dtx/sagas/in_process_store.go @@ -0,0 +1,109 @@ +package sagas + +import ( + "context" + "sync" + "time" + + "github.com/DoNewsCode/core/dtx" +) + +// InProcessStore creates an in process storage that implements Store. +type InProcessStore struct { + lock sync.Mutex + transactions map[string][]Log +} + +// NewInProcessStore creates a InProcessStore. +func NewInProcessStore() *InProcessStore { + return &InProcessStore{ + transactions: make(map[string][]Log), + } +} + +// Ack marks the log entry as acknowledged, either with an error or not. It is +// safe to call ack to the same log entry more than once. +func (i *InProcessStore) Ack(ctx context.Context, logID string, err error) error { + co := ctx.Value(dtx.CorrelationID).(string) + i.lock.Lock() + defer i.lock.Unlock() + + logs := i.transactions[co] + for k := 0; k < len(logs); k++ { + if logs[k].ID == logID { + if i.transactions[co][k].LogType == Session && err == nil { + delete(i.transactions, co) + return nil + } + i.transactions[co][k].StepError = err + i.transactions[co][k].FinishedAt = time.Now() + } + } + return nil +} + +// Log appends a new unacknowledged log entry to the store. +func (i *InProcessStore) Log(ctx context.Context, log Log) error { + i.lock.Lock() + defer i.lock.Unlock() + + if i.transactions == nil { + i.transactions = make(map[string][]Log) + } + i.transactions[log.correlationID] = append(i.transactions[log.correlationID], log) + return nil +} + +// UnacknowledgedSteps searches the InProcessStore for unacknowledged steps under the given correlationID. +func (i *InProcessStore) UnacknowledgedSteps(ctx context.Context, correlationID string) ([]Log, error) { + i.lock.Lock() + defer i.lock.Unlock() + + return i.unacknowledgedSteps(ctx, correlationID) +} + +// UncommittedSagas searches the store for all uncommitted sagas, and return log entries under the matching sagas. +func (i *InProcessStore) UncommittedSagas(ctx context.Context) ([]Log, error) { + i.lock.Lock() + defer i.lock.Unlock() + + var logs []Log + for k := range i.transactions { + // For safety only. Memory store will not persist successfully finished transactions. + if i.transactions[k][0].LogType == Session && !i.transactions[k][0].FinishedAt.IsZero() { + return []Log{}, nil + } + + parts, err := i.unacknowledgedSteps(ctx, k) + + if err != nil { + return nil, err + } + logs = append(logs, parts...) + } + return logs, nil +} + +func (i *InProcessStore) unacknowledgedSteps(ctx context.Context, correlationID string) ([]Log, error) { + + var ( + stepStates = make(map[string]Log) + ) + + for _, l := range i.transactions[correlationID] { + if l.LogType == Do { + stepStates[l.StepName] = l + } + if l.LogType == Undo && (!l.FinishedAt.IsZero()) && l.StepError == nil { + delete(stepStates, l.StepName) + } + } + var steps []Log + for k := range stepStates { + steps = append(steps, stepStates[k]) + } + if len(steps) == 0 { + delete(i.transactions, correlationID) + } + return steps, nil +} diff --git a/dtx/sagas/in_process_store_test.go b/dtx/sagas/in_process_store_test.go new file mode 100644 index 00000000..d3d9743f --- /dev/null +++ b/dtx/sagas/in_process_store_test.go @@ -0,0 +1,140 @@ +package sagas + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/DoNewsCode/core/dtx" + "github.com/stretchr/testify/assert" +) + +func TestInProcessStore_Ack(t *testing.T) { + cases := []struct { + name string + log Log + err error + asserts func(t *testing.T, log Log, s *InProcessStore) + }{ + { + "session without error", + Log{ + ID: "1", + correlationID: "2", + LogType: Session, + StartedAt: time.Now(), + }, + nil, + func(t *testing.T, log Log, s *InProcessStore) { + assert.Len(t, s.transactions, 0) + }, + }, + { + "session with error", + Log{ + ID: "1", + correlationID: "2", + LogType: Session, + StartedAt: time.Now(), + }, + errors.New("foo"), + func(t *testing.T, log Log, s *InProcessStore) { + assert.Len(t, s.transactions, 1) + assert.Error(t, s.transactions[log.correlationID][0].StepError) + }, + }, + { + "do without error", + Log{ + ID: "1", + correlationID: "2", + LogType: Do, + StartedAt: time.Now(), + }, + nil, + func(t *testing.T, log Log, s *InProcessStore) { + assert.Len(t, s.transactions, 1) + assert.False(t, s.transactions[log.correlationID][0].FinishedAt.IsZero()) + assert.NoError(t, s.transactions[log.correlationID][0].StepError) + }, + }, + { + "do with error", + Log{ + ID: "1", + correlationID: "2", + LogType: Do, + StartedAt: time.Now(), + }, + errors.New("foo"), + func(t *testing.T, log Log, s *InProcessStore) { + assert.Len(t, s.transactions, 1) + assert.False(t, s.transactions[log.correlationID][0].FinishedAt.IsZero()) + assert.Error(t, s.transactions[log.correlationID][0].StepError) + }, + }, + } + + for _, c := range cases { + c := c + t.Run(c.name, func(t *testing.T) { + t.Parallel() + store := NewInProcessStore() + ctx := context.WithValue(context.Background(), dtx.CorrelationID, c.log.correlationID) + store.Log(ctx, c.log) + store.Ack(ctx, c.log.ID, c.err) + c.asserts(t, c.log, store) + }) + } +} + +func TestInProcessStore_UncommittedSteps(t *testing.T) { + store := NewInProcessStore() + ctx := context.WithValue(context.Background(), dtx.CorrelationID, "2") + store.Log(ctx, Log{ + ID: "1", + correlationID: "2", + StartedAt: time.Now(), + LogType: Session, + }) + store.Log(ctx, Log{ + ID: "2", + correlationID: "2", + StartedAt: time.Now(), + LogType: Do, + }) + logs, err := store.UnacknowledgedSteps(context.Background(), "2") + assert.NoError(t, err) + assert.Len(t, logs, 1) + + store.Log(ctx, Log{ + ID: "2", + correlationID: "2", + StartedAt: time.Now(), + LogType: Undo, + }) + + logs, err = store.UnacknowledgedSteps(context.Background(), "2") + assert.NoError(t, err) + assert.Len(t, logs, 1) + + store.Ack(ctx, "2", nil) + logs, err = store.UnacknowledgedSteps(context.Background(), "2") + assert.NoError(t, err) + assert.Len(t, logs, 0) +} + +func TestInProcessStore_UncommittedSagas(t *testing.T) { + store := NewInProcessStore() + store.transactions["test"] = []Log{{ + ID: "1", + correlationID: "test", + FinishedAt: time.Now(), + LogType: Session, + StepError: nil, + }} + logs, err := store.UncommittedSagas(context.Background()) + assert.NoError(t, err) + assert.Len(t, logs, 0) +} diff --git a/dtx/sagas/log.go b/dtx/sagas/log.go new file mode 100644 index 00000000..de13ae9b --- /dev/null +++ b/dtx/sagas/log.go @@ -0,0 +1,30 @@ +package sagas + +import ( + "time" +) + +// LogType is a type enum that describes the types of Log. +type LogType uint + +const ( + // Session type logs the occurrence of a new distributed transaction. + Session LogType = iota + // Do type logs an incremental action in the distributed saga step. + Do + // Undo type logs a compensation action in the distributed saga step. + Undo +) + +// Log is the structural Log type of the distributed saga. +type Log struct { + ID string + correlationID string + StartedAt time.Time + FinishedAt time.Time + LogType LogType + StepNumber int + StepParam interface{} + StepName string + StepError error +} diff --git a/dtx/sagas/registery.go b/dtx/sagas/registery.go new file mode 100644 index 00000000..530e8ed6 --- /dev/null +++ b/dtx/sagas/registery.go @@ -0,0 +1,158 @@ +package sagas + +import ( + "context" + "fmt" + "time" + + "github.com/DoNewsCode/core/dtx" + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/log/level" + "github.com/rs/xid" +) + +// Step is a step in the Saga. +type Step struct { + Name string + Do endpoint.Endpoint + Undo endpoint.Endpoint +} + +// Registry holds all transaction sagas in this process. It should be populated during the initialization of the application. +type Registry struct { + logger log.Logger + Store Store + steps map[string]*Step + timeout time.Duration +} + +// Option is the functional option for NewRegistry. +type Option func(registry *Registry) + +// WithLogger is an option that adds a logger to the registry. +func WithLogger(logger log.Logger) Option { + return func(registry *Registry) { + registry.logger = logger + } +} + +// WithTimeout is an option that configures when the unacknowledged steps +// should be marked as stale and become candidates for rollback. +func WithTimeout(duration time.Duration) Option { + return func(registry *Registry) { + registry.timeout = duration + } +} + +// NewRegistry creates a new Registry. +func NewRegistry(store Store, opts ...Option) *Registry { + r := &Registry{ + logger: log.NewNopLogger(), + Store: store, + timeout: 10 * time.Minute, + steps: make(map[string]*Step), + } + for _, f := range opts { + f(r) + } + return r +} + +// StartTX starts a transaction using saga pattern. +func (r *Registry) StartTX(ctx context.Context) (*TX, context.Context) { + cid := xid.New().String() + tx := &TX{ + session: Log{ + ID: xid.New().String(), + correlationID: cid, + StartedAt: time.Now(), + LogType: Session, + }, + store: r.Store, + correlationID: cid, + rollbacks: make(map[string]endpoint.Endpoint), + } + ctx = context.WithValue(ctx, dtx.CorrelationID, cid) + ctx = context.WithValue(ctx, TxContextKey, tx) + must(tx.store.Log(ctx, tx.session)) + return tx, ctx +} + +// AddStep registers the saga steps in the registry. The registration should be done +// during the bootstrapping of application. +func (r *Registry) AddStep(step *Step) endpoint.Endpoint { + r.steps[step.Name] = step + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + logID := xid.New().String() + tx := TxFromContext(ctx) + if tx.completed { + panic("re-executing a completed transaction") + } + stepLog := Log{ + ID: logID, + correlationID: tx.correlationID, + StartedAt: time.Now(), + LogType: Do, + StepName: step.Name, + StepParam: request, + } + must(tx.store.Log(ctx, stepLog)) + tx.rollbacks[step.Name] = func(ctx context.Context, _ interface{}) (response interface{}, err error) { + logID := xid.New().String() + compensateLog := Log{ + ID: logID, + correlationID: tx.correlationID, + StartedAt: time.Now(), + LogType: Undo, + StepName: step.Name, + StepParam: request, + } + must(tx.store.Log(ctx, compensateLog)) + resp, err := step.Undo(ctx, request) + must(tx.store.Ack(ctx, logID, err)) + + return resp, err + } + response, err = step.Do(ctx, request) + must(tx.store.Ack(ctx, logID, err)) + return response, err + } +} + +// Recover rollbacks all uncommitted sagas by retrieving them in the store. +func (r *Registry) Recover(ctx context.Context) { + logs, err := r.Store.UncommittedSagas(ctx) + if err != nil { + panic(err) + } + for _, log := range logs { + if log.StartedAt.Add(r.timeout).After(time.Now()) { + continue + } + if _, ok := r.steps[log.StepName]; !ok { + level.Warn(r.logger).Log( + "msg", + fmt.Sprintf("saga step %s not registered", log.StepName), + ) + } + tx := TX{ + correlationID: log.correlationID, + store: r.Store, + } + ctx = context.WithValue(ctx, dtx.CorrelationID, tx.correlationID) + logID := xid.New().String() + compensateLog := Log{ + ID: logID, + correlationID: tx.correlationID, + StartedAt: time.Now(), + LogType: Undo, + StepName: log.StepName, + StepParam: log.StepParam, + } + + must(tx.store.Log(ctx, compensateLog)) + _, err := r.steps[log.StepName].Undo(ctx, log.StepParam) + must(tx.store.Ack(ctx, logID, err)) + } +} diff --git a/dtx/sagas/registry_test.go b/dtx/sagas/registry_test.go new file mode 100644 index 00000000..b9526b70 --- /dev/null +++ b/dtx/sagas/registry_test.go @@ -0,0 +1,62 @@ +package sagas + +import ( + "context" + "testing" + "time" + + "github.com/go-kit/kit/log" +) + +func TestRegistry_Recover(t *testing.T) { + store := NewInProcessStore() + store.transactions["test"] = []Log{{ + ID: "0", + correlationID: "2", + StartedAt: time.Now(), + LogType: Session, + StepNumber: 0, + }, { + ID: "1", + correlationID: "2", + StartedAt: time.Now(), + FinishedAt: time.Time{}, + StepNumber: 1, + LogType: Do, + StepError: nil, + }} + reg := NewRegistry(store, WithLogger(log.NewNopLogger())) + reg.Recover(context.Background()) +} + +func TestRegistry_RecoverWithTimeout(t *testing.T) { + store := NewInProcessStore() + store.transactions["test"] = []Log{{ + ID: "0", + correlationID: "2", + StartedAt: time.Now(), + LogType: Session, + }, { + ID: "1", + correlationID: "2", + StartedAt: time.Now(), + FinishedAt: time.Time{}, + StepNumber: 0, + LogType: Do, + StepError: nil, + StepName: "foo", + }} + reg := NewRegistry(store, WithLogger(log.NewNopLogger())) + reg.AddStep(&Step{ + Name: "foo", + Do: func(ctx context.Context, request interface{}) (response interface{}, err error) { + t.Fatal("should not be called") + return nil, nil + }, + Undo: func(ctx context.Context, req interface{}) (response interface{}, err error) { + t.Fatal("should not be called") + return nil, nil + }, + }) + reg.Recover(context.Background()) +} diff --git a/dtx/sagas/tx.go b/dtx/sagas/tx.go new file mode 100644 index 00000000..d1ca4097 --- /dev/null +++ b/dtx/sagas/tx.go @@ -0,0 +1,61 @@ +package sagas + +import ( + "context" + + "github.com/go-kit/kit/endpoint" + "github.com/hashicorp/go-multierror" +) + +type contextKey string + +// TxContextKey is the context key for TX. +const TxContextKey contextKey = "coordinator" + +// Store is the interface to persist logs of transactions. +type Store interface { + Log(ctx context.Context, log Log) error + Ack(ctx context.Context, id string, err error) error + UnacknowledgedSteps(ctx context.Context, correlationID string) ([]Log, error) + UncommittedSagas(ctx context.Context) ([]Log, error) +} + +// TX is a distributed transaction coordinator. It should be initialized +// by directly assigning its public members. +type TX struct { + store Store + correlationID string + session Log + rollbacks map[string]endpoint.Endpoint + undoErr *multierror.Error + completed bool +} + +// Commit commits the current transaction. +func (tx *TX) Commit(ctx context.Context) error { + tx.completed = true + return tx.store.Ack(ctx, tx.session.ID, nil) +} + +// Rollback rollbacks the current transaction. +func (tx *TX) Rollback(ctx context.Context) error { + for _, call := range tx.rollbacks { + _, err := call(ctx, nil) + if err != nil { + tx.undoErr = multierror.Append(tx.undoErr, err) + } + } + tx.completed = true + return tx.undoErr.ErrorOrNil() +} + +// TxFromContext returns the tx instance from context. +func TxFromContext(ctx context.Context) *TX { + return ctx.Value(TxContextKey).(*TX) +} + +func must(err error) { + if err != nil { + panic(err) + } +} diff --git a/dtx/sagas/tx_test.go b/dtx/sagas/tx_test.go new file mode 100644 index 00000000..66676f36 --- /dev/null +++ b/dtx/sagas/tx_test.go @@ -0,0 +1,240 @@ +package sagas + +import ( + "context" + "errors" + "testing" + + "github.com/hashicorp/go-multierror" + "github.com/stretchr/testify/assert" +) + +func TestSaga_success(t *testing.T) { + var value int + store := NewInProcessStore() + r := NewRegistry(store) + + ep1 := r.AddStep(&Step{ + "one", + func(ctx context.Context, req interface{}) (interface{}, error) { + value++ + return nil, nil + }, + func(ctx context.Context, req interface{}) (interface{}, error) { + value-- + return nil, nil + }, + }) + ep2 := r.AddStep(&Step{ + "two", + func(ctx context.Context, req interface{}) (interface{}, error) { + value++ + return nil, nil + }, + func(ctx context.Context, req interface{}) (interface{}, error) { + value-- + return nil, nil + }, + }) + + var c, ctx = r.StartTX(context.Background()) + ep1(ctx, nil) + ep2(ctx, nil) + c.Commit(ctx) + assert.Equal(t, 2, value) +} + +func TestSaga_failure(t *testing.T) { + var value int + store := NewInProcessStore() + r := NewRegistry(store) + + ep1 := r.AddStep(&Step{ + "one", + func(ctx context.Context, req interface{}) (interface{}, error) { + value++ + return nil, nil + }, + func(ctx context.Context, req interface{}) (interface{}, error) { + value-- + return nil, nil + }, + }) + + ep2 := r.AddStep(&Step{ + "two", + func(ctx context.Context, req interface{}) (interface{}, error) { + value++ + return nil, errors.New("") + }, + func(ctx context.Context, req interface{}) (interface{}, error) { + value-- + return nil, nil + }, + }) + + var c, ctx = r.StartTX(context.Background()) + ep1(ctx, nil) + ep2(ctx, nil) + c.Rollback(ctx) + assert.Equal(t, 0, value) +} + +func TestSaga_recovery(t *testing.T) { + var attempt int + var value int + var store = &InProcessStore{} + var r = NewRegistry(store, WithTimeout(0)) + var errTest = errors.New("test") + ep1 := r.AddStep(&Step{ + "one", + func(ctx context.Context, req interface{}) (interface{}, error) { + value++ + return nil, nil + }, + func(ctx context.Context, req interface{}) (interface{}, error) { + value-- + return nil, nil + }, + }) + + ep2 := r.AddStep(&Step{ + + "two", + func(ctx context.Context, req interface{}) (interface{}, error) { + value++ + return nil, errors.New("") + }, + func(ctx context.Context, req interface{}) (interface{}, error) { + if attempt == 0 { + attempt++ + return nil, errTest + } + value-- + return nil, nil + + }, + }) + + var c, ctx = r.StartTX(context.Background()) + ep1(ctx, nil) + ep2(ctx, nil) + err := c.Rollback(ctx) + assert.NotNil(t, err) + assert.Len(t, err.(*multierror.Error).Errors, 1) + assert.Equal(t, 1, value) + + r.Recover(ctx) + assert.Equal(t, 0, value) +} + +func TestSaga_panic(t *testing.T) { + var attempt int + var value int + var store = &InProcessStore{} + var r = NewRegistry(store, WithTimeout(0)) + + ep1 := r.AddStep(&Step{ + "one", + func(ctx context.Context, req interface{}) (interface{}, error) { + value++ + return nil, nil + }, + func(ctx context.Context, req interface{}) (interface{}, error) { + value-- + return nil, nil + }, + }) + ep2 := r.AddStep(&Step{ + "two", + func(ctx context.Context, req interface{}) (interface{}, error) { + value++ + return nil, errors.New("") + }, + func(ctx context.Context, req interface{}) (interface{}, error) { + if attempt == 0 { + attempt++ + panic("err") + } + value-- + return nil, nil + }, + }) + + defer func(r *Registry) { + if rec := recover(); rec != nil { + r.Recover(context.Background()) + assert.Equal(t, 0, value) + } + }(r) + + var _, ctx = r.StartTX(context.Background()) + ep1(ctx, nil) + ep2(ctx, nil) +} + +func TestSaga_shortCircuit(t *testing.T) { + var value int + var store = &InProcessStore{} + var r = NewRegistry(store, WithTimeout(0)) + + ep1 := r.AddStep(&Step{ + "one", + func(ctx context.Context, req interface{}) (interface{}, error) { + value++ + return nil, nil + }, + func(ctx context.Context, req interface{}) (interface{}, error) { + value-- + return nil, nil + }, + }) + + r.AddStep(&Step{ + "two", + func(ctx context.Context, req interface{}) (interface{}, error) { + panic("should not reach") + }, + func(ctx context.Context, req interface{}) (interface{}, error) { + panic("should not reach") + }, + }) + + var c, ctx = r.StartTX(context.Background()) + ep1(ctx, nil) + c.Commit(ctx) + assert.Equal(t, 1, value) +} + +func TestSaga_emptyRecover(t *testing.T) { + var value int + var attempt int + var store = &InProcessStore{} + var r = NewRegistry(store, WithTimeout(0)) + + ep := r.AddStep(&Step{ + "two", + func(ctx context.Context, req interface{}) (interface{}, error) { + value++ + return nil, errors.New("foo") + }, + func(ctx context.Context, req interface{}) (interface{}, error) { + if attempt == 0 { + attempt++ + value-- + return nil, nil + } + panic("err") + }, + }) + tx, ctx := r.StartTX(context.Background()) + tx.Commit(ctx) + r.Recover(context.Background()) + assert.Equal(t, 0, value) + defer func() { + assert.NotNil(t, recover()) + }() + + ep(ctx, nil) + +} diff --git a/dtx/transport.go b/dtx/transport.go new file mode 100644 index 00000000..3410aadf --- /dev/null +++ b/dtx/transport.go @@ -0,0 +1,69 @@ +package dtx + +import ( + "context" + stdhttp "net/http" + + "github.com/go-kit/kit/transport/grpc" + "github.com/go-kit/kit/transport/http" + "google.golang.org/grpc/metadata" +) + +const ( + header string = "X-TX-CORRELATION-ID" + headerHTTP2 string = "x-tx-correlation-id" +) + +// HTTPToContext moves a CorrelationID from request header to context. Particularly +// useful for servers. +func HTTPToContext() http.RequestFunc { + return func(ctx context.Context, r *stdhttp.Request) context.Context { + token := r.Header.Get(header) + if token == "" { + return ctx + } + return context.WithValue(ctx, CorrelationID, token) + } +} + +// ContextToHTTP moves a CorrelationID from context to request header. Particularly +// useful for clients. +func ContextToHTTP() http.RequestFunc { + return func(ctx context.Context, r *stdhttp.Request) context.Context { + token, ok := ctx.Value(CorrelationID).(string) + if ok { + r.Header.Add(header, token) + } + return ctx + } +} + +// GRPCToContext moves a CorrelationID from grpc metadata to context. Particularly +// userful for servers. +func GRPCToContext() grpc.ServerRequestFunc { + return func(ctx context.Context, md metadata.MD) context.Context { + // capital "Key" is illegal in HTTP/2. + tokens, ok := md[headerHTTP2] + if !ok { + return ctx + } + if len(tokens) <= 0 { + return ctx + } + ctx = context.WithValue(ctx, CorrelationID, tokens[len(tokens)-1]) + return ctx + } +} + +// ContextToGRPC moves a CorrelationID from context to grpc metadata. Particularly +// useful for clients. +func ContextToGRPC() grpc.ClientRequestFunc { + return func(ctx context.Context, md *metadata.MD) context.Context { + token, ok := ctx.Value(CorrelationID).(string) + if ok { + // capital "Key" is illegal in HTTP/2. + (*md)[headerHTTP2] = []string{token} + } + return ctx + } +} diff --git a/dtx/transport_test.go b/dtx/transport_test.go new file mode 100644 index 00000000..d80f972b --- /dev/null +++ b/dtx/transport_test.go @@ -0,0 +1,107 @@ +package dtx + +import ( + "context" + "net/http" + "testing" + + "google.golang.org/grpc/metadata" +) + +func TestHTTPToContext(t *testing.T) { + reqFunc := HTTPToContext() + + // When the header doesn't exist + ctx := reqFunc(context.Background(), &http.Request{}) + + if ctx.Value(CorrelationID) != nil { + t.Error("Context shouldn't contain the CorrelationID") + } + + head := http.Header{} + // Authorization header is correct + head.Set(header, "foobar") + ctx = reqFunc(context.Background(), &http.Request{Header: head}) + + token := ctx.Value(CorrelationID).(string) + if token != "foobar" { + t.Errorf("Context doesn't contain the expected encoded token value; expected: %s, got: %s", "foobar", token) + } +} + +func TestContextToHTTP(t *testing.T) { + reqFunc := ContextToHTTP() + + // No JWT Token is passed in the context + ctx := context.Background() + r := http.Request{} + reqFunc(ctx, &r) + + token := r.Header.Get(header) + if token != "" { + t.Error("header key should not exist in metadata") + } + + // Correct JWT Token is passed in the context + ctx = context.WithValue(context.Background(), CorrelationID, "foobar") + r = http.Request{Header: http.Header{}} + reqFunc(ctx, &r) + + token = r.Header.Get(header) + expected := "foobar" + + if token != expected { + t.Errorf("Authorization header does not contain the expected JWT token; expected %s, got %s", expected, token) + } +} + +func TestGRPCToContext(t *testing.T) { + md := metadata.MD{} + reqFunc := GRPCToContext() + + // No Authorization header is passed + ctx := reqFunc(context.Background(), md) + token := ctx.Value(CorrelationID) + if token != nil { + t.Error("Context should not contain a correlation ID") + } + + md[headerHTTP2] = []string{"foobar"} + ctx = reqFunc(context.Background(), md) + token, ok := ctx.Value(CorrelationID).(string) + if !ok { + t.Fatal("Correlation ID not passed to context correctly") + } + + if token != "foobar" { + t.Errorf("Correlation ID did not match: expecting %s got %s", "foobar", token) + } +} + +func TestContextToGRPC(t *testing.T) { + reqFunc := ContextToGRPC() + + // No JWT Token is passed in the context + ctx := context.Background() + md := metadata.MD{} + reqFunc(ctx, &md) + + _, ok := md[headerHTTP2] + if ok { + t.Error("authorization key should not exist in metadata") + } + + // Correct JWT Token is passed in the context + ctx = context.WithValue(context.Background(), CorrelationID, "foobar") + md = metadata.MD{} + reqFunc(ctx, &md) + + token, ok := md[headerHTTP2] + if !ok { + t.Fatal("JWT Token not passed to metadata correctly") + } + + if token[0] != "foobar" { + t.Errorf("JWT tokens did not match: expecting %s got %s", "foobar", token[0]) + } +} diff --git a/go.mod b/go.mod index 3da8aa6a..fbf2f58b 100644 --- a/go.mod +++ b/go.mod @@ -28,6 +28,7 @@ require ( github.com/gorilla/mux v1.8.0 github.com/grpc-ecosystem/go-grpc-middleware v1.2.2 // indirect github.com/grpc-ecosystem/grpc-gateway v1.14.6 // indirect + github.com/hashicorp/go-multierror v1.1.0 github.com/heptiolabs/healthcheck v0.0.0-20180807145615-6ff867650f40 github.com/jonboulle/clockwork v0.2.2 // indirect github.com/knadh/koanf v0.15.0 @@ -42,11 +43,11 @@ require ( github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.9.0 github.com/robfig/cron/v3 v3.0.1 + github.com/rs/xid v1.2.1 github.com/segmentio/kafka-go v0.4.10 github.com/sirupsen/logrus v1.7.0 // indirect github.com/soheilhy/cmux v0.1.5-0.20210205191134-5ec6847320e5 // indirect github.com/spf13/cobra v1.1.3 - github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.7.0 github.com/tmc/grpc-websocket-proxy v0.0.0-20200427203606-3cfed13b9966 // indirect github.com/uber/jaeger-client-go v2.25.0+incompatible diff --git a/go.sum b/go.sum index cf17c190..be4d57db 100644 --- a/go.sum +++ b/go.sum @@ -259,6 +259,7 @@ github.com/hashicorp/consul/api v1.1.0/go.mod h1:VmuI/Lkw1nC05EYQWNKwWGbkg+FbDBt github.com/hashicorp/consul/api v1.3.0/go.mod h1:MmDNSzIMUjNpY/mQ398R4bk2FnqQLoPndWW5VkKPlCE= github.com/hashicorp/consul/sdk v0.1.1/go.mod h1:VKf9jXwCTEY1QZP2MOLRhb5i/I/ssyNV1vwHyQBF0x8= github.com/hashicorp/consul/sdk v0.3.0/go.mod h1:VKf9jXwCTEY1QZP2MOLRhb5i/I/ssyNV1vwHyQBF0x8= +github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-cleanhttp v0.5.0/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= github.com/hashicorp/go-cleanhttp v0.5.1/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= @@ -267,6 +268,8 @@ github.com/hashicorp/go-hclog v0.8.0/go.mod h1:5CU+agLiy3J7N7QjHK5d05KxGsuXiQLrj github.com/hashicorp/go-immutable-radix v1.0.0/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= github.com/hashicorp/go-msgpack v0.5.3/go.mod h1:ahLV/dePpqEmjfWmKiqvPkv/twdG7iPBM1vqhUKIvfM= github.com/hashicorp/go-multierror v1.0.0/go.mod h1:dHtQlpGsu+cZNNAkkCN/P3hoUDHhCYQXV3UM06sGGrk= +github.com/hashicorp/go-multierror v1.1.0 h1:B9UzwGQJehnUY1yNrnwREHc3fGbC2xefo8g4TbElacI= +github.com/hashicorp/go-multierror v1.1.0/go.mod h1:spPvp8C1qA32ftKqdAHm4hHTbPw+vmowP0z+KUhOZdA= github.com/hashicorp/go-plugin v1.0.1/go.mod h1:++UyYGoz3o5w9ZzAdZxtQKrWWP+iqPBn3cQptSMzBuY= github.com/hashicorp/go-retryablehttp v0.5.4/go.mod h1:9B5zBasrRhHXnJnui7y6sL7es7NDiJgTc6Er0maI1Xs= github.com/hashicorp/go-rootcerts v1.0.0/go.mod h1:K6zTfqpRlCUIjkwsN4Z+hiSfzSTQa6eBIzfwKfwNnHU= @@ -564,6 +567,7 @@ github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6L github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.2.2/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rs/xid v1.2.1 h1:mhH9Nq+C1fY2l1XIpgxIiUOfNpRBYH1kKcr+qfKgjRc= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= @@ -705,7 +709,6 @@ golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 h1:psW17arqaxU48Z5kZ0CQnkZWQJsqcURM6tKiBApRjXI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0 h1:hb9wdF1z5waM+dSIICn1l0DkLVDT3hqhhQsDNUmHPRE= golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= @@ -722,7 +725,6 @@ golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTk golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20190409202823-959b441ac422/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20190909230951-414d861bb4ac/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20200302205851-738671d3881b h1:Wh+f8QHJXR411sJR8/vRBTZ7YapZaRvUcLFFJhusH0k= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= diff --git a/leader/doc.go b/leader/doc.go index 813433bf..ae1580a5 100644 --- a/leader/doc.go +++ b/leader/doc.go @@ -22,7 +22,7 @@ To use package leader with package core: var c *core.C = core.Default() c.Provide(otetcd.Providers) // to provide the underlying driver c.Provide(leader.Providers) - c.Invoke(function(status *leader.Status) { + c.Invoke(func(status *leader.Status) { if ! status.IsLeader { return }