Skip to content

Commit

Permalink
Make events discriminator testable in codec interface tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ilija42 committed Jan 14, 2025
1 parent 9f498d8 commit 0c844c5
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 16 deletions.
4 changes: 2 additions & 2 deletions pkg/solana/codec/codec_entry.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func NewAccountEntry(offchainName string, idlTypes AccountIDLTypes, includeDiscr

var discriminator *Discriminator
if includeDiscriminator {
discriminator = NewDiscriminator(offchainName, true)
discriminator = NewDiscriminator(idlTypes.Account.Name, true)
}

return newEntry(
Expand Down Expand Up @@ -92,7 +92,7 @@ func NewEventArgsEntry(offChainName string, idlTypes EventIDLTypes, includeDiscr

var discriminator *Discriminator
if includeDiscriminator {
discriminator = NewDiscriminator(offChainName, false)
discriminator = NewDiscriminator(idlTypes.Event.Name, false)
}

return newEntry(
Expand Down
12 changes: 7 additions & 5 deletions pkg/solana/codec/codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,19 @@ func (it *codecInterfaceTester) GetAccountString(i int) string {
}

func (it *codecInterfaceTester) EncodeFields(t *testing.T, request *EncodeRequest) []byte {
if request.TestOn == TestItemType || request.TestOn == testutils.TestEventItem {
return encodeFieldsOnItem(t, request)
if request.TestOn == TestItemType {
return encodeFieldsOnItem(t, request, true)
} else if request.TestOn == testutils.TestEventItem {
return encodeFieldsOnItem(t, request, false)
}

return encodeFieldsOnSliceOrArray(t, request)
}

func encodeFieldsOnItem(t *testing.T, request *EncodeRequest) ocr2types.Report {
func encodeFieldsOnItem(t *testing.T, request *EncodeRequest, isAccount bool) ocr2types.Report {
buf := new(bytes.Buffer)
// The underlying TestItemAsAccount adds a discriminator by default while being Borsh encoded.
if err := testutils.EncodeRequestToTestItemAsAccount(request.TestStructs[0]).MarshalWithEncoder(bin.NewBorshEncoder(buf)); err != nil {
// The underlying TestItem adds a discriminator by default while being Borsh encoded.
if err := testutils.EncodeRequestToTestItemAsAccount(request.TestStructs[0], isAccount).MarshalWithEncoder(bin.NewBorshEncoder(buf)); err != nil {
require.NoError(t, err)
}
return buf.Bytes()
Expand Down
33 changes: 24 additions & 9 deletions pkg/solana/codec/testutils/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ var CodecDefs = map[string]CodecDef{
},
}

type TestItemAsAccount struct {
type TestItem struct {
IsAccount bool
Field int32
OracleID uint8
OracleIDs [32]uint8
Expand All @@ -170,14 +171,20 @@ type TestItemAsAccount struct {
NestedStaticStruct NestedStatic
}

var TestItemDiscriminator = [8]byte{148, 105, 105, 155, 26, 167, 212, 149}
var TestItemAsAccountDiscriminator = [8]byte{148, 105, 105, 155, 26, 167, 212, 149}

func (obj TestItemAsAccount) MarshalWithEncoder(encoder *agbinary.Encoder) (err error) {
// Write account discriminator:
err = encoder.WriteBytes(TestItemDiscriminator[:], false)
var TestItemAsEventDiscriminator = [8]byte{119, 183, 160, 247, 84, 104, 222, 251}

func (obj TestItem) MarshalWithEncoder(encoder *agbinary.Encoder) (err error) {
if obj.IsAccount {
err = encoder.WriteBytes(TestItemAsAccountDiscriminator[:], false)
} else {
err = encoder.WriteBytes(TestItemAsEventDiscriminator[:], false)
}
if err != nil {
return err
}

// Serialize `Field` param:
err = encoder.Encode(obj.Field)
if err != nil {
Expand Down Expand Up @@ -226,19 +233,26 @@ func (obj TestItemAsAccount) MarshalWithEncoder(encoder *agbinary.Encoder) (err
return nil
}

func (obj *TestItemAsAccount) UnmarshalWithDecoder(decoder *agbinary.Decoder) error {
func (obj *TestItem) UnmarshalWithDecoder(decoder *agbinary.Decoder) error {
// Read and check account discriminator:
{
discriminator, err := decoder.ReadTypeID()
if err != nil {
return err
}
if !discriminator.Equal(TestItemDiscriminator[:]) {
if obj.IsAccount && !discriminator.Equal(TestItemAsAccountDiscriminator[:]) {
return fmt.Errorf(
"wrong discriminator: wanted %s, got %s",
"[148 105 105 155 26 167 212 149]",
fmt.Sprint(discriminator[:]))
}

if !obj.IsAccount && !discriminator.Equal(TestItemAsEventDiscriminator[:]) {
return fmt.Errorf(
"wrong discriminator: wanted %s, got %s",
"[119, 183, 160, 247, 84, 104, 222, 251]",
fmt.Sprint(discriminator[:]))
}
}
// Deserialize `Field`:
err := decoder.Decode(&obj.Field)
Expand Down Expand Up @@ -563,8 +577,9 @@ func (obj *NestedStatic) UnmarshalWithDecoder(decoder *agbinary.Decoder) (err er
return nil
}

func EncodeRequestToTestItemAsAccount(testStruct interfacetests.TestStruct) TestItemAsAccount {
return TestItemAsAccount{
func EncodeRequestToTestItemAsAccount(testStruct interfacetests.TestStruct, isAccount bool) TestItem {
return TestItem{
IsAccount: isAccount,
Field: *testStruct.Field,
OracleID: uint8(testStruct.OracleID),
OracleIDs: getOracleIDs(testStruct),
Expand Down

0 comments on commit 0c844c5

Please sign in to comment.