From 743e40f11dfd9c8805e0bfd0c419f0e59d67b64e Mon Sep 17 00:00:00 2001 From: Awbrey Hughlett Date: Thu, 16 Jan 2025 10:55:09 -0600 Subject: [PATCH] test cleanup --- pkg/solana/chainreader/bindings_test.go | 8 +++++++ pkg/solana/chainreader/chain_reader_test.go | 23 +++++++++--------- pkg/solana/logpoller/parser.go | 26 ++++++++++----------- pkg/solana/logpoller/parser_test.go | 23 ++++++++++++------ 4 files changed, 48 insertions(+), 32 deletions(-) diff --git a/pkg/solana/chainreader/bindings_test.go b/pkg/solana/chainreader/bindings_test.go index e8dbea89a..28d8b4c7b 100644 --- a/pkg/solana/chainreader/bindings_test.go +++ b/pkg/solana/chainreader/bindings_test.go @@ -9,7 +9,9 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/smartcontractkit/chainlink-common/pkg/codec" "github.com/smartcontractkit/chainlink-common/pkg/types" + "github.com/smartcontractkit/chainlink-common/pkg/types/query" ) func TestBindings_CreateType(t *testing.T) { @@ -48,6 +50,8 @@ type mockBinding struct { func (_m *mockBinding) SetCodec(_ types.RemoteCodec) {} +func (_m *mockBinding) SetModifier(_ codec.Modifier) {} + func (_m *mockBinding) SetAddress(_ solana.PublicKey) {} func (_m *mockBinding) GetAddress() solana.PublicKey { @@ -63,3 +67,7 @@ func (_m *mockBinding) CreateType(b bool) (any, error) { func (_m *mockBinding) Decode(_ context.Context, _ []byte, _ any) error { return nil } + +func (_m *mockBinding) QueryKey(context.Context, query.KeyFilter, query.LimitAndSort, any) ([]types.Sequence, error) { + return nil, nil +} diff --git a/pkg/solana/chainreader/chain_reader_test.go b/pkg/solana/chainreader/chain_reader_test.go index de37567b6..fdf77ddd2 100644 --- a/pkg/solana/chainreader/chain_reader_test.go +++ b/pkg/solana/chainreader/chain_reader_test.go @@ -63,7 +63,7 @@ func TestSolanaChainReaderService_ServiceCtx(t *testing.T) { t.Parallel() ctx := tests.Context(t) - svc, err := chainreader.NewChainReaderService(logger.Test(t), new(mockedRPCClient), config.ContractReader{}) + svc, err := chainreader.NewContractReaderService(logger.Test(t), new(mockedRPCClient), config.ContractReader{}, nil) require.NoError(t, err) require.NotNil(t, svc) @@ -99,7 +99,7 @@ func TestSolanaChainReaderService_GetLatestValue(t *testing.T) { require.NoError(t, err) client := new(mockedRPCClient) - svc, err := chainreader.NewChainReaderService(logger.Test(t), client, conf) + svc, err := chainreader.NewContractReaderService(logger.Test(t), client, conf, nil) require.NoError(t, err) require.NotNil(t, svc) @@ -136,7 +136,7 @@ func TestSolanaChainReaderService_GetLatestValue(t *testing.T) { client := new(mockedRPCClient) expectedErr := fmt.Errorf("expected error") - svc, err := chainreader.NewChainReaderService(logger.Test(t), client, conf) + svc, err := chainreader.NewContractReaderService(logger.Test(t), client, conf, nil) require.NoError(t, err) require.NotNil(t, svc) @@ -170,7 +170,7 @@ func TestSolanaChainReaderService_GetLatestValue(t *testing.T) { _, conf := newTestConfAndCodec(t) client := new(mockedRPCClient) - svc, err := chainreader.NewChainReaderService(logger.Test(t), client, conf) + svc, err := chainreader.NewContractReaderService(logger.Test(t), client, conf, nil) require.NoError(t, err) require.NotNil(t, svc) @@ -191,7 +191,7 @@ func TestSolanaChainReaderService_GetLatestValue(t *testing.T) { _, conf := newTestConfAndCodec(t) client := new(mockedRPCClient) - svc, err := chainreader.NewChainReaderService(logger.Test(t), client, conf) + svc, err := chainreader.NewContractReaderService(logger.Test(t), client, conf, nil) require.NoError(t, err) require.NotNil(t, svc) @@ -212,7 +212,7 @@ func TestSolanaChainReaderService_GetLatestValue(t *testing.T) { _, conf := newTestConfAndCodec(t) client := new(mockedRPCClient) - svc, err := chainreader.NewChainReaderService(logger.Test(t), client, conf) + svc, err := chainreader.NewContractReaderService(logger.Test(t), client, conf, nil) require.NoError(t, err) require.NotNil(t, svc) @@ -382,9 +382,10 @@ func (_m *mockedRPCClient) SetForAddress(pk ag_solana.PublicKey, bts []byte, err type chainReaderInterfaceTester struct { TestSelectionSupport - conf config.ContractReader - address []string - reader *wrappedTestChainReader + conf config.ContractReader + address []string + reader *wrappedTestChainReader + eventSource chainreader.EventSourcer } func (r *chainReaderInterfaceTester) GetAccountBytes(i int) []byte { @@ -465,7 +466,7 @@ func (r *chainReaderInterfaceTester) Setup(t *testing.T) { func (r *chainReaderInterfaceTester) GetContractReader(t *testing.T) types.ContractReader { client := new(mockedRPCClient) - svc, err := chainreader.NewChainReaderService(logger.Test(t), client, r.conf) + svc, err := chainreader.NewContractReaderService(logger.Test(t), client, r.conf, r.eventSource) if err != nil { t.Logf("chain reader service was not able to start: %s", err.Error()) t.FailNow() @@ -491,7 +492,7 @@ type wrappedTestChainReader struct { types.UnimplementedContractReader test *testing.T - service *chainreader.SolanaChainReaderService + service *chainreader.ContractReaderService client *mockedRPCClient tester ChainComponentsInterfaceTester[*testing.T] testStructQueue []*TestStruct diff --git a/pkg/solana/logpoller/parser.go b/pkg/solana/logpoller/parser.go index 9c62743e5..08e7e56c8 100644 --- a/pkg/solana/logpoller/parser.go +++ b/pkg/solana/logpoller/parser.go @@ -15,13 +15,13 @@ import ( ) const ( - blockFieldName = "block_number" - chainIDFieldName = "chain_id" - timestampFieldName = "block_timestamp" - txHashFieldName = "tx_hash" - addressFieldName = "address" - eventSigFieldName = "event_sig" - defaultSort = "block_number ASC, log_index ASC" + blockFieldName = "block_number" + chainIDFieldName = "chain_id" + timestampFieldName = "block_timestamp" + txHashFieldName = "tx_hash" + addressFieldName = "address" + eventSubkeyFieldName = "event_subkey" + defaultSort = "block_number ASC, log_index ASC" ) var ( @@ -134,13 +134,11 @@ func (v *pgDSLParser) VisitEventSubkeyFilter(p *eventSubkeyFilter) { // TODO: the value type will be the off-chain field type that a raw IDL codec would decode into // this value will need to be wrapped in a special type that will encode the value properly for // direct comparison. - /* - v.expression = fmt.Sprintf( - "%s = :%s", - eventSigFieldName, - v.args.withIndexedField(eventSigFieldName, p.eventSig), - ) - */ + v.expression = fmt.Sprintf( + "%s = :%s", + eventSubkeyFieldName, + v.args.withIndexedField(eventSubkeyFieldName, p.Subkey), + ) } func (v *pgDSLParser) buildQuery( diff --git a/pkg/solana/logpoller/parser_test.go b/pkg/solana/logpoller/parser_test.go index 96d2d8656..3c531e6fe 100644 --- a/pkg/solana/logpoller/parser_test.go +++ b/pkg/solana/logpoller/parser_test.go @@ -53,7 +53,10 @@ func TestDSLParser(t *testing.T) { parser := &pgDSLParser{} expressions := []query.Expression{ NewAddressFilter(pk), - NewEventSigFilter([]byte("test")), + NewEventSubkeyFilter([]string{"test"}, []primitives.ValueComparator{ + {Value: 42, Operator: primitives.Gte}, + {Value: "test_value", Operator: primitives.Eq}, + }), query.Confidence(primitives.Unconfirmed), } limiter := query.NewLimitAndSort(query.CursorLimit(fmt.Sprintf("10-5-%s", txHash), query.CursorFollowing, 20)) @@ -61,7 +64,7 @@ func TestDSLParser(t *testing.T) { result, args, err := parser.buildQuery(chainID, expressions, limiter) expected := logsQuery( " WHERE chain_id = :chain_id " + - "AND (address = :address_0 AND event_sig = :event_sig_0) " + + "AND (address = :address_0 AND event_subkey = :event_subkey_0) " + "AND (block_number > :cursor_block_number OR (block_number = :cursor_block_number " + "AND log_index > :cursor_log_index)) " + "ORDER BY block_number ASC, log_index ASC, tx_hash ASC LIMIT 20") @@ -82,14 +85,17 @@ func TestDSLParser(t *testing.T) { parser := &pgDSLParser{} expressions := []query.Expression{ NewAddressFilter(pk), - NewEventSigFilter([]byte("test")), + NewEventSubkeyFilter([]string{"test"}, []primitives.ValueComparator{ + {Value: 42, Operator: primitives.Gte}, + {Value: "test_value", Operator: primitives.Eq}, + }), } limiter := query.NewLimitAndSort(query.CountLimit(20)) result, args, err := parser.buildQuery(chainID, expressions, limiter) expected := logsQuery( " WHERE chain_id = :chain_id " + - "AND (address = :address_0 AND event_sig = :event_sig_0) " + + "AND (address = :address_0 AND event_subkey = :event_subkey_0) " + "ORDER BY " + defaultSort + " " + "LIMIT 20") @@ -237,9 +243,13 @@ func TestDSLParser(t *testing.T) { t.Run("nested query deep", func(t *testing.T) { t.Parallel() - sigFilter := NewEventSigFilter([]byte("test")) parser := &pgDSLParser{} + sigFilter := NewEventSubkeyFilter([]string{"test"}, []primitives.ValueComparator{ + {Value: 42, Operator: primitives.Gte}, + {Value: "test_value", Operator: primitives.Eq}, + }) + limiter := query.LimitAndSort{} expressions := []query.Expression{ {BoolExpression: query.BoolExpression{ Expressions: []query.Expression{ @@ -261,13 +271,12 @@ func TestDSLParser(t *testing.T) { BoolOperator: query.AND, }}, } - limiter := query.LimitAndSort{} result, args, err := parser.buildQuery(chainID, expressions, limiter) expected := logsQuery( " WHERE chain_id = :chain_id " + "AND (block_timestamp = :block_timestamp_0 " + - "AND (tx_hash = :tx_hash_0 OR event_sig = :event_sig_0)) " + + "AND (tx_hash = :tx_hash_0 OR event_subkey = :event_subkey_0)) " + "ORDER BY " + defaultSort) require.NoError(t, err)