diff --git a/pkg/solana/chainwriter/ccip_example_config.go b/pkg/solana/chainwriter/ccip_example_config.go index fc46794a8..79f781b70 100644 --- a/pkg/solana/chainwriter/ccip_example_config.go +++ b/pkg/solana/chainwriter/ccip_example_config.go @@ -1,9 +1,15 @@ package chainwriter import ( + "context" + "fmt" + "reflect" + "github.com/gagliardetto/solana-go" ) +const registryAddress = "4Nn9dsYBcSTzRbK9hg9kzCUdrCSkMZq1UR6Vw1Tkaf6A" + func TestConfig() { // Fake constant addresses for the purpose of this example. registryAddress := "4Nn9dsYBcSTzRbK9hg9kzCUdrCSkMZq1UR6Vw1Tkaf6A" @@ -331,3 +337,72 @@ func TestConfig() { } _ = chainWriterConfig } + +// This example doesn't contain the complete implementation of the function, since the +// types needed to transform can't be imported into this repository. However, in production, this +// function will be implemented in the CCIP plugin, which will have access to all the necessary types. +func CCIPArgsTransform(ctx context.Context, cw *SolanaChainWriterService, args any, accounts solana.AccountMetaSlice) (any, error) { + TokenPoolLookupTable := LookupTables{ + DerivedLookupTables: []DerivedLookupTable{ + { + Name: "RegistryTokenState", + Accounts: PDALookups{ + Name: "RegistryTokenState", + PublicKey: AccountConstant{ + Address: registryAddress, + IsSigner: false, + IsWritable: false, + }, + Seeds: []Seed{ + {Dynamic: AccountLookup{Location: "Message.TokenAmounts.DestTokenAddress"}}, + }, + IsSigner: false, + IsWritable: false, + InternalField: InternalField{ + Type: reflect.TypeOf(DataAccount{}), + Location: "LookupTable", + }, + }, + }, + }} + tableMap, _, err := cw.ResolveLookupTables(ctx, args, TokenPoolLookupTable) + if err != nil { + return nil, err + } + registryTables := tableMap["RegistryTokenState"] + tokenPoolAddresses := []solana.PublicKey{} + for _, table := range registryTables { + tokenPoolAddresses = append(tokenPoolAddresses, table[0].PublicKey) + } + + tokenIndexes := []uint8{} + for i, account := range accounts { + for _, address := range tokenPoolAddresses { + if account.PublicKey == address { + if i > 255 { + return nil, fmt.Errorf("index %d out of range for uint8", i) + } + tokenIndexes = append(tokenIndexes, uint8(i)) + } + } + } + + if len(tokenIndexes) != len(tokenPoolAddresses) { + return nil, fmt.Errorf("missing token pools in accounts") + } + + // Args should be of the following type: + // https://github.com/smartcontractkit/chainlink/blob/73f16fec1dcf13d3254d44b3e2df3e36303ce77f/core/capabilities/ccip/ocrimpls/contract_transmitter.go#L82-L90 + // + // struct { + // ReportContext [2][32]byte + // Report []byte + // Info ccipocr3.ExecuteReportInfo + // } + // + // Then we need to extend it to include TokenIndexes and return that altered struct. + // This is just an example - in the real plugin implementation, the types will be imported. + // The token indexes calculation above should be mostly accurate though. + + return args, nil +} diff --git a/pkg/solana/chainwriter/chain_writer.go b/pkg/solana/chainwriter/chain_writer.go index e02148d89..1cf0976b0 100644 --- a/pkg/solana/chainwriter/chain_writer.go +++ b/pkg/solana/chainwriter/chain_writer.go @@ -59,6 +59,7 @@ type MethodConfig struct { Accounts []Lookup // Location in the args where the debug ID is stored DebugIDLocation string + ArgsTransform func(ctx context.Context, cw *SolanaChainWriterService, args any, accounts solana.AccountMetaSlice) (any, error) } func NewSolanaChainWriterService(logger logger.Logger, reader client.Reader, txm txm.TxManager, ge fees.Estimator, config ChainWriterConfig) (*SolanaChainWriterService, error) { @@ -256,15 +257,6 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra } } - encodedPayload, err := s.encoder.Encode(ctx, args, codec.WrapItemType(true, contractName, method, "")) - - if err != nil { - return errorWithDebugID(fmt.Errorf("error encoding transaction payload: %w", err), debugID) - } - - discriminator := GetDiscriminator(methodConfig.ChainSpecificName) - encodedPayload = append(discriminator[:], encodedPayload...) - // Fetch derived and static table maps derivedTableMap, staticTableMap, err := s.ResolveLookupTables(ctx, args, methodConfig.LookupTables) if err != nil { @@ -282,12 +274,17 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra return errorWithDebugID(fmt.Errorf("error parsing fee payer address: %w", err), debugID) } - accounts = append([]*solana.AccountMeta{solana.Meta(feePayer).SIGNER().WRITE()}, accounts...) - accounts = append(accounts, solana.Meta(solana.SystemProgramID)) - // Filter the lookup table addresses based on which accounts are actually used filteredLookupTableMap := s.FilterLookupTableAddresses(accounts, derivedTableMap, staticTableMap) + // Transform args if necessary + if methodConfig.ArgsTransform != nil { + args, err = methodConfig.ArgsTransform(ctx, s, args, accounts) + if err != nil { + return errorWithDebugID(fmt.Errorf("error transforming args: %w", err), debugID) + } + } + // Fetch latest blockhash blockhash, err := s.reader.LatestBlockhash(ctx) if err != nil { @@ -300,6 +297,15 @@ func (s *SolanaChainWriterService) SubmitTransaction(ctx context.Context, contra return errorWithDebugID(fmt.Errorf("error parsing program ID: %w", err), debugID) } + encodedPayload, err := s.encoder.Encode(ctx, args, codec.WrapItemType(true, contractName, method, "")) + + if err != nil { + return errorWithDebugID(fmt.Errorf("error encoding transaction payload: %w", err), debugID) + } + + discriminator := GetDiscriminator(methodConfig.ChainSpecificName) + encodedPayload = append(discriminator[:], encodedPayload...) + tx, err := solana.NewTransaction( []solana.Instruction{ solana.NewInstruction(programID, accounts, encodedPayload), diff --git a/pkg/solana/chainwriter/chain_writer_test.go b/pkg/solana/chainwriter/chain_writer_test.go index 947674a2f..5c6a47485 100644 --- a/pkg/solana/chainwriter/chain_writer_test.go +++ b/pkg/solana/chainwriter/chain_writer_test.go @@ -2,6 +2,8 @@ package chainwriter_test import ( "bytes" + "context" + "encoding/json" "errors" "math/big" "os" @@ -454,6 +456,12 @@ func TestChainWriter_SubmitTransaction(t *testing.T) { StaticLookupTables: []solana.PublicKey{staticLookupTablePubkey}, }, Accounts: []chainwriter.Lookup{ + chainwriter.AccountConstant{ + Name: "feepayer", + Address: admin.String(), + IsSigner: false, + IsWritable: false, + }, chainwriter.AccountConstant{ Name: "Constant", Address: account1.String(), @@ -482,7 +490,14 @@ func TestChainWriter_SubmitTransaction(t *testing.T) { LookupTableName: "DerivedTable", IncludeIndexes: []int{0}, }, + chainwriter.AccountConstant{ + Name: "systemprogram", + Address: solana.SystemProgramID.String(), + IsSigner: false, + IsWritable: false, + }, }, + ArgsTransform: nil, }, }, IDL: testContractIDLJson, @@ -566,6 +581,100 @@ func TestChainWriter_SubmitTransaction(t *testing.T) { submitErr := cw.SubmitTransaction(ctx, "contract_reader_interface", "initializeLookupTable", args, txID, programID.String(), nil, nil) require.NoError(t, submitErr) }) + + t.Run("submits transaction successfully with ArgsTransform", func(t *testing.T) { + type ArgsPostTransform struct { + LookupTable solana.PublicKey + Seed1 []byte + Seed2 []byte + Seed3 string + } + cwConfigWithArgs := cwConfig + programConfig := cwConfigWithArgs.Programs["contract_reader_interface"] + rawIDL := cwConfigWithArgs.Programs["contract_reader_interface"].IDL + + var idlMap map[string]interface{} + err := json.Unmarshal([]byte(rawIDL), &idlMap) + require.NoError(t, err) + + instructions, ok := idlMap["instructions"].([]interface{}) + require.True(t, ok) + + // Add an additional field to the IDL that will be set in the ArgsTransform + // Since it's in the IDL, the codec will require it's present in the args + for _, instr := range instructions { + instrObj, ok := instr.(map[string]interface{}) + require.True(t, ok) + + if instrObj["name"] == "initializeLookupTable" { + argsArr, ok := instrObj["args"].([]interface{}) + require.True(t, ok) + + newArg := map[string]interface{}{ + "name": "seed3", + "type": "string", + } + argsArr = append(argsArr, newArg) + instrObj["args"] = argsArr + } + } + + modifiedIDL, err := json.Marshal(idlMap) + require.NoError(t, err) + + programConfig.IDL = string(modifiedIDL) + methodConfig := programConfig.Methods["initializeLookupTable"] + + methodConfig.ArgsTransform = func(ctx context.Context, cw *chainwriter.SolanaChainWriterService, args any, accounts solana.AccountMetaSlice) (any, error) { + argsPreTransform, ok := args.(Arguments) + require.True(t, ok) + + argsTransformed := ArgsPostTransform{ + LookupTable: argsPreTransform.LookupTable, + Seed1: argsPreTransform.Seed1, + Seed2: seed2, + Seed3: "seed3", + } + + return argsTransformed, nil + } + + programConfig.Methods["initializeLookupTable"] = methodConfig + cwConfigWithArgs.Programs["contract_reader_interface"] = programConfig + cw, err := chainwriter.NewSolanaChainWriterService(testutils.NewNullLogger(), rw, txm, ge, cwConfig) + require.NoError(t, err) + + recentBlockHash := solana.Hash{} + rw.On("LatestBlockhash", mock.Anything).Return(&rpc.GetLatestBlockhashResult{Value: &rpc.LatestBlockhashResult{Blockhash: recentBlockHash, LastValidBlockHeight: uint64(100)}}, nil).Once() + txID := uuid.NewString() + + // The TX being successfully sent means it was encoded properly, meaning the ArgsTransform worked + txm.On("Enqueue", mock.Anything, admin.String(), mock.MatchedBy(func(tx *solana.Transaction) bool { + // match transaction fields to ensure it was built as expected + require.Equal(t, recentBlockHash, tx.Message.RecentBlockhash) + require.Len(t, tx.Message.Instructions, 1) + require.Len(t, tx.Message.AccountKeys, 6) // fee payer + derived accounts + require.Equal(t, admin, tx.Message.AccountKeys[0]) // fee payer + require.Equal(t, account1, tx.Message.AccountKeys[1]) // account constant + require.Equal(t, account2, tx.Message.AccountKeys[2]) // account lookup + require.Equal(t, account3, tx.Message.AccountKeys[3]) // pda lookup + require.Equal(t, solana.SystemProgramID, tx.Message.AccountKeys[4]) // system program ID + require.Equal(t, programID, tx.Message.AccountKeys[5]) // instruction program ID + // instruction program ID + require.Len(t, tx.Message.AddressTableLookups, 1) // address table look contains entry + require.Equal(t, derivedLookupTablePubkey, tx.Message.AddressTableLookups[0].AccountKey) // address table + return true + }), &txID, mock.Anything).Return(nil).Once() + + args := Arguments{ + LookupTable: account2, + Seed1: seed1, + Seed2: seed2, + } + + submitErr := cw.SubmitTransaction(ctx, "contract_reader_interface", "initializeLookupTable", args, txID, programID.String(), nil, nil) + require.NoError(t, submitErr) + }) } func TestChainWriter_GetTransactionStatus(t *testing.T) {