diff --git a/pkg/solana/chainreader/account_read_binding.go b/pkg/solana/chainreader/account_read_binding.go index eacd45fad..b8854b38c 100644 --- a/pkg/solana/chainreader/account_read_binding.go +++ b/pkg/solana/chainreader/account_read_binding.go @@ -2,6 +2,7 @@ package chainreader import ( "context" + "fmt" "github.com/gagliardetto/solana-go" @@ -15,12 +16,16 @@ type accountReadBinding struct { namespace, genericName string codec types.RemoteCodec key solana.PublicKey + isPda bool // flag to signify whether or not the account read is for a PDA + prefix string // only used for PDA public key calculation } -func newAccountReadBinding(namespace, genericName string) *accountReadBinding { +func newAccountReadBinding(namespace, genericName, prefix string, isPda bool) *accountReadBinding { return &accountReadBinding{ namespace: namespace, genericName: genericName, + prefix: prefix, + isPda: isPda, } } @@ -34,8 +39,21 @@ func (b *accountReadBinding) SetAddress(key solana.PublicKey) { b.key = key } -func (b *accountReadBinding) GetAddress() solana.PublicKey { - return b.key +func (b *accountReadBinding) GetAddress(ctx context.Context, params any) (solana.PublicKey, error) { + // Return the bound key if normal account read + if !b.isPda { + return b.key, nil + } + // Calculate the public key if PDA account read + seedBytes, err := b.buildSeedsSlice(ctx, params) + if err != nil { + return solana.PublicKey{}, fmt.Errorf("failed build seeds list for PDA calculation: %w", err) + } + key, _, err := solana.FindProgramAddress(seedBytes, b.key) + if err != nil { + return solana.PublicKey{}, fmt.Errorf("failed find program address for PDA: %w", err) + } + return key, nil } func (b *accountReadBinding) CreateType(forEncoding bool) (any, error) { @@ -45,3 +63,39 @@ func (b *accountReadBinding) CreateType(forEncoding bool) (any, error) { func (b *accountReadBinding) Decode(ctx context.Context, bts []byte, outVal any) error { return b.codec.Decode(ctx, bts, outVal, codec.WrapItemType(false, b.namespace, b.genericName, codec.ChainConfigTypeAccountDef)) } + +// buildSeedsSlice encodes and builds the seedslist to calculate the PDA public key +func (b *accountReadBinding) buildSeedsSlice(ctx context.Context, params any) ([][]byte, error) { + flattenedSeeds := make([]byte, 0, solana.MaxSeeds*solana.MaxSeedLength) + // Append the static prefix string first + flattenedSeeds = append(flattenedSeeds, []byte(b.prefix)...) + // Encode the seeds provided in the params + encodedParamSeeds, err := b.codec.Encode(ctx, params, codec.WrapItemType(true, b.namespace, b.genericName, "")) + if err != nil { + return nil, fmt.Errorf("failed to encode params into bytes for PDA seeds: %w", err) + } + // Append the encoded seeds + flattenedSeeds = append(flattenedSeeds, encodedParamSeeds...) + + if len(flattenedSeeds) > solana.MaxSeeds*solana.MaxSeedLength { + return nil, fmt.Errorf("seeds exceed the maximum allowed length") + } + + // Splitting the seeds since they are expected to be provided separately to FindProgramAddress + // Arbitrarily separating the seeds at max seed length would still yield the same PDA since + // FindProgramAddress appends the seed bytes together under the hood + numSeeds := len(flattenedSeeds) / solana.MaxSeedLength + if len(flattenedSeeds)%solana.MaxSeedLength != 0 { + numSeeds++ + } + seedByteArray := make([][]byte, 0, numSeeds) + for i := 0; i < numSeeds; i++ { + startIdx := i * solana.MaxSeedLength + endIdx := startIdx + solana.MaxSeedLength + if endIdx > len(flattenedSeeds) { + endIdx = len(flattenedSeeds) + } + seedByteArray = append(seedByteArray, flattenedSeeds[startIdx:endIdx]) + } + return seedByteArray, nil +} diff --git a/pkg/solana/chainreader/batch.go b/pkg/solana/chainreader/batch.go index d5990601d..91995bb80 100644 --- a/pkg/solana/chainreader/batch.go +++ b/pkg/solana/chainreader/batch.go @@ -3,6 +3,7 @@ package chainreader import ( "context" "errors" + "fmt" "github.com/gagliardetto/solana-go" @@ -38,7 +39,11 @@ func doMethodBatchCall(ctx context.Context, client MultipleAccountGetter, bindin return nil, err } - keys[idx] = binding.GetAddress() + key, err := binding.GetAddress(ctx, call.Params) + if err != nil { + return nil, fmt.Errorf("failed to get address for %s account read: %w", call.ReadName, err) + } + keys[idx] = key } // Fetch the account data diff --git a/pkg/solana/chainreader/bindings.go b/pkg/solana/chainreader/bindings.go index 751a58fdd..1b927df85 100644 --- a/pkg/solana/chainreader/bindings.go +++ b/pkg/solana/chainreader/bindings.go @@ -11,7 +11,7 @@ import ( type readBinding interface { SetAddress(solana.PublicKey) - GetAddress() solana.PublicKey + GetAddress(context.Context, any) (solana.PublicKey, error) SetCodec(types.RemoteCodec) CreateType(bool) (any, error) Decode(context.Context, []byte, any) error diff --git a/pkg/solana/chainreader/bindings_test.go b/pkg/solana/chainreader/bindings_test.go index e8dbea89a..3dec21194 100644 --- a/pkg/solana/chainreader/bindings_test.go +++ b/pkg/solana/chainreader/bindings_test.go @@ -50,8 +50,8 @@ func (_m *mockBinding) SetCodec(_ types.RemoteCodec) {} func (_m *mockBinding) SetAddress(_ solana.PublicKey) {} -func (_m *mockBinding) GetAddress() solana.PublicKey { - return solana.PublicKey{} +func (_m *mockBinding) GetAddress(_ context.Context, _ any) (solana.PublicKey, error) { + return solana.PublicKey{}, nil } func (_m *mockBinding) CreateType(b bool) (any, error) { diff --git a/pkg/solana/chainreader/chain_reader.go b/pkg/solana/chainreader/chain_reader.go index 1edcb9b8e..41d70161b 100644 --- a/pkg/solana/chainreader/chain_reader.go +++ b/pkg/solana/chainreader/chain_reader.go @@ -284,25 +284,26 @@ func (s *SolanaChainReaderService) init(namespaces map[string]config.ChainContra } func (s *SolanaChainReaderService) addAccountRead(namespace string, genericName string, idl codec.IDL, idlType codec.IdlTypeDef, readDefinition config.ReadDefinition) error { - inputAccountIDLDef := codec.NilIdlTypeDefTy - // TODO: - // if hasPDA{ - // inputAccountIDLDef = pdaType - // } - if err := s.addCodecDef(true, namespace, genericName, codec.ChainConfigTypeAccountDef, idl, inputAccountIDLDef, readDefinition.InputModifications); err != nil { - return err - } - if err := s.addCodecDef(false, namespace, genericName, codec.ChainConfigTypeAccountDef, idl, idlType, readDefinition.OutputModifications); err != nil { return err } s.lookup.addReadNameForContract(namespace, genericName) - s.bindings.AddReadBinding(namespace, genericName, newAccountReadBinding( - namespace, - genericName, - )) + var reader readBinding + var inputAccountIDLDef interface{} + // Create PDA read binding if PDA prefix or seeds configs are populated + if len(readDefinition.PDADefiniton.Prefix) > 0 || len(readDefinition.PDADefiniton.Seeds) > 0 { + inputAccountIDLDef = readDefinition.PDADefiniton + reader = newAccountReadBinding(namespace, genericName, readDefinition.PDADefiniton.Prefix, true) + } else { + inputAccountIDLDef = codec.NilIdlTypeDefTy + reader = newAccountReadBinding(namespace, genericName, "", false) + } + if err := s.addCodecDef(true, namespace, genericName, codec.ChainConfigTypeAccountDef, idl, inputAccountIDLDef, readDefinition.InputModifications); err != nil { + return err + } + s.bindings.AddReadBinding(namespace, genericName, reader) return nil } diff --git a/pkg/solana/chainreader/chain_reader_test.go b/pkg/solana/chainreader/chain_reader_test.go index de37567b6..019ef3e09 100644 --- a/pkg/solana/chainreader/chain_reader_test.go +++ b/pkg/solana/chainreader/chain_reader_test.go @@ -2,6 +2,7 @@ package chainreader_test import ( "context" + go_binary "encoding/binary" "encoding/json" "fmt" "math/big" @@ -13,7 +14,6 @@ import ( "time" "github.com/gagliardetto/solana-go" - ag_solana "github.com/gagliardetto/solana-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -38,6 +38,7 @@ import ( const ( Namespace = "NameSpace" NamedMethod = "NamedMethod1" + PDAAccount = "PDAAccount1" ) func TestSolanaChainReaderService_ReaderInterface(t *testing.T) { @@ -222,7 +223,7 @@ func TestSolanaChainReaderService_GetLatestValue(t *testing.T) { require.NoError(t, svc.Close()) }) - pk := ag_solana.NewWallet().PublicKey() + pk := solana.NewWallet().PublicKey() require.NotNil(t, svc.Bind(ctx, []types.BoundContract{ { @@ -266,6 +267,204 @@ func TestSolanaChainReaderService_GetLatestValue(t *testing.T) { }, })) }) + + t.Run("PDA account read success", func(t *testing.T) { + t.Parallel() + + programID := solana.NewWallet().PublicKey() + pubKey := solana.NewWallet().PublicKey() + uint64Seed := uint64(5) + prefixString := "Prefix" + + readDef := config.ReadDefinition{ + ChainSpecificName: testutils.TestStructWithNestedStruct, + ReadType: config.Account, + OutputModifications: codeccommon.ModifiersConfig{ + &codeccommon.RenameModifierConfig{Fields: map[string]string{"Value": "V"}}, + }, + } + + testCases := []struct { + name string + pdaDefinition codec.PDATypeDef + inputModifier codeccommon.ModifiersConfig + expected solana.PublicKey + params map[string]any + }{ + { + name: "happy path", + pdaDefinition: codec.PDATypeDef{ + Prefix: prefixString, + Seeds: []codec.PDASeed{ + { + Name: "PubKey", + Type: codec.IdlTypePublicKey, + }, + { + Name: "Uint64Seed", + Type: codec.IdlTypeU64, + }, + }, + }, + expected: mustFindProgramAddress(t, programID, [][]byte{[]byte(prefixString), pubKey.Bytes(), go_binary.LittleEndian.AppendUint64([]byte{}, uint64Seed)}), + params: map[string]any{ + "PubKey": pubKey, + "Uint64Seed": uint64Seed, + }, + }, + { + name: "with modifier and random field", + pdaDefinition: codec.PDATypeDef{ + Prefix: prefixString, + Seeds: []codec.PDASeed{ + { + Name: "PubKey", + Type: codec.IdlTypePublicKey, + }, + { + Name: "Uint64Seed", + Type: codec.IdlTypeU64, + }, + }, + }, + inputModifier: codeccommon.ModifiersConfig{ + &codeccommon.RenameModifierConfig{Fields: map[string]string{"PubKey": "PublicKey"}}, + }, + expected: mustFindProgramAddress(t, programID, [][]byte{[]byte(prefixString), pubKey.Bytes(), go_binary.LittleEndian.AppendUint64([]byte{}, uint64Seed)}), + params: map[string]any{ + "PublicKey": pubKey, + "randomField": "randomValue", // unused field should be ignored by the codec + "Uint64Seed": uint64Seed, + }, + }, + { + name: "only prefix", + pdaDefinition: codec.PDATypeDef{ + Prefix: prefixString, + }, + expected: mustFindProgramAddress(t, programID, [][]byte{[]byte(prefixString)}), + params: nil, + }, + { + name: "no prefix", + pdaDefinition: codec.PDATypeDef{ + Prefix: "", + Seeds: []codec.PDASeed{ + { + Name: "PubKey", + Type: codec.IdlTypePublicKey, + }, + { + Name: "Uint64Seed", + Type: codec.IdlTypeU64, + }, + }, + }, + expected: mustFindProgramAddress(t, programID, [][]byte{pubKey.Bytes(), go_binary.LittleEndian.AppendUint64([]byte{}, uint64Seed)}), + params: map[string]any{ + "PubKey": pubKey, + "Uint64Seed": uint64Seed, + }, + }, + { + name: "public key seed provided as bytes", + pdaDefinition: codec.PDATypeDef{ + Prefix: prefixString, + Seeds: []codec.PDASeed{ + { + Name: "PubKey", + Type: codec.IdlTypePublicKey, + }, + }, + }, + expected: mustFindProgramAddress(t, programID, [][]byte{[]byte(prefixString), pubKey.Bytes()}), + params: map[string]any{ + "PubKey": pubKey.Bytes(), + }, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + testReadDef := readDef + testReadDef.PDADefiniton = testCase.pdaDefinition + testReadDef.InputModifications = testCase.inputModifier + testCodec, conf := newTestConfAndCodecWithInjectibleReadDef(t, PDAAccount, testReadDef) + encoded, err := testCodec.Encode(ctx, expected, testutils.TestStructWithNestedStruct) + require.NoError(t, err) + + client := new(mockedRPCClient) + svc, err := chainreader.NewChainReaderService(logger.Test(t), client, conf) + require.NoError(t, err) + require.NotNil(t, svc) + require.NoError(t, svc.Start(ctx)) + + t.Cleanup(func() { + require.NoError(t, svc.Close()) + }) + + binding := types.BoundContract{ + Name: Namespace, + Address: programID.String(), // Set the program ID used to calculate the PDA + } + + client.SetForAddress(testCase.expected, encoded, nil, 0) + + require.NoError(t, svc.Bind(ctx, []types.BoundContract{binding})) + + var result modifiedStructWithNestedStruct + require.NoError(t, svc.GetLatestValue(ctx, binding.ReadIdentifier(PDAAccount), primitives.Unconfirmed, testCase.params, &result)) + + assert.Equal(t, expected.InnerStruct, result.InnerStruct) + assert.Equal(t, expected.Value, result.V) + assert.Equal(t, expected.TimeVal, result.TimeVal) + assert.Equal(t, expected.DurationVal, result.DurationVal) + }) + } + }) + + t.Run("PDA account read errors if missing param", func(t *testing.T) { + prefixString := "Prefix" + readDef := config.ReadDefinition{ + ChainSpecificName: testutils.TestStructWithNestedStruct, + ReadType: config.Account, + PDADefiniton: codec.PDATypeDef{ + Prefix: prefixString, + Seeds: []codec.PDASeed{ + { + Name: "PubKey", + Type: codec.IdlTypePublicKey, + }, + }, + }, + OutputModifications: codeccommon.ModifiersConfig{ + &codeccommon.RenameModifierConfig{Fields: map[string]string{"Value": "V"}}, + }, + } + _, conf := newTestConfAndCodecWithInjectibleReadDef(t, PDAAccount, readDef) + + client := new(mockedRPCClient) + svc, err := chainreader.NewChainReaderService(logger.Test(t), client, conf) + require.NoError(t, err) + require.NotNil(t, svc) + require.NoError(t, svc.Start(ctx)) + + t.Cleanup(func() { + require.NoError(t, svc.Close()) + }) + + binding := types.BoundContract{ + Name: Namespace, + Address: solana.NewWallet().PublicKey().String(), // Set the program ID used to calculate the PDA + } + + require.NoError(t, svc.Bind(ctx, []types.BoundContract{binding})) + + var result modifiedStructWithNestedStruct + require.Error(t, svc.GetLatestValue(ctx, binding.ReadIdentifier(PDAAccount), primitives.Unconfirmed, map[string]any{ + "randomField": "randomValue", // unused field should be ignored by the codec + }, &result)) + }) } func newTestIDLAndCodec(t *testing.T) (string, codec.IDL, types.RemoteCodec) { @@ -311,6 +510,23 @@ func newTestConfAndCodec(t *testing.T) (types.RemoteCodec, config.ContractReader return testCodec, conf } +func newTestConfAndCodecWithInjectibleReadDef(t *testing.T, readDefName string, readDef config.ReadDefinition) (types.RemoteCodec, config.ContractReader) { + t.Helper() + rawIDL, _, testCodec := newTestIDLAndCodec(t) + conf := config.ContractReader{ + Namespaces: map[string]config.ChainContractReader{ + Namespace: { + IDL: mustUnmarshalIDL(t, rawIDL), + Reads: map[string]config.ReadDefinition{ + readDefName: readDef, + }, + }, + }, + } + + return testCodec, conf +} + type modifiedStructWithNestedStruct struct { V uint8 InnerStruct testutils.ObjectRef1 @@ -320,7 +536,7 @@ type modifiedStructWithNestedStruct struct { BasicVector []string TimeVal int64 DurationVal time.Duration - PublicKey ag_solana.PublicKey + PublicKey solana.PublicKey EnumVal uint8 } @@ -365,7 +581,7 @@ func (_m *mockedRPCClient) SetNext(bts []byte, err error, delay time.Duration) { }) } -func (_m *mockedRPCClient) SetForAddress(pk ag_solana.PublicKey, bts []byte, err error, delay time.Duration) { +func (_m *mockedRPCClient) SetForAddress(pk solana.PublicKey, bts []byte, err error, delay time.Duration) { _m.mu.Lock() defer _m.mu.Unlock() @@ -409,7 +625,7 @@ func (r *chainReaderInterfaceTester) Name() string { func (r *chainReaderInterfaceTester) Setup(t *testing.T) { r.address = make([]string, 7) for idx := range r.address { - r.address[idx] = ag_solana.NewWallet().PublicKey().String() + r.address[idx] = solana.NewWallet().PublicKey().String() } r.conf = config.ContractReader{ @@ -643,7 +859,7 @@ func (r *wrappedTestChainReader) GetLatestValue(ctx context.Context, readIdentif } } - r.client.SetForAddress(ag_solana.PublicKey(r.tester.GetAccountBytes(acct)), bts, nil, 0) + r.client.SetForAddress(solana.PublicKey(r.tester.GetAccountBytes(acct)), bts, nil, 0) return r.service.GetLatestValue(ctx, readIdentifier, confidenceLevel, params, returnVal) } @@ -925,3 +1141,9 @@ func mustUnmarshalIDL(t *testing.T, rawIDL string) codec.IDL { return idl } + +func mustFindProgramAddress(t *testing.T, programID solana.PublicKey, seeds [][]byte) solana.PublicKey { + key, _, err := solana.FindProgramAddress(seeds, programID) + require.NoError(t, err) + return key +} diff --git a/pkg/solana/codec/anchoridl.go b/pkg/solana/codec/anchoridl.go index 3fc296e97..0ea1322ad 100644 --- a/pkg/solana/codec/anchoridl.go +++ b/pkg/solana/codec/anchoridl.go @@ -140,6 +140,18 @@ type IdlField struct { Type IdlType `json:"type"` } +// PDA is a struct that does not correlate to an official IDL type +// It is needed to encode seeds to calculate the address for PDA account reads +type PDATypeDef struct { + Prefix string `json:"prefix,omitempty"` + Seeds []PDASeed `json:"seeds,omitempty"` +} + +type PDASeed struct { + Name string `json:"name"` + Type IdlTypeAsString `json:"type"` +} + type IdlTypeAsString string const ( @@ -255,6 +267,12 @@ type IdlType struct { asIdlTypeArray *IdlTypeArray } +func NewIdlStringType(asString IdlTypeAsString) IdlType { + return IdlType{ + asString: asString, + } +} + func (env *IdlType) IsString() bool { return env.asString != "" } diff --git a/pkg/solana/codec/codec_entry.go b/pkg/solana/codec/codec_entry.go index f22b05984..d3b459d57 100644 --- a/pkg/solana/codec/codec_entry.go +++ b/pkg/solana/codec/codec_entry.go @@ -53,6 +53,22 @@ func NewAccountEntry(offchainName string, idlTypes AccountIDLTypes, includeDiscr ), nil } +func NewPDAEntry(offchainName string, pdaTypeDef PDATypeDef, mod codec.Modifier, builder commonencodings.Builder) (Entry, error) { + // PDA seeds do not have any dependecies in the IDL so the type def slice can be left empty for refs + _, accCodec, err := asStruct(pdaSeedsToIdlField(pdaTypeDef.Seeds), createRefs(IdlTypeDefSlice{}, builder), offchainName, false, false) + if err != nil { + return nil, err + } + + return newEntry( + offchainName, + offchainName, // PDA seeds do not correlate to anything on-chain so reusing offchain name + accCodec, + false, + mod, + ), nil +} + type InstructionArgsIDLTypes struct { Instruction IdlInstruction Types IdlTypeDefSlice @@ -205,3 +221,14 @@ func eventFieldsToFields(evFields []IdlEventField) []IdlField { } return idlFields } + +func pdaSeedsToIdlField(seeds []PDASeed) []IdlField { + idlFields := make([]IdlField, 0, len(seeds)) + for _, seed := range seeds { + idlFields = append(idlFields, IdlField{ + Name: seed.Name, + Type: NewIdlStringType(seed.Type), + }) + } + return idlFields +} diff --git a/pkg/solana/codec/solana.go b/pkg/solana/codec/solana.go index 3a19a1683..08ff964a9 100644 --- a/pkg/solana/codec/solana.go +++ b/pkg/solana/codec/solana.go @@ -98,6 +98,8 @@ func CreateCodecEntry(idlDefinition interface{}, offChainName string, idl IDL, m entry, err = NewInstructionArgsEntry(offChainName, InstructionArgsIDLTypes{Instruction: v, Types: idl.Types}, mod, binary.LittleEndian()) case IdlEvent: entry, err = NewEventArgsEntry(offChainName, EventIDLTypes{Event: v, Types: idl.Types}, true, mod, binary.LittleEndian()) + case PDATypeDef: + entry, err = NewPDAEntry(offChainName, v, mod, binary.LittleEndian()) default: return nil, fmt.Errorf("unknown codec IDL definition: %T", idlDefinition) } diff --git a/pkg/solana/config/chain_reader.go b/pkg/solana/config/chain_reader.go index 57ccb9040..ab09e013a 100644 --- a/pkg/solana/config/chain_reader.go +++ b/pkg/solana/config/chain_reader.go @@ -26,6 +26,7 @@ type ReadDefinition struct { ReadType ReadType `json:"readType,omitempty"` InputModifications commoncodec.ModifiersConfig `json:"inputModifications,omitempty"` OutputModifications commoncodec.ModifiersConfig `json:"outputModifications,omitempty"` + PDADefiniton codec.PDATypeDef `json:"pdaDefinition,omitempty"` // Only used for PDA account reads } type ReadType int diff --git a/pkg/solana/config/chain_reader_test.go b/pkg/solana/config/chain_reader_test.go index 19bcedbe3..cb52c56f9 100644 --- a/pkg/solana/config/chain_reader_test.go +++ b/pkg/solana/config/chain_reader_test.go @@ -43,6 +43,12 @@ func TestChainReaderConfig(t *testing.T) { assert.Equal(t, validChainReaderConfig, result) }) + t.Run("valid unmarshal with PDA account", func(t *testing.T) { + var result config.ContractReader + require.NoError(t, json.Unmarshal([]byte(validJSONWithIDLAsString), &result)) + assert.Equal(t, validChainReaderConfig, result) + }) + t.Run("invalid unmarshal", func(t *testing.T) { t.Parallel()