Skip to content

Commit 944e37b

Browse files
committed
pkg/capabilities: support replacing registered capabilities after shutdown
1 parent 8f2c438 commit 944e37b

File tree

5 files changed

+258
-33
lines changed

5 files changed

+258
-33
lines changed

pkg/capabilities/registry/base.go

Lines changed: 239 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@ import (
66
"fmt"
77
"strings"
88
"sync"
9+
"sync/atomic"
910

1011
"github.com/Masterminds/semver/v3"
12+
"google.golang.org/grpc"
13+
"google.golang.org/grpc/connectivity"
1114

1215
"github.com/smartcontractkit/chainlink-common/pkg/capabilities"
1316
"github.com/smartcontractkit/chainlink-common/pkg/logger"
@@ -18,8 +21,22 @@ var (
1821
ErrCapabilityAlreadyExists = errors.New("capability already exists")
1922
)
2023

24+
// atomicBaseCapability extends [capabilities.BaseCapability] to support atomic updates and forward client state checks.
25+
type atomicBaseCapability interface {
26+
capabilities.BaseCapability
27+
Update(capabilities.BaseCapability) error
28+
StateGetter
29+
}
30+
31+
var _ StateGetter = (*grpc.ClientConn)(nil)
32+
33+
// StateGetter is implemented by GRPC client connections.
34+
type StateGetter interface {
35+
GetState() connectivity.State
36+
}
37+
2138
type baseRegistry struct {
22-
m map[string]capabilities.BaseCapability
39+
m map[string]atomicBaseCapability
2340
lggr logger.Logger
2441
mu sync.RWMutex
2542
}
@@ -28,7 +45,7 @@ var _ core.CapabilitiesRegistryBase = (*baseRegistry)(nil)
2845

2946
func NewBaseRegistry(lggr logger.Logger) core.CapabilitiesRegistryBase {
3047
return &baseRegistry{
31-
m: map[string]capabilities.BaseCapability{},
48+
m: map[string]atomicBaseCapability{},
3249
lggr: logger.Named(lggr, "registries.basic"),
3350
}
3451
}
@@ -142,46 +159,240 @@ func (r *baseRegistry) Add(ctx context.Context, c capabilities.BaseCapability) e
142159
return err
143160
}
144161

145-
switch info.CapabilityType {
146-
case capabilities.CapabilityTypeTrigger:
147-
_, ok := c.(capabilities.TriggerCapability)
148-
if !ok {
149-
return errors.New("trigger capability does not satisfy TriggerCapability interface")
162+
id := info.ID
163+
bc, ok := r.m[id]
164+
if ok {
165+
if bc.GetState() != connectivity.Shutdown {
166+
return fmt.Errorf("%w: id %s found in registry", ErrCapabilityAlreadyExists, id)
150167
}
151-
case capabilities.CapabilityTypeAction, capabilities.CapabilityTypeConsensus, capabilities.CapabilityTypeTarget:
152-
_, ok := c.(capabilities.ExecutableCapability)
153-
if !ok {
154-
return errors.New("action does not satisfy ExecutableCapability interface")
168+
if err := bc.Update(c); err != nil {
169+
return fmt.Errorf("failed to update capability %s: %w", id, err)
155170
}
156-
case capabilities.CapabilityTypeCombined:
157-
_, ok := c.(capabilities.ExecutableAndTriggerCapability)
158-
if !ok {
159-
return errors.New("target capability does not satisfy ExecutableAndTriggerCapability interface")
171+
} else {
172+
var ac atomicBaseCapability
173+
switch info.CapabilityType {
174+
case capabilities.CapabilityTypeTrigger:
175+
ac = &atomicTriggerCapability{}
176+
case capabilities.CapabilityTypeAction, capabilities.CapabilityTypeConsensus, capabilities.CapabilityTypeTarget:
177+
ac = &atomicExecuteCapability{}
178+
case capabilities.CapabilityTypeCombined:
179+
ac = &atomicExecuteAndTriggerCapability{}
180+
default:
181+
return fmt.Errorf("unknown capability type: %s", info.CapabilityType)
160182
}
161-
default:
162-
return fmt.Errorf("unknown capability type: %s", info.CapabilityType)
163-
}
164-
165-
id := info.ID
166-
_, ok := r.m[id]
167-
if ok {
168-
return fmt.Errorf("%w: id %s found in registry", ErrCapabilityAlreadyExists, id)
183+
if err := ac.Update(c); err != nil {
184+
return err
185+
}
186+
r.m[id] = ac
169187
}
170-
171-
r.m[id] = c
172188
r.lggr.Infow("capability added", "id", id, "type", info.CapabilityType, "description", info.Description, "version", info.Version())
173189
return nil
174190
}
175191

176192
func (r *baseRegistry) Remove(_ context.Context, id string) error {
177193
r.mu.Lock()
178194
defer r.mu.Unlock()
179-
_, ok := r.m[id]
195+
ac, ok := r.m[id]
180196
if !ok {
181197
return fmt.Errorf("unable to remove, capability not found: %s", id)
182198
}
183-
184-
delete(r.m, id)
199+
if err := ac.Update(nil); err != nil {
200+
return fmt.Errorf("failed to remove capability %s: %w", id, err)
201+
}
185202
r.lggr.Infow("capability removed", "id", id)
186203
return nil
187204
}
205+
206+
var _ capabilities.TriggerCapability = &atomicTriggerCapability{}
207+
208+
type atomicTriggerCapability struct {
209+
atomic.Pointer[capabilities.TriggerCapability]
210+
}
211+
212+
func (a *atomicTriggerCapability) Update(c capabilities.BaseCapability) error {
213+
if c == nil {
214+
a.Store(nil)
215+
return nil
216+
}
217+
tc, ok := c.(capabilities.TriggerCapability)
218+
if !ok {
219+
return errors.New("trigger capability does not satisfy TriggerCapability interface")
220+
}
221+
a.Store(&tc)
222+
return nil
223+
}
224+
225+
func (a *atomicTriggerCapability) Info(ctx context.Context) (capabilities.CapabilityInfo, error) {
226+
c := a.Load()
227+
if c == nil {
228+
return capabilities.CapabilityInfo{}, errors.New("capability unavailable")
229+
}
230+
return (*c).Info(ctx)
231+
}
232+
233+
func (a *atomicTriggerCapability) GetState() connectivity.State {
234+
c := a.Load()
235+
if c == nil {
236+
return connectivity.Shutdown
237+
}
238+
if sg, ok := (*c).(StateGetter); ok {
239+
return sg.GetState()
240+
}
241+
return connectivity.State(-1) // unknown
242+
}
243+
244+
func (a *atomicTriggerCapability) RegisterTrigger(ctx context.Context, request capabilities.TriggerRegistrationRequest) (<-chan capabilities.TriggerResponse, error) {
245+
c := a.Load()
246+
if c == nil {
247+
return nil, errors.New("capability unavailable")
248+
}
249+
return (*c).RegisterTrigger(ctx, request)
250+
}
251+
252+
func (a *atomicTriggerCapability) UnregisterTrigger(ctx context.Context, request capabilities.TriggerRegistrationRequest) error {
253+
c := a.Load()
254+
if c == nil {
255+
return errors.New("capability unavailable")
256+
}
257+
return (*c).UnregisterTrigger(ctx, request)
258+
}
259+
260+
var _ capabilities.ExecutableCapability = &atomicExecuteCapability{}
261+
262+
type atomicExecuteCapability struct {
263+
atomic.Pointer[capabilities.ExecutableCapability]
264+
}
265+
266+
func (a *atomicExecuteCapability) Update(c capabilities.BaseCapability) error {
267+
if c == nil {
268+
a.Store(nil)
269+
return nil
270+
}
271+
tc, ok := c.(capabilities.ExecutableCapability)
272+
if !ok {
273+
return errors.New("action does not satisfy ExecutableCapability interface")
274+
}
275+
a.Store(&tc)
276+
return nil
277+
}
278+
279+
func (a *atomicExecuteCapability) Info(ctx context.Context) (capabilities.CapabilityInfo, error) {
280+
c := a.Load()
281+
if c == nil {
282+
return capabilities.CapabilityInfo{}, errors.New("capability unavailable")
283+
}
284+
return (*c).Info(ctx)
285+
}
286+
287+
func (a *atomicExecuteCapability) GetState() connectivity.State {
288+
c := a.Load()
289+
if c == nil {
290+
return connectivity.Shutdown
291+
}
292+
if sg, ok := (*c).(StateGetter); ok {
293+
return sg.GetState()
294+
}
295+
return connectivity.State(-1) // unknown
296+
}
297+
298+
func (a *atomicExecuteCapability) RegisterToWorkflow(ctx context.Context, request capabilities.RegisterToWorkflowRequest) error {
299+
c := a.Load()
300+
if c == nil {
301+
return errors.New("capability unavailable")
302+
}
303+
return (*c).RegisterToWorkflow(ctx, request)
304+
}
305+
306+
func (a *atomicExecuteCapability) UnregisterFromWorkflow(ctx context.Context, request capabilities.UnregisterFromWorkflowRequest) error {
307+
c := a.Load()
308+
if c == nil {
309+
return errors.New("capability unavailable")
310+
}
311+
return (*c).UnregisterFromWorkflow(ctx, request)
312+
}
313+
314+
func (a *atomicExecuteCapability) Execute(ctx context.Context, request capabilities.CapabilityRequest) (capabilities.CapabilityResponse, error) {
315+
c := a.Load()
316+
if c == nil {
317+
return capabilities.CapabilityResponse{}, errors.New("capability unavailable")
318+
}
319+
return (*c).Execute(ctx, request)
320+
}
321+
322+
var _ capabilities.ExecutableAndTriggerCapability = &atomicExecuteAndTriggerCapability{}
323+
324+
type atomicExecuteAndTriggerCapability struct {
325+
atomic.Pointer[capabilities.ExecutableAndTriggerCapability]
326+
}
327+
328+
func (a *atomicExecuteAndTriggerCapability) Update(c capabilities.BaseCapability) error {
329+
if c == nil {
330+
a.Store(nil)
331+
return nil
332+
}
333+
tc, ok := c.(capabilities.ExecutableAndTriggerCapability)
334+
if !ok {
335+
return errors.New("target capability does not satisfy ExecutableAndTriggerCapability interface")
336+
}
337+
a.Store(&tc)
338+
return nil
339+
}
340+
341+
func (a *atomicExecuteAndTriggerCapability) Info(ctx context.Context) (capabilities.CapabilityInfo, error) {
342+
c := a.Load()
343+
if c == nil {
344+
return capabilities.CapabilityInfo{}, errors.New("capability unavailable")
345+
}
346+
return (*c).Info(ctx)
347+
}
348+
349+
func (a *atomicExecuteAndTriggerCapability) GetState() connectivity.State {
350+
c := a.Load()
351+
if c == nil {
352+
return connectivity.Shutdown
353+
}
354+
if sg, ok := (*c).(StateGetter); ok {
355+
return sg.GetState()
356+
}
357+
return connectivity.State(-1) // unknown
358+
}
359+
360+
func (a *atomicExecuteAndTriggerCapability) RegisterTrigger(ctx context.Context, request capabilities.TriggerRegistrationRequest) (<-chan capabilities.TriggerResponse, error) {
361+
c := a.Load()
362+
if c == nil {
363+
return nil, errors.New("capability unavailable")
364+
}
365+
return (*c).RegisterTrigger(ctx, request)
366+
}
367+
368+
func (a *atomicExecuteAndTriggerCapability) UnregisterTrigger(ctx context.Context, request capabilities.TriggerRegistrationRequest) error {
369+
c := a.Load()
370+
if c == nil {
371+
return errors.New("capability unavailable")
372+
}
373+
return (*c).UnregisterTrigger(ctx, request)
374+
}
375+
376+
func (a *atomicExecuteAndTriggerCapability) RegisterToWorkflow(ctx context.Context, request capabilities.RegisterToWorkflowRequest) error {
377+
c := a.Load()
378+
if c == nil {
379+
return errors.New("capability unavailable")
380+
}
381+
return (*c).RegisterToWorkflow(ctx, request)
382+
}
383+
384+
func (a *atomicExecuteAndTriggerCapability) UnregisterFromWorkflow(ctx context.Context, request capabilities.UnregisterFromWorkflowRequest) error {
385+
c := a.Load()
386+
if c == nil {
387+
return errors.New("capability unavailable")
388+
}
389+
return (*c).UnregisterFromWorkflow(ctx, request)
390+
}
391+
392+
func (a *atomicExecuteAndTriggerCapability) Execute(ctx context.Context, request capabilities.CapabilityRequest) (capabilities.CapabilityResponse, error) {
393+
c := a.Load()
394+
if c == nil {
395+
return capabilities.CapabilityResponse{}, errors.New("capability unavailable")
396+
}
397+
return (*c).Execute(ctx, request)
398+
}

pkg/capabilities/registry/base_test.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,17 @@ func TestRegistry(t *testing.T) {
5050

5151
gc, err := r.Get(ctx, id)
5252
require.NoError(t, err)
53+
info, err := gc.Info(t.Context())
54+
require.NoError(t, err)
5355

54-
assert.Equal(t, c, gc)
56+
assert.Equal(t, c.CapabilityInfo, info)
5557

5658
cs, err := r.List(ctx)
5759
require.NoError(t, err)
5860
assert.Len(t, cs, 1)
59-
assert.Equal(t, c, cs[0])
61+
info, err = cs[0].Info(t.Context())
62+
require.NoError(t, err)
63+
assert.Equal(t, c.CapabilityInfo, info)
6064
}
6165

6266
func TestRegistryCompatibleVersions(t *testing.T) {

pkg/loop/internal/core/services/capability/capabilities.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"sync"
99

1010
"google.golang.org/grpc"
11+
"google.golang.org/grpc/connectivity"
1112
"google.golang.org/protobuf/types/known/emptypb"
1213

1314
"github.com/smartcontractkit/chainlink-common/pkg/capabilities"
@@ -135,14 +136,19 @@ func InfoToReply(info capabilities.CapabilityInfo) *capabilitiespb.CapabilityInf
135136
}
136137

137138
type baseCapabilityClient struct {
139+
c *grpc.ClientConn
138140
grpc capabilitiespb.BaseCapabilityClient
139141
*net.BrokerExt
140142
}
141143

142144
var _ capabilities.BaseCapability = (*baseCapabilityClient)(nil)
143145

144146
func newBaseCapabilityClient(brokerExt *net.BrokerExt, conn *grpc.ClientConn) *baseCapabilityClient {
145-
return &baseCapabilityClient{grpc: capabilitiespb.NewBaseCapabilityClient(conn), BrokerExt: brokerExt}
147+
return &baseCapabilityClient{c: conn, grpc: capabilitiespb.NewBaseCapabilityClient(conn), BrokerExt: brokerExt}
148+
}
149+
150+
func (c *baseCapabilityClient) GetState() connectivity.State {
151+
return c.c.GetState()
146152
}
147153

148154
func (c *baseCapabilityClient) Info(ctx context.Context) (capabilities.CapabilityInfo, error) {

pkg/loop/internal/core/services/capability/capabilities_registry.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010

1111
"github.com/smartcontractkit/chainlink-common/pkg/capabilities"
1212
capabilitiespb "github.com/smartcontractkit/chainlink-common/pkg/capabilities/pb"
13+
"github.com/smartcontractkit/chainlink-common/pkg/capabilities/registry"
1314
"github.com/smartcontractkit/chainlink-common/pkg/loop/internal/net"
1415
"github.com/smartcontractkit/chainlink-common/pkg/loop/internal/pb"
1516
"github.com/smartcontractkit/chainlink-common/pkg/types/core"
@@ -601,6 +602,10 @@ func (c *capabilitiesRegistryServer) List(ctx context.Context, _ *emptypb.Empty)
601602
return reply, nil
602603
}
603604

605+
var _ registry.StateGetter = (*TriggerCapabilityClient)(nil)
606+
var _ registry.StateGetter = (*ExecutableCapabilityClient)(nil)
607+
var _ registry.StateGetter = (*CombinedCapabilityClient)(nil)
608+
604609
func (c *capabilitiesRegistryServer) Add(ctx context.Context, request *pb.AddRequest) (*emptypb.Empty, error) {
605610
conn, err := c.Dial(request.CapabilityID)
606611
if err != nil {

pkg/loop/internal/core/services/capability/capabilities_registry_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@ import (
77
"testing"
88

99
"github.com/hashicorp/go-plugin"
10+
p2ptypes "github.com/smartcontractkit/libocr/ragep2p/types"
1011
"github.com/stretchr/testify/assert"
1112
"github.com/stretchr/testify/mock"
1213
"github.com/stretchr/testify/require"
1314
"google.golang.org/grpc"
1415

15-
p2ptypes "github.com/smartcontractkit/libocr/ragep2p/types"
16-
1716
"github.com/smartcontractkit/chainlink-protos/cre/go/values"
1817

1918
"github.com/smartcontractkit/chainlink-common/pkg/capabilities"

0 commit comments

Comments
 (0)