Skip to content

Commit

Permalink
Enable ChainReader to read PDA account state (#1003)
Browse files Browse the repository at this point in the history
* Implemented PDA account reading in ChainReader

* Fixed linting

* Updated PDA account reads to use the codec to encode seeds in params

* Removed string seed encoder

* Reverted moving ChainWriter helper methods to utils

* Updated the PDA codec entry to use the existing builder methods

* Removed duplicated public key encoder

* Added builder method for IDL string types

* Updated PDA account read test to use input modifier

* Moved PDA seed IDL type builder out of the config

* Added new PDA account read unit tests

* Fixed linting

* Added back encoder codec def for normal account reads

* Merged the pda read binding with the existing account read binding
  • Loading branch information
amit-momin authored Jan 17, 2025
1 parent 81a855e commit 933f88f
Show file tree
Hide file tree
Showing 11 changed files with 362 additions and 26 deletions.
60 changes: 57 additions & 3 deletions pkg/solana/chainreader/account_read_binding.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package chainreader

import (
"context"
"fmt"

"github.com/gagliardetto/solana-go"

Expand All @@ -15,12 +16,16 @@ type accountReadBinding struct {
namespace, genericName string
codec types.RemoteCodec
key solana.PublicKey
isPda bool // flag to signify whether or not the account read is for a PDA
prefix string // only used for PDA public key calculation
}

func newAccountReadBinding(namespace, genericName string) *accountReadBinding {
func newAccountReadBinding(namespace, genericName, prefix string, isPda bool) *accountReadBinding {
return &accountReadBinding{
namespace: namespace,
genericName: genericName,
prefix: prefix,
isPda: isPda,
}
}

Expand All @@ -34,8 +39,21 @@ func (b *accountReadBinding) SetAddress(key solana.PublicKey) {
b.key = key
}

func (b *accountReadBinding) GetAddress() solana.PublicKey {
return b.key
func (b *accountReadBinding) GetAddress(ctx context.Context, params any) (solana.PublicKey, error) {
// Return the bound key if normal account read
if !b.isPda {
return b.key, nil
}
// Calculate the public key if PDA account read
seedBytes, err := b.buildSeedsSlice(ctx, params)
if err != nil {
return solana.PublicKey{}, fmt.Errorf("failed build seeds list for PDA calculation: %w", err)
}
key, _, err := solana.FindProgramAddress(seedBytes, b.key)
if err != nil {
return solana.PublicKey{}, fmt.Errorf("failed find program address for PDA: %w", err)
}
return key, nil
}

func (b *accountReadBinding) CreateType(forEncoding bool) (any, error) {
Expand All @@ -45,3 +63,39 @@ func (b *accountReadBinding) CreateType(forEncoding bool) (any, error) {
func (b *accountReadBinding) Decode(ctx context.Context, bts []byte, outVal any) error {
return b.codec.Decode(ctx, bts, outVal, codec.WrapItemType(false, b.namespace, b.genericName, codec.ChainConfigTypeAccountDef))
}

// buildSeedsSlice encodes and builds the seedslist to calculate the PDA public key
func (b *accountReadBinding) buildSeedsSlice(ctx context.Context, params any) ([][]byte, error) {
flattenedSeeds := make([]byte, 0, solana.MaxSeeds*solana.MaxSeedLength)
// Append the static prefix string first
flattenedSeeds = append(flattenedSeeds, []byte(b.prefix)...)
// Encode the seeds provided in the params
encodedParamSeeds, err := b.codec.Encode(ctx, params, codec.WrapItemType(true, b.namespace, b.genericName, ""))
if err != nil {
return nil, fmt.Errorf("failed to encode params into bytes for PDA seeds: %w", err)
}
// Append the encoded seeds
flattenedSeeds = append(flattenedSeeds, encodedParamSeeds...)

if len(flattenedSeeds) > solana.MaxSeeds*solana.MaxSeedLength {
return nil, fmt.Errorf("seeds exceed the maximum allowed length")
}

// Splitting the seeds since they are expected to be provided separately to FindProgramAddress
// Arbitrarily separating the seeds at max seed length would still yield the same PDA since
// FindProgramAddress appends the seed bytes together under the hood
numSeeds := len(flattenedSeeds) / solana.MaxSeedLength
if len(flattenedSeeds)%solana.MaxSeedLength != 0 {
numSeeds++
}
seedByteArray := make([][]byte, 0, numSeeds)
for i := 0; i < numSeeds; i++ {
startIdx := i * solana.MaxSeedLength
endIdx := startIdx + solana.MaxSeedLength
if endIdx > len(flattenedSeeds) {
endIdx = len(flattenedSeeds)
}
seedByteArray = append(seedByteArray, flattenedSeeds[startIdx:endIdx])
}
return seedByteArray, nil
}
7 changes: 6 additions & 1 deletion pkg/solana/chainreader/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package chainreader
import (
"context"
"errors"
"fmt"

"github.com/gagliardetto/solana-go"

Expand Down Expand Up @@ -38,7 +39,11 @@ func doMethodBatchCall(ctx context.Context, client MultipleAccountGetter, bindin
return nil, err
}

keys[idx] = binding.GetAddress()
key, err := binding.GetAddress(ctx, call.Params)
if err != nil {
return nil, fmt.Errorf("failed to get address for %s account read: %w", call.ReadName, err)
}
keys[idx] = key
}

// Fetch the account data
Expand Down
2 changes: 1 addition & 1 deletion pkg/solana/chainreader/bindings.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (

type readBinding interface {
SetAddress(solana.PublicKey)
GetAddress() solana.PublicKey
GetAddress(context.Context, any) (solana.PublicKey, error)
SetCodec(types.RemoteCodec)
CreateType(bool) (any, error)
Decode(context.Context, []byte, any) error
Expand Down
4 changes: 2 additions & 2 deletions pkg/solana/chainreader/bindings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ func (_m *mockBinding) SetCodec(_ types.RemoteCodec) {}

func (_m *mockBinding) SetAddress(_ solana.PublicKey) {}

func (_m *mockBinding) GetAddress() solana.PublicKey {
return solana.PublicKey{}
func (_m *mockBinding) GetAddress(_ context.Context, _ any) (solana.PublicKey, error) {
return solana.PublicKey{}, nil
}

func (_m *mockBinding) CreateType(b bool) (any, error) {
Expand Down
27 changes: 14 additions & 13 deletions pkg/solana/chainreader/chain_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,25 +284,26 @@ func (s *SolanaChainReaderService) init(namespaces map[string]config.ChainContra
}

func (s *SolanaChainReaderService) addAccountRead(namespace string, genericName string, idl codec.IDL, idlType codec.IdlTypeDef, readDefinition config.ReadDefinition) error {
inputAccountIDLDef := codec.NilIdlTypeDefTy
// TODO:
// if hasPDA{
// inputAccountIDLDef = pdaType
// }
if err := s.addCodecDef(true, namespace, genericName, codec.ChainConfigTypeAccountDef, idl, inputAccountIDLDef, readDefinition.InputModifications); err != nil {
return err
}

if err := s.addCodecDef(false, namespace, genericName, codec.ChainConfigTypeAccountDef, idl, idlType, readDefinition.OutputModifications); err != nil {
return err
}

s.lookup.addReadNameForContract(namespace, genericName)

s.bindings.AddReadBinding(namespace, genericName, newAccountReadBinding(
namespace,
genericName,
))
var reader readBinding
var inputAccountIDLDef interface{}
// Create PDA read binding if PDA prefix or seeds configs are populated
if len(readDefinition.PDADefiniton.Prefix) > 0 || len(readDefinition.PDADefiniton.Seeds) > 0 {
inputAccountIDLDef = readDefinition.PDADefiniton
reader = newAccountReadBinding(namespace, genericName, readDefinition.PDADefiniton.Prefix, true)
} else {
inputAccountIDLDef = codec.NilIdlTypeDefTy
reader = newAccountReadBinding(namespace, genericName, "", false)
}
if err := s.addCodecDef(true, namespace, genericName, codec.ChainConfigTypeAccountDef, idl, inputAccountIDLDef, readDefinition.InputModifications); err != nil {
return err
}
s.bindings.AddReadBinding(namespace, genericName, reader)

return nil
}
Expand Down
Loading

0 comments on commit 933f88f

Please sign in to comment.