Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable ChainReader to read PDA account state #1003

Merged
merged 17 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pkg/solana/chainreader/account_read_binding.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ func (b *accountReadBinding) SetAddress(key solana.PublicKey) {
b.key = key
}

func (b *accountReadBinding) GetAddress() solana.PublicKey {
return b.key
func (b *accountReadBinding) GetAddress(params any) (solana.PublicKey, error) {
return b.key, nil
}

func (b *accountReadBinding) CreateType(forEncoding bool) (any, error) {
Expand Down
7 changes: 6 additions & 1 deletion pkg/solana/chainreader/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package chainreader
import (
"context"
"errors"
"fmt"

"github.com/gagliardetto/solana-go"

Expand Down Expand Up @@ -38,7 +39,11 @@ func doMethodBatchCall(ctx context.Context, client MultipleAccountGetter, bindin
return nil, err
}

keys[idx] = binding.GetAddress()
key, err := binding.GetAddress(call.Params)
if err != nil {
return nil, fmt.Errorf("failed to get address for binding %v: %w", binding, err)
}
keys[idx] = key
}

// Fetch the account data
Expand Down
2 changes: 1 addition & 1 deletion pkg/solana/chainreader/bindings.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (

type readBinding interface {
SetAddress(solana.PublicKey)
GetAddress() solana.PublicKey
GetAddress(params any) (solana.PublicKey, error)
SetCodec(types.RemoteCodec)
CreateType(bool) (any, error)
Decode(context.Context, []byte, any) error
Expand Down
4 changes: 2 additions & 2 deletions pkg/solana/chainreader/bindings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ func (_m *mockBinding) SetCodec(_ types.RemoteCodec) {}

func (_m *mockBinding) SetAddress(_ solana.PublicKey) {}

func (_m *mockBinding) GetAddress() solana.PublicKey {
return solana.PublicKey{}
func (_m *mockBinding) GetAddress(params any) (solana.PublicKey, error) {
return solana.PublicKey{}, nil
}

func (_m *mockBinding) CreateType(b bool) (any, error) {
Expand Down
20 changes: 12 additions & 8 deletions pkg/solana/chainreader/chain_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,6 @@ func (s *SolanaChainReaderService) init(namespaces map[string]config.ChainContra

func (s *SolanaChainReaderService) addAccountRead(namespace string, genericName string, idl codec.IDL, idlType codec.IdlTypeDef, readDefinition config.ReadDefinition) error {
inputAccountIDLDef := codec.NilIdlTypeDefTy
// TODO:
// if hasPDA{
// inputAccountIDLDef = pdaType
// }
if err := s.addCodecDef(true, namespace, genericName, codec.ChainConfigTypeAccountDef, idl, inputAccountIDLDef, readDefinition.InputModifications); err != nil {
return err
}
Expand All @@ -299,10 +295,18 @@ func (s *SolanaChainReaderService) addAccountRead(namespace string, genericName

s.lookup.addReadNameForContract(namespace, genericName)

s.bindings.AddReadBinding(namespace, genericName, newAccountReadBinding(
namespace,
genericName,
))
var reader readBinding
// Create PDA read binding if Seeds config is non-nil
// Note: Empty seeds list is a valid configuration for a PDA so a length check is intentionally skipped
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would an empty seeds list just be a normal account read?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think FindProgramAddress still outputs a different programID even with empty seeds since it's still applying the bump seed. I think every PDA in the on-chain code uses some sort of seed though. So guess it wouldn't hurt to also validate empty seeds for our current use case but I wanted to keep this open in case something changes in the future.

if readDefinition.Seeds != nil {
if len(readDefinition.Seeds) > solana.MaxSeeds {
return fmt.Errorf("read definition contains more seeds (%d) than the max allowed (%d) for PDAs", len(readDefinition.Seeds), solana.MaxSeeds)
}
reader = newPdaReadBinding(namespace, genericName, readDefinition.Seeds)
} else {
reader = newAccountReadBinding(namespace, genericName)
}
s.bindings.AddReadBinding(namespace, genericName, reader)

return nil
}
Expand Down
89 changes: 81 additions & 8 deletions pkg/solana/chainreader/chain_reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package chainreader_test

import (
"context"
go_binary "encoding/binary"
"encoding/json"
"fmt"
"math/big"
Expand All @@ -13,7 +14,6 @@ import (
"time"

"github.com/gagliardetto/solana-go"
ag_solana "github.com/gagliardetto/solana-go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand All @@ -36,8 +36,15 @@ import (
)

const (
Namespace = "NameSpace"
NamedMethod = "NamedMethod1"
Namespace = "NameSpace"
NamedMethod = "NamedMethod1"
PDAAccount = "PDAAccount1"
PDAStringSeed = "Seed"
PDANumSeed = uint64(5)
)

var (
PDAPublicKeySeed = solana.NewWallet().PublicKey()
)

func TestSolanaChainReaderService_ReaderInterface(t *testing.T) {
Expand Down Expand Up @@ -129,6 +136,54 @@ func TestSolanaChainReaderService_GetLatestValue(t *testing.T) {
assert.Equal(t, expected.DurationVal, result.DurationVal)
})

t.Run("PDA account read successful", func(t *testing.T) {
t.Parallel()

testCodec, conf := newTestConfAndCodec(t)
encoded, err := testCodec.Encode(ctx, expected, testutils.TestStructWithNestedStruct)

require.NoError(t, err)

client := new(mockedRPCClient)
svc, err := chainreader.NewChainReaderService(logger.Test(t), client, conf)

require.NoError(t, err)
require.NotNil(t, svc)
require.NoError(t, svc.Start(ctx))

t.Cleanup(func() {
require.NoError(t, svc.Close())
})

programID := solana.NewWallet().PublicKey()

var result modifiedStructWithNestedStruct

binding := types.BoundContract{
Name: Namespace,
Address: programID.String(), // Set the program ID used to calculate the PDA
}

pdaAccount, _, err := solana.FindProgramAddress([][]byte{
[]byte(PDAStringSeed),
PDAPublicKeySeed.Bytes(),
go_binary.LittleEndian.AppendUint64([]byte{}, PDANumSeed),
}, programID)
require.NoError(t, err)

client.SetForAddress(pdaAccount, encoded, nil, 0)

param := struct{ Parameter uint64 }{Parameter: PDANumSeed}

require.NoError(t, svc.Bind(ctx, []types.BoundContract{binding}))
require.NoError(t, svc.GetLatestValue(ctx, binding.ReadIdentifier(PDAAccount), primitives.Unconfirmed, param, &result))

assert.Equal(t, expected.InnerStruct, result.InnerStruct)
assert.Equal(t, expected.Value, result.V)
assert.Equal(t, expected.TimeVal, result.TimeVal)
assert.Equal(t, expected.DurationVal, result.DurationVal)
})

t.Run("Error Returned From Account Reader", func(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -222,7 +277,7 @@ func TestSolanaChainReaderService_GetLatestValue(t *testing.T) {
require.NoError(t, svc.Close())
})

pk := ag_solana.NewWallet().PublicKey()
pk := solana.NewWallet().PublicKey()

require.NotNil(t, svc.Bind(ctx, []types.BoundContract{
{
Expand Down Expand Up @@ -303,6 +358,24 @@ func newTestConfAndCodec(t *testing.T) (types.RemoteCodec, config.ContractReader
&codeccommon.RenameModifierConfig{Fields: map[string]string{"Value": "V"}},
},
},
PDAAccount: {
ChainSpecificName: testutils.TestStructWithNestedStruct,
ReadType: config.Account,
Seeds: []config.Seed{
{
Value: PDAStringSeed,
},
{
Value: PDAPublicKeySeed,
},
{
Location: "Parameter",
},
},
OutputModifications: codeccommon.ModifiersConfig{
&codeccommon.RenameModifierConfig{Fields: map[string]string{"Value": "V"}},
},
},
},
},
},
Expand All @@ -320,7 +393,7 @@ type modifiedStructWithNestedStruct struct {
BasicVector []string
TimeVal int64
DurationVal time.Duration
PublicKey ag_solana.PublicKey
PublicKey solana.PublicKey
EnumVal uint8
}

Expand Down Expand Up @@ -365,7 +438,7 @@ func (_m *mockedRPCClient) SetNext(bts []byte, err error, delay time.Duration) {
})
}

func (_m *mockedRPCClient) SetForAddress(pk ag_solana.PublicKey, bts []byte, err error, delay time.Duration) {
func (_m *mockedRPCClient) SetForAddress(pk solana.PublicKey, bts []byte, err error, delay time.Duration) {
_m.mu.Lock()
defer _m.mu.Unlock()

Expand Down Expand Up @@ -409,7 +482,7 @@ func (r *chainReaderInterfaceTester) Name() string {
func (r *chainReaderInterfaceTester) Setup(t *testing.T) {
r.address = make([]string, 7)
for idx := range r.address {
r.address[idx] = ag_solana.NewWallet().PublicKey().String()
r.address[idx] = solana.NewWallet().PublicKey().String()
}

r.conf = config.ContractReader{
Expand Down Expand Up @@ -643,7 +716,7 @@ func (r *wrappedTestChainReader) GetLatestValue(ctx context.Context, readIdentif
}
}

r.client.SetForAddress(ag_solana.PublicKey(r.tester.GetAccountBytes(acct)), bts, nil, 0)
r.client.SetForAddress(solana.PublicKey(r.tester.GetAccountBytes(acct)), bts, nil, 0)

return r.service.GetLatestValue(ctx, readIdentifier, confidenceLevel, params, returnVal)
}
Expand Down
97 changes: 97 additions & 0 deletions pkg/solana/chainreader/pda_read_binding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package chainreader

import (
"context"
"fmt"

"github.com/gagliardetto/solana-go"

"github.com/smartcontractkit/chainlink-common/pkg/types"

"github.com/smartcontractkit/chainlink-solana/pkg/solana/codec"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/config"
"github.com/smartcontractkit/chainlink-solana/pkg/solana/utils"
)

// pdaReadBinding provides calculating PDA addresses with the provided seeds and reading decoded PDA Account data using a defined codec
type pdaReadBinding struct {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we could merge this into account_read_binding.go? They essentially do the same thing, since we’re reading an account in both cases. The only difference is that we need to calculate the PDA address if parameters are provided.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm good point. I think I could. I'll give this a shot

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I merged the two in the latest commit

namespace string
genericName string
codec types.RemoteCodec
programID solana.PublicKey
seeds []config.Seed
}

func newPdaReadBinding(namespace, genericName string, seeds []config.Seed) *pdaReadBinding {
return &pdaReadBinding{
namespace: namespace,
genericName: genericName,
seeds: seeds,
}
}

var _ readBinding = &pdaReadBinding{}

func (b *pdaReadBinding) SetCodec(codec types.RemoteCodec) {
b.codec = codec
}

func (b *pdaReadBinding) SetAddress(programID solana.PublicKey) {
b.programID = programID
}

func (b *pdaReadBinding) GetAddress(params any) (solana.PublicKey, error) {
seedBytes, err := b.buildSeedsSlice(params)
if err != nil {
return solana.PublicKey{}, fmt.Errorf("failed build seeds list for PDA generation: %w", err)
}
key, _, err := solana.FindProgramAddress(seedBytes, b.programID)
if err != nil {
return solana.PublicKey{}, fmt.Errorf("failed find program address for PDA: %w", err)
}
return key, nil
}

func (b *pdaReadBinding) CreateType(forEncoding bool) (any, error) {
return b.codec.CreateType(codec.WrapItemType(forEncoding, b.namespace, b.genericName, codec.ChainConfigTypeAccountDef), forEncoding)
}

func (b *pdaReadBinding) Decode(ctx context.Context, bts []byte, outVal any) error {
return b.codec.Decode(ctx, bts, outVal, codec.WrapItemType(false, b.namespace, b.genericName, codec.ChainConfigTypeAccountDef))
}

func (b *pdaReadBinding) buildSeedsSlice(params any) ([][]byte, error) {
if b.seeds == nil {
return [][]byte{}, nil
}

seedByteArray := make([][]byte, 0, len(b.seeds))
for _, seed := range b.seeds {
silaslenihan marked this conversation as resolved.
Show resolved Hide resolved
if seed.Value != nil && len(seed.Location) > 0 {
return nil, fmt.Errorf("seed cannot have both Value (%v) and Location (%s) defined", seed.Value, seed.Location)
}
if seed.Value != nil {
byteArray := utils.ConvertAnyToPDASeed(seed.Value)
if byteArray == nil {
return nil, fmt.Errorf("failed to convert seed %v to byte array", seed.Value)
}
if len(byteArray) > solana.MaxSeedLength {
return nil, fmt.Errorf("seed length %d exceeds the max allowed length %d", len(byteArray), solana.MaxSeedLength)
}
seedByteArray = append(seedByteArray, utils.ConvertAnyToPDASeed(seed.Value))
continue
}
if len(seed.Location) > 0 {
byteArrays, err := utils.GetValuesAtLocation(params, seed.Location)
if err != nil {
return nil, fmt.Errorf("failed to find seed at location %s in params: %w", seed.Location, err)
}
if len(byteArrays) != 1 {
return nil, fmt.Errorf("expected 1 seed. found %d seeds at location %s", len(byteArrays), seed.Location)
}
seedByteArray = append(seedByteArray, byteArrays[0])
continue
}
}
return seedByteArray, nil
}
Loading
Loading