diff --git a/pkg/solana/logpoller/log_poller.go b/pkg/solana/logpoller/log_poller.go index c82e8eb47..f24c48351 100644 --- a/pkg/solana/logpoller/log_poller.go +++ b/pkg/solana/logpoller/log_poller.go @@ -10,6 +10,7 @@ import ( "github.com/smartcontractkit/chainlink-common/pkg/logger" "github.com/smartcontractkit/chainlink-common/pkg/services" + "github.com/smartcontractkit/chainlink-common/pkg/types/query" "github.com/smartcontractkit/chainlink-common/pkg/utils" "github.com/smartcontractkit/chainlink-solana/pkg/solana/client" @@ -29,6 +30,7 @@ type ORM interface { MarkFilterBackfilled(ctx context.Context, id int64) (err error) InsertLogs(context.Context, []Log) (err error) SelectSeqNums(ctx context.Context) (map[int64]int64, error) + FilteredLogs(ctx context.Context, queryFilter []query.Expression, limitAndSort query.LimitAndSort, queryName string) ([]Log, error) } type Service struct { @@ -257,3 +259,7 @@ func (lp *Service) startFilterBackfill(ctx context.Context, filter Filter, toBlo lp.lggr.Errorw("Failed to mark filter backfill", "filter", filter, "err", err) } } + +func (lp *Service) FilteredLogs(ctx context.Context, queryFilter []query.Expression, limitAndSort query.LimitAndSort, queryName string) ([]Log, error) { + return lp.orm.FilteredLogs(ctx, queryFilter, limitAndSort, queryName) +} diff --git a/pkg/solana/logpoller/mock_orm.go b/pkg/solana/logpoller/mock_orm.go index 1508ba4aa..604a0d5b6 100644 --- a/pkg/solana/logpoller/mock_orm.go +++ b/pkg/solana/logpoller/mock_orm.go @@ -5,6 +5,7 @@ package logpoller import ( context "context" + query "github.com/smartcontractkit/chainlink-common/pkg/types/query" mock "github.com/stretchr/testify/mock" ) @@ -113,6 +114,67 @@ func (_c *mockORM_DeleteFilters_Call) RunAndReturn(run func(context.Context, map return _c } +// FilteredLogs provides a mock function with given fields: ctx, queryFilter, limitAndSort, queryName +func (_m *mockORM) FilteredLogs(ctx context.Context, queryFilter []query.Expression, limitAndSort query.LimitAndSort, queryName string) ([]Log, error) { + ret := _m.Called(ctx, queryFilter, limitAndSort, queryName) + + if len(ret) == 0 { + panic("no return value specified for FilteredLogs") + } + + var r0 []Log + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, []query.Expression, query.LimitAndSort, string) ([]Log, error)); ok { + return rf(ctx, queryFilter, limitAndSort, queryName) + } + if rf, ok := ret.Get(0).(func(context.Context, []query.Expression, query.LimitAndSort, string) []Log); ok { + r0 = rf(ctx, queryFilter, limitAndSort, queryName) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]Log) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, []query.Expression, query.LimitAndSort, string) error); ok { + r1 = rf(ctx, queryFilter, limitAndSort, queryName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// mockORM_FilteredLogs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'FilteredLogs' +type mockORM_FilteredLogs_Call struct { + *mock.Call +} + +// FilteredLogs is a helper method to define mock.On call +// - ctx context.Context +// - queryFilter []query.Expression +// - limitAndSort query.LimitAndSort +// - queryName string +func (_e *mockORM_Expecter) FilteredLogs(ctx interface{}, queryFilter interface{}, limitAndSort interface{}, queryName interface{}) *mockORM_FilteredLogs_Call { + return &mockORM_FilteredLogs_Call{Call: _e.mock.On("FilteredLogs", ctx, queryFilter, limitAndSort, queryName)} +} + +func (_c *mockORM_FilteredLogs_Call) Run(run func(ctx context.Context, queryFilter []query.Expression, limitAndSort query.LimitAndSort, queryName string)) *mockORM_FilteredLogs_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].([]query.Expression), args[2].(query.LimitAndSort), args[3].(string)) + }) + return _c +} + +func (_c *mockORM_FilteredLogs_Call) Return(_a0 []Log, _a1 error) *mockORM_FilteredLogs_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *mockORM_FilteredLogs_Call) RunAndReturn(run func(context.Context, []query.Expression, query.LimitAndSort, string) ([]Log, error)) *mockORM_FilteredLogs_Call { + _c.Call.Return(run) + return _c +} + // InsertFilter provides a mock function with given fields: ctx, filter func (_m *mockORM) InsertFilter(ctx context.Context, filter Filter) (int64, error) { ret := _m.Called(ctx, filter) diff --git a/pkg/solana/logpoller/parser.go b/pkg/solana/logpoller/parser.go index fcb3a8da9..bc82704cf 100644 --- a/pkg/solana/logpoller/parser.go +++ b/pkg/solana/logpoller/parser.go @@ -22,6 +22,7 @@ const ( addressFieldName = "address" eventSigFieldName = "event_sig" defaultSort = "block_number ASC, log_index ASC" + subKeysFieldName = "subkey_values" ) var ( @@ -54,6 +55,80 @@ var _ primitives.Visitor = (*pgDSLParser)(nil) func (v *pgDSLParser) Comparator(_ primitives.Comparator) {} +type IndexedValueComparator struct { + Value IndexedValue + Operator primitives.ComparisonOperator +} + +type eventBySubkeyFilter struct { + SubkeyIndex uint64 + ValueComparers []IndexedValueComparator +} + +func (f *eventBySubkeyFilter) Accept(visitor primitives.Visitor) { + switch v := visitor.(type) { + case *pgDSLParser: + v.VisitEventSubkeysByValueFilter(f) + } +} + +func NewEventBySubkeyFilter(subkeyIndex uint64, valueComparers []primitives.ValueComparator) (query.Expression, error) { + var indexedValueComparators []IndexedValueComparator + for _, cmp := range valueComparers { + iVal, err := NewIndexedValue(cmp.Value) + if err != nil { + return query.Expression{}, err + } + iValCmp := IndexedValueComparator{ + Value: iVal, + Operator: cmp.Operator, + } + indexedValueComparators = append(indexedValueComparators, iValCmp) + } + return query.Expression{ + Primitive: &eventBySubkeyFilter{ + SubkeyIndex: subkeyIndex, + ValueComparers: indexedValueComparators, + }, + }, nil +} + +func (v *pgDSLParser) VisitEventSubkeysByValueFilter(p *eventBySubkeyFilter) { + if len(p.ValueComparers) > 0 { + if p.SubkeyIndex > 3 { // For now, maximum # of fields that can be indexed is 4--we can increase this if needed by adding more db indexes + v.err = fmt.Errorf("invalid subkey index: %d", p.SubkeyIndex) + return + } + + // Add 1 since postgresql arrays are 1-indexed. + subkeyIdx := v.args.withIndexedField("subkey_index", p.SubkeyIndex+1) + + comps := make([]string, len(p.ValueComparers)) + for idx, comp := range p.ValueComparers { + comps[idx], v.err = makeComp(comp, v.args, "subkey_value", subkeyIdx, "subkey_values[:%s] %s :%s") + if v.err != nil { + return + } + } + + v.expression = strings.Join(comps, " AND ") + } +} + +func makeComp(comp IndexedValueComparator, args *queryArgs, field, subfield, pattern string) (string, error) { + cmp, err := cmpOpToString(comp.Operator) + if err != nil { + return "", err + } + + return fmt.Sprintf( + pattern, + subfield, + cmp, + args.withIndexedField(field, comp.Value), + ), nil +} + func (v *pgDSLParser) Block(prim primitives.Block) { cmp, err := cmpOpToString(prim.Operator) if err != nil { diff --git a/pkg/solana/logpoller/parser_test.go b/pkg/solana/logpoller/parser_test.go index 96d2d8656..d176cd747 100644 --- a/pkg/solana/logpoller/parser_test.go +++ b/pkg/solana/logpoller/parser_test.go @@ -198,6 +198,47 @@ func TestDSLParser(t *testing.T) { }) }) + t.Run("query for event topic", func(t *testing.T) { + t.Parallel() + + subkeyFilter, err := NewEventBySubkeyFilter(2, []primitives.ValueComparator{ + {Value: 4, Operator: primitives.Gt}, + {Value: 7, Operator: primitives.Lt}, + }) + require.NoError(t, err) + + parser := &pgDSLParser{} + expressions := []query.Expression{subkeyFilter} + limiter := query.LimitAndSort{} + + result, args, err := parser.buildQuery(chainID, expressions, limiter + require.NoError(t, err) + expectedQuery := logsQuery( + " WHERE chain_id = :chain_id " + + "AND subkey_values[:subkey_index_0] > :subkey_value_0 AND subkey_values[:subkey_index_0] < :subkey_value_1 ORDER BY " + defaultSort) + + var iValLower, iValUpper IndexedValue + iValLower, err = NewIndexedValue(4) + require.NoError(t, err) + iValUpper, err = NewIndexedValue(7) + require.NoError(t, err) + + expectedArgs := map[string]any{ + "chain_id": chainID, + "subkey_index_0": uint64(3), + "subkey_value_0": iValLower, + "subkey_value_1": iValUpper, + } + + require.NoError(t, err) + assert.Equal(t, expectedQuery, result) + + var m map[string]any + m, err = args.toArgs() + require.NoError(t, err) + assert.Equal(t, expectedArgs, m) + }) + // nested query -> a & (b || c) t.Run("nested query", func(t *testing.T) { t.Parallel()