Skip to content

Commit

Permalink
Added ArgsTransform to ChainWriter
Browse files Browse the repository at this point in the history
  • Loading branch information
silaslenihan committed Jan 16, 2025
1 parent 74b7a72 commit c4004ed
Show file tree
Hide file tree
Showing 3 changed files with 200 additions and 12 deletions.
75 changes: 75 additions & 0 deletions pkg/solana/chainwriter/ccip_example_config.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
package chainwriter

import (
"context"
"fmt"
"reflect"

"github.com/gagliardetto/solana-go"
)

const registryAddress = "4Nn9dsYBcSTzRbK9hg9kzCUdrCSkMZq1UR6Vw1Tkaf6A"

func TestConfig() {
// Fake constant addresses for the purpose of this example.
registryAddress := "4Nn9dsYBcSTzRbK9hg9kzCUdrCSkMZq1UR6Vw1Tkaf6A"
Expand Down Expand Up @@ -331,3 +337,72 @@ func TestConfig() {
}
_ = chainWriterConfig
}

// This example doesn't contain the complete implementation of the function, since the
// types needed to transform can't be imported into this repository. However, in production, this
// function will be implemented in the CCIP plugin, which will have access to all the necessary types.
func CCIPArgsTransform(ctx context.Context, cw *SolanaChainWriterService, args any, accounts solana.AccountMetaSlice) (any, error) {
TokenPoolLookupTable := LookupTables{
DerivedLookupTables: []DerivedLookupTable{
{
Name: "RegistryTokenState",
Accounts: PDALookups{
Name: "RegistryTokenState",
PublicKey: AccountConstant{
Address: registryAddress,
IsSigner: false,
IsWritable: false,
},
Seeds: []Seed{
{Dynamic: AccountLookup{Location: "Message.TokenAmounts.DestTokenAddress"}},
},
IsSigner: false,
IsWritable: false,
InternalField: InternalField{
Type: reflect.TypeOf(DataAccount{}),
Location: "LookupTable",
},
},
},
}}
tableMap, _, err := cw.ResolveLookupTables(ctx, args, TokenPoolLookupTable)
if err != nil {
return nil, err
}
registryTables := tableMap["RegistryTokenState"]
tokenPoolAddresses := []solana.PublicKey{}
for _, table := range registryTables {
tokenPoolAddresses = append(tokenPoolAddresses, table[0].PublicKey)
}

tokenIndexes := []uint8{}
for i, account := range accounts {
for _, address := range tokenPoolAddresses {
if account.PublicKey == address {
if i > 255 {
return nil, fmt.Errorf("index %d out of range for uint8", i)
}
tokenIndexes = append(tokenIndexes, uint8(i))
}
}
}

if len(tokenIndexes) != len(tokenPoolAddresses) {
return nil, fmt.Errorf("missing token pools in accounts")
}

// Args should be of the following type:
// https://github.com/smartcontractkit/chainlink/blob/73f16fec1dcf13d3254d44b3e2df3e36303ce77f/core/capabilities/ccip/ocrimpls/contract_transmitter.go#L82-L90
//
// struct {
// ReportContext [2][32]byte
// Report []byte
// Info ccipocr3.ExecuteReportInfo
// }
//
// Then we need to extend it to include TokenIndexes and return that altered struct.
// This is just an example - in the real plugin implementation, the types will be imported.
// The token indexes calculation above should be mostly accurate though.

return args, nil
}
30 changes: 18 additions & 12 deletions pkg/solana/chainwriter/chain_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ type MethodConfig struct {
Accounts []Lookup
// Location in the args where the debug ID is stored
DebugIDLocation string
ArgsTransform func(ctx context.Context, cw *SolanaChainWriterService, args any, accounts solana.AccountMetaSlice) (any, error)
}

func NewSolanaChainWriterService(logger logger.Logger, reader client.Reader, txm txm.TxManager, ge fees.Estimator, config ChainWriterConfig) (*SolanaChainWriterService, error) {
Expand Down Expand Up @@ -256,15 +257,6 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra
}
}

encodedPayload, err := s.encoder.Encode(ctx, args, codec.WrapItemType(true, contractName, method, ""))

if err != nil {
return errorWithDebugID(fmt.Errorf("error encoding transaction payload: %w", err), debugID)
}

discriminator := GetDiscriminator(methodConfig.ChainSpecificName)
encodedPayload = append(discriminator[:], encodedPayload...)

// Fetch derived and static table maps
derivedTableMap, staticTableMap, err := s.ResolveLookupTables(ctx, args, methodConfig.LookupTables)
if err != nil {
Expand All @@ -282,12 +274,17 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra
return errorWithDebugID(fmt.Errorf("error parsing fee payer address: %w", err), debugID)
}

accounts = append([]*solana.AccountMeta{solana.Meta(feePayer).SIGNER().WRITE()}, accounts...)
accounts = append(accounts, solana.Meta(solana.SystemProgramID))

// Filter the lookup table addresses based on which accounts are actually used
filteredLookupTableMap := s.FilterLookupTableAddresses(accounts, derivedTableMap, staticTableMap)

// Transform args if necessary
if methodConfig.ArgsTransform != nil {
args, err = methodConfig.ArgsTransform(ctx, s, args, accounts)
if err != nil {
return errorWithDebugID(fmt.Errorf("error transforming args: %w", err), debugID)
}
}

// Fetch latest blockhash
blockhash, err := s.reader.LatestBlockhash(ctx)
if err != nil {
Expand All @@ -300,6 +297,15 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra
return errorWithDebugID(fmt.Errorf("error parsing program ID: %w", err), debugID)
}

encodedPayload, err := s.encoder.Encode(ctx, args, codec.WrapItemType(true, contractName, method, ""))

if err != nil {
return errorWithDebugID(fmt.Errorf("error encoding transaction payload: %w", err), debugID)
}

discriminator := GetDiscriminator(methodConfig.ChainSpecificName)
encodedPayload = append(discriminator[:], encodedPayload...)

tx, err := solana.NewTransaction(
[]solana.Instruction{
solana.NewInstruction(programID, accounts, encodedPayload),
Expand Down
107 changes: 107 additions & 0 deletions pkg/solana/chainwriter/chain_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package chainwriter_test

import (
"bytes"
"context"
"encoding/json"
"errors"
"math/big"
"os"
Expand Down Expand Up @@ -454,6 +456,12 @@ func TestChainWriter_SubmitTransaction(t *testing.T) {
StaticLookupTables: []solana.PublicKey{staticLookupTablePubkey},
},
Accounts: []chainwriter.Lookup{
chainwriter.AccountConstant{
Name: "feepayer",
Address: admin.String(),
IsSigner: false,
IsWritable: false,
},
chainwriter.AccountConstant{
Name: "Constant",
Address: account1.String(),
Expand Down Expand Up @@ -482,7 +490,14 @@ func TestChainWriter_SubmitTransaction(t *testing.T) {
LookupTableName: "DerivedTable",
IncludeIndexes: []int{0},
},
chainwriter.AccountConstant{
Name: "systemprogram",
Address: solana.SystemProgramID.String(),
IsSigner: false,
IsWritable: false,
},
},
ArgsTransform: nil,
},
},
IDL: testContractIDLJson,
Expand Down Expand Up @@ -566,6 +581,98 @@ func TestChainWriter_SubmitTransaction(t *testing.T) {
submitErr := cw.SubmitTransaction(ctx, "contract_reader_interface", "initializeLookupTable", args, txID, programID.String(), nil, nil)
require.NoError(t, submitErr)
})

t.Run("submits transaction successfully with ArgsTransform", func(t *testing.T) {
type ArgsPostTransform struct {
LookupTable solana.PublicKey
Seed1 []byte
Seed2 []byte
Seed3 string
}
cwConfigWithArgs := cwConfig
programConfig := cwConfigWithArgs.Programs["contract_reader_interface"]
rawIDL := cwConfigWithArgs.Programs["contract_reader_interface"].IDL

var idlMap map[string]interface{}
err := json.Unmarshal([]byte(rawIDL), &idlMap)
require.NoError(t, err)

instructions, ok := idlMap["instructions"].([]interface{})
require.True(t, ok)

for _, instr := range instructions {
instrObj, ok := instr.(map[string]interface{})
require.True(t, ok)

if instrObj["name"] == "initializeLookupTable" {
argsArr, ok := instrObj["args"].([]interface{})
require.True(t, ok)

newArg := map[string]interface{}{
"name": "seed3",
"type": "string",
}
argsArr = append(argsArr, newArg)
instrObj["args"] = argsArr
}
}

modifiedIDL, err := json.Marshal(idlMap)
require.NoError(t, err)

programConfig.IDL = string(modifiedIDL)
methodConfig := programConfig.Methods["initializeLookupTable"]

methodConfig.ArgsTransform = func(ctx context.Context, cw *chainwriter.SolanaChainWriterService, args any, accounts solana.AccountMetaSlice) (any, error) {
argsPreTransform, ok := args.(Arguments)
require.True(t, ok)

argsTransformed := ArgsPostTransform{
LookupTable: argsPreTransform.LookupTable,
Seed1: argsPreTransform.Seed1,
Seed2: seed2,
Seed3: "seed3",
}

return argsTransformed, nil
}

programConfig.Methods["initializeLookupTable"] = methodConfig
cwConfigWithArgs.Programs["contract_reader_interface"] = programConfig
cw, err := chainwriter.NewSolanaChainWriterService(testutils.NewNullLogger(), rw, txm, ge, cwConfig)
require.NoError(t, err)

recentBlockHash := solana.Hash{}
rw.On("LatestBlockhash", mock.Anything).Return(&rpc.GetLatestBlockhashResult{Value: &rpc.LatestBlockhashResult{Blockhash: recentBlockHash, LastValidBlockHeight: uint64(100)}}, nil).Once()
txID := uuid.NewString()

// The TX being successfully sent means it was encoded properly, meaning the ArgsTransform worked
txm.On("Enqueue", mock.Anything, admin.String(), mock.MatchedBy(func(tx *solana.Transaction) bool {
// match transaction fields to ensure it was built as expected
require.Equal(t, recentBlockHash, tx.Message.RecentBlockhash)
require.Len(t, tx.Message.Instructions, 1)
require.Len(t, tx.Message.AccountKeys, 6) // fee payer + derived accounts
require.Equal(t, admin, tx.Message.AccountKeys[0]) // fee payer
require.Equal(t, account1, tx.Message.AccountKeys[1]) // account constant
require.Equal(t, account2, tx.Message.AccountKeys[2]) // account lookup
require.Equal(t, account3, tx.Message.AccountKeys[3]) // pda lookup
require.Equal(t, solana.SystemProgramID, tx.Message.AccountKeys[4]) // system program ID
require.Equal(t, programID, tx.Message.AccountKeys[5]) // instruction program ID
// instruction program ID
require.Len(t, tx.Message.AddressTableLookups, 1) // address table look contains entry
require.Equal(t, derivedLookupTablePubkey, tx.Message.AddressTableLookups[0].AccountKey) // address table
return true
}), &txID, mock.Anything).Return(nil).Once()

args := Arguments{
LookupTable: account2,
Seed1: seed1,
Seed2: seed2,
}

submitErr := cw.SubmitTransaction(ctx, "contract_reader_interface", "initializeLookupTable", args, txID, programID.String(), nil, nil)
require.NoError(t, submitErr)
})
}

func TestChainWriter_GetTransactionStatus(t *testing.T) {
Expand Down

0 comments on commit c4004ed

Please sign in to comment.