Skip to content

Commit

Permalink
Use masks to differentiate between versions of a struct
Browse files Browse the repository at this point in the history
  • Loading branch information
firelizzard18 committed Jul 12, 2024
1 parent c92861d commit 27a5030
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 116 deletions.
194 changes: 92 additions & 102 deletions pkg/types/encoding/eip712.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,16 @@ func (td *TypeDefinition) Resolve(v any, typeName string) (eipResolvedValue, err
}

var fields []*resolvedFieldValue
for _, field := range *td.Fields {
value, _ := data[field.Name]
v, err := field.resolve(value)
for i, field := range *td.Fields {
value, ok := data[field.Name]
if !ok {
continue
}
v, err := field.resolve(value, typeName)
if err != nil {
return nil, err
}
fields = append(fields, v)
fields = append(fields, &resolvedFieldValue{*field, i, v})
}

return &eipResolvedStruct{
Expand All @@ -193,7 +196,7 @@ func (td *TypeDefinition) Resolve(v any, typeName string) (eipResolvedValue, err

type eipResolvedValue interface {
Hash() ([]byte, error)
Types(map[string][]*TypeField)
Types(ret map[string][]*TypeField) string
MarshalJSON() ([]byte, error)
header(map[string]string)
}
Expand All @@ -205,40 +208,10 @@ type eipResolvedStruct struct {

type resolvedFieldValue struct {
TypeField
skip bool
index int
value eipResolvedValue
}

type eipResolvedArray []eipResolvedValue

func (e eipResolvedArray) MarshalJSON() ([]byte, error) {
return json.Marshal([]eipResolvedValue(e))
}

type eipEmptyStruct struct {
typeName string
td *TypeDefinition
}

func (e *eipEmptyStruct) MarshalJSON() ([]byte, error) { panic("invalid call") }

func (e *eipEmptyStruct) header(ret map[string]string) {
headerFor(ret, e.typeName, func(fn func(*TypeField)) {
for _, field := range *e.td.Fields {
fn(field)
}
})
}

func (e *eipEmptyStruct) Hash() ([]byte, error) {
return make([]byte, 32), nil
}

func (e *eipEmptyStruct) Types(ret map[string][]*TypeField) {
name, _ := stripSlice(e.typeName)
ret[name] = *e.td.Fields
}

type eipResolvedAtomic struct {
ethType string
value any
Expand All @@ -248,10 +221,6 @@ type eipResolvedAtomic struct {
func (e *eipResolvedStruct) MarshalJSON() ([]byte, error) {
v := map[string]json.RawMessage{}
for _, f := range e.fields {
if f.skip {
continue
}

var err error
v[f.Name], err = f.value.MarshalJSON()
if err != nil {
Expand All @@ -261,7 +230,7 @@ func (e *eipResolvedStruct) MarshalJSON() ([]byte, error) {
return json.Marshal(v)
}

const debugHash = false
const debugHash = true

func (e *eipResolvedStruct) Hash() ([]byte, error) {
//the stripping shouldn't be necessary, but do it as a precaution
Expand Down Expand Up @@ -323,49 +292,70 @@ func (e *eipResolvedStruct) Hash() ([]byte, error) {
}

func (e *eipResolvedStruct) header(ret map[string]string) {
headerFor(ret, e.typeName, func(fn func(*TypeField)) {
for _, field := range e.fields {
fn(&field.TypeField)
}
})
for _, field := range e.fields {
field.value.header(ret)
}
}

func headerFor(ret map[string]string, typeName string, fields func(func(*TypeField))) {
//the stripping shouldn't be necessary, but do it as a precaution
strippedType, _ := stripSlice(typeName)
name, _ := stripSlice(e.typeName)

//scan the fields
var mask int
var fields []string
for _, f := range e.fields {
mask |= 1 << f.index
fields = append(fields, f.value.Types(nil)+" "+f.Name)
}

//define the type structure
var header strings.Builder
header.WriteString(strippedType + "(")
first := true
fields(func(field *TypeField) {
if first {
first = false
} else {
header.WriteString(",")
}
header.WriteString(field.Type + " " + field.Name)
})
header.WriteString(")")
ret[strippedType] = header.String()
if name != "EIP712Domain" && name != "Transaction" {
name = fmt.Sprintf("%s{%x}", name, mask)
}
ret[name] = fmt.Sprintf("%s(%s)", name, strings.Join(fields, ","))

//find dependencies
for _, field := range e.fields {
field.value.header(ret)
}
}

func (e *eipResolvedStruct) Types(ret map[string][]*TypeField) {
func (e *eipResolvedStruct) Types(ret map[string][]*TypeField) string {
var mask int
var fields []*TypeField
for _, f := range e.fields {
fields = append(fields, &TypeField{Name: f.Name, Type: f.Type})
f.value.Types(ret)
mask |= 1 << f.index
fields = append(fields, &TypeField{
Name: f.Name,
Type: f.value.Types(ret),
})

}
name, _ := stripSlice(e.typeName)
ret[name] = fields
if name != "EIP712Domain" && name != "Transaction" {
name = fmt.Sprintf("%s{%x}", name, mask)
}
if ret != nil {
ret[name] = fields
}
return name
}

type eipResolvedArray struct {
typeName string
values []eipResolvedValue
}

func (e *eipResolvedArray) MarshalJSON() ([]byte, error) {
v := map[string]json.RawMessage{}
for i, elem := range e.values {
var err error
v[fmt.Sprint(i)], err = elem.MarshalJSON()
if err != nil {
return nil, err
}
}
return json.Marshal(v)
}

func (e eipResolvedArray) Hash() ([]byte, error) {
func (e *eipResolvedArray) Hash() ([]byte, error) {
var buf bytes.Buffer
for _, v := range e {
for _, v := range e.values {
hash, err := v.Hash()
if err != nil {
return nil, err
Expand All @@ -375,21 +365,34 @@ func (e eipResolvedArray) Hash() ([]byte, error) {
return keccak256(buf.Bytes()), nil
}

func (e eipResolvedArray) header(ret map[string]string) {
for _, v := range e {
func (e *eipResolvedArray) header(ret map[string]string) {
var fields []string
for i, v := range e.values {
fields = append(fields, fmt.Sprintf("%s %d", v.Types(nil), i))
v.header(ret)
}
if ret != nil {
ret[e.typeName] = fmt.Sprintf("%s(%s)", e.typeName, strings.Join(fields, ","))
}
}

func (e eipResolvedArray) Types(ret map[string][]*TypeField) {
for _, v := range e {
v.Types(ret)
func (e *eipResolvedArray) Types(ret map[string][]*TypeField) string {
var fields []*TypeField
for i, v := range e.values {
fields = append(fields, &TypeField{
Name: fmt.Sprint(i),
Type: v.Types(ret),
})
}
if ret != nil {
ret[e.typeName] = fields
}
return e.typeName
}

func (e *eipResolvedAtomic) Hash() ([]byte, error) { return e.hash() }
func (e *eipResolvedAtomic) header(map[string]string) {}
func (e *eipResolvedAtomic) Types(map[string][]*TypeField) {}
func (e *eipResolvedAtomic) Hash() ([]byte, error) { return e.hash() }
func (e *eipResolvedAtomic) header(map[string]string) {}
func (e *eipResolvedAtomic) Types(ret map[string][]*TypeField) string { return e.ethType }

func (e *eipResolvedAtomic) MarshalJSON() ([]byte, error) {
v := e.value
Expand Down Expand Up @@ -435,34 +438,31 @@ func newAtomicEncoder[T any](ethType string, hasher func(T) ([]byte, error)) EIP
return &eip712AtomicResolver[T]{ethType, hasher}
}

func (f *TypeField) resolve(v any) (*resolvedFieldValue, error) {
func (f *TypeField) resolve(v any, parentType string) (eipResolvedValue, error) {
strippedType, slices := stripSlice(f.Type)
encoder, ok := eip712EncoderMap[strippedType]
if ok {
if slices {
// If v is nil, return an empty array
if v == nil {
return &resolvedFieldValue{*f, false, eipResolvedArray{}}, nil
}
vv, ok := v.([]interface{})
if !ok {
return nil, fmt.Errorf("eip712 field %s is not of an array of interfaces", f.Name)
}
var array eipResolvedArray
array.typeName = fmt.Sprintf("%s{%s}", parentType, f.Name)
for _, vvv := range vv {
r, err := encoder.Resolve(vvv, f.Type)
if err != nil {
return nil, err
}
array = append(array, r)
array.values = append(array.values, r)
}
return &resolvedFieldValue{*f, false, array}, nil
return &array, nil
}
r, err := encoder.Resolve(v, f.Type)
if err != nil {
return nil, err
}
return &resolvedFieldValue{*f, false, r}, nil
return r, nil
}

//if we get here, we are expecting a struct
Expand All @@ -471,11 +471,6 @@ func (f *TypeField) resolve(v any) (*resolvedFieldValue, error) {
return nil, fmt.Errorf("eip712 field %s", f.Type)
}

// If v is nil and the type is a struct, skip this value
if v == nil {
return &resolvedFieldValue{*f, true, &eipEmptyStruct{strippedType, fields}}, nil
}

//from here on down we are expecting a struct
if slices {
//we expect a slice
Expand All @@ -484,27 +479,22 @@ func (f *TypeField) resolve(v any) (*resolvedFieldValue, error) {
return nil, fmt.Errorf("eip712 field %s is not of an array of interfaces", f.Name)
}
var array eipResolvedArray
array.typeName = fmt.Sprintf("%s{%s}", parentType, f.Name)
for _, vvv := range vv {
//now run the hasher for the type
// look for encoder, if we don't have one, call the types encoder
fields, ok := SchemaDictionary[strippedType]
if !ok {
return nil, fmt.Errorf("eip712 field %s", f.Type)
}
r, err := fields.Resolve(vvv, f.Type)
if err != nil {
return nil, err
}
array = append(array, r)
array.values = append(array.values, r)
}
return &resolvedFieldValue{*f, false, array}, nil
return &array, nil
}

r, err := fields.Resolve(v, f.Type)
if err != nil {
return nil, err
}
return &resolvedFieldValue{*f, false, r}, nil
return r, nil
}

func NewTypeField(n string, tp string) *TypeField {
Expand Down
20 changes: 8 additions & 12 deletions pkg/types/encoding/eip712_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ func TestEIP712Arrays(t *testing.T) {
For("adi.acme", "book", "1").
UpdateKeyPage().
Add().Entry().Hash([32]byte{1, 2, 3}).FinishEntry().FinishOperation().
Add().Entry().Owner("foo.bar").FinishEntry().FinishOperation().
Done()),
must(build.Transaction().
For("adi.acme", "book", "1").
UpdateKeyPage().
Add().Entry().Hash([32]byte{1, 2, 3}).FinishEntry().FinishOperation().
SetThreshold(2).
// Add().Entry().Owner("foo.bar").FinishEntry().FinishOperation().
Done()),
// must(build.Transaction().
// For("adi.acme", "book", "1").
// UpdateKeyPage().
// Add().Entry().Hash([32]byte{1, 2, 3}).FinishEntry().FinishOperation().
// SetThreshold(2).
// Done()),
}

for i, txn := range cases {
Expand All @@ -55,11 +55,7 @@ func TestEIP712Arrays(t *testing.T) {
}
txn.Header.Initiator = [32]byte(sig.Metadata().Hash())

b, err := json.Marshal(txn)
require.NoError(t, err)
fmt.Printf("%s\n", b)

b, err = protocol.MarshalEip712(txn, sig)
b, err := protocol.MarshalEip712(txn, sig)
require.NoError(t, err)
buf := new(bytes.Buffer)
require.NoError(t, json.Indent(buf, b, "", " "))
Expand Down
3 changes: 1 addition & 2 deletions protocol/signature_eip712.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ func MarshalEip712(txn *Transaction, sig Signature) (ret []byte, err error) {
Message json.RawMessage `json:"message"`
}
e := eip712{}
e.PrimaryType = "Transaction"
e.Domain = encoding.Eip712Domain

e.Message, err = r.MarshalJSON()
Expand All @@ -72,7 +71,7 @@ func MarshalEip712(txn *Transaction, sig Signature) (ret []byte, err error) {
}

e.Types = map[string][]*encoding.TypeField{}
r.Types(e.Types)
e.PrimaryType = r.Types(e.Types)
encoding.EIP712DomainValue.Types(e.Types)

return json.Marshal(e)
Expand Down

0 comments on commit 27a5030

Please sign in to comment.