diff --git a/pkg/types/encoding/eip712.go b/pkg/types/encoding/eip712.go index 886317080..65b51b4d9 100644 --- a/pkg/types/encoding/eip712.go +++ b/pkg/types/encoding/eip712.go @@ -176,13 +176,16 @@ func (td *TypeDefinition) Resolve(v any, typeName string) (eipResolvedValue, err } var fields []*resolvedFieldValue - for _, field := range *td.Fields { - value, _ := data[field.Name] - v, err := field.resolve(value) + for i, field := range *td.Fields { + value, ok := data[field.Name] + if !ok { + continue + } + v, err := field.resolve(value, typeName) if err != nil { return nil, err } - fields = append(fields, v) + fields = append(fields, &resolvedFieldValue{*field, i, v}) } return &eipResolvedStruct{ @@ -193,7 +196,7 @@ func (td *TypeDefinition) Resolve(v any, typeName string) (eipResolvedValue, err type eipResolvedValue interface { Hash() ([]byte, error) - Types(map[string][]*TypeField) + Types(ret map[string][]*TypeField) string MarshalJSON() ([]byte, error) header(map[string]string) } @@ -205,40 +208,10 @@ type eipResolvedStruct struct { type resolvedFieldValue struct { TypeField - skip bool + index int value eipResolvedValue } -type eipResolvedArray []eipResolvedValue - -func (e eipResolvedArray) MarshalJSON() ([]byte, error) { - return json.Marshal([]eipResolvedValue(e)) -} - -type eipEmptyStruct struct { - typeName string - td *TypeDefinition -} - -func (e *eipEmptyStruct) MarshalJSON() ([]byte, error) { panic("invalid call") } - -func (e *eipEmptyStruct) header(ret map[string]string) { - headerFor(ret, e.typeName, func(fn func(*TypeField)) { - for _, field := range *e.td.Fields { - fn(field) - } - }) -} - -func (e *eipEmptyStruct) Hash() ([]byte, error) { - return make([]byte, 32), nil -} - -func (e *eipEmptyStruct) Types(ret map[string][]*TypeField) { - name, _ := stripSlice(e.typeName) - ret[name] = *e.td.Fields -} - type eipResolvedAtomic struct { ethType string value any @@ -248,10 +221,6 @@ type eipResolvedAtomic struct { func (e *eipResolvedStruct) MarshalJSON() ([]byte, error) { v := map[string]json.RawMessage{} for _, f := range e.fields { - if f.skip { - continue - } - var err error v[f.Name], err = f.value.MarshalJSON() if err != nil { @@ -261,7 +230,7 @@ func (e *eipResolvedStruct) MarshalJSON() ([]byte, error) { return json.Marshal(v) } -const debugHash = false +const debugHash = true func (e *eipResolvedStruct) Hash() ([]byte, error) { //the stripping shouldn't be necessary, but do it as a precaution @@ -323,49 +292,70 @@ func (e *eipResolvedStruct) Hash() ([]byte, error) { } func (e *eipResolvedStruct) header(ret map[string]string) { - headerFor(ret, e.typeName, func(fn func(*TypeField)) { - for _, field := range e.fields { - fn(&field.TypeField) - } - }) - for _, field := range e.fields { - field.value.header(ret) - } -} - -func headerFor(ret map[string]string, typeName string, fields func(func(*TypeField))) { //the stripping shouldn't be necessary, but do it as a precaution - strippedType, _ := stripSlice(typeName) + name, _ := stripSlice(e.typeName) + + //scan the fields + var mask int + var fields []string + for _, f := range e.fields { + mask |= 1 << f.index + fields = append(fields, f.value.Types(nil)+" "+f.Name) + } //define the type structure - var header strings.Builder - header.WriteString(strippedType + "(") - first := true - fields(func(field *TypeField) { - if first { - first = false - } else { - header.WriteString(",") - } - header.WriteString(field.Type + " " + field.Name) - }) - header.WriteString(")") - ret[strippedType] = header.String() + if name != "EIP712Domain" && name != "Transaction" { + name = fmt.Sprintf("%s{%x}", name, mask) + } + ret[name] = fmt.Sprintf("%s(%s)", name, strings.Join(fields, ",")) + + //find dependencies + for _, field := range e.fields { + field.value.header(ret) + } } -func (e *eipResolvedStruct) Types(ret map[string][]*TypeField) { +func (e *eipResolvedStruct) Types(ret map[string][]*TypeField) string { + var mask int var fields []*TypeField for _, f := range e.fields { - fields = append(fields, &TypeField{Name: f.Name, Type: f.Type}) - f.value.Types(ret) + mask |= 1 << f.index + fields = append(fields, &TypeField{ + Name: f.Name, + Type: f.value.Types(ret), + }) + } name, _ := stripSlice(e.typeName) - ret[name] = fields + if name != "EIP712Domain" && name != "Transaction" { + name = fmt.Sprintf("%s{%x}", name, mask) + } + if ret != nil { + ret[name] = fields + } + return name +} + +type eipResolvedArray struct { + typeName string + values []eipResolvedValue +} + +func (e *eipResolvedArray) MarshalJSON() ([]byte, error) { + v := map[string]json.RawMessage{} + for i, elem := range e.values { + var err error + v[fmt.Sprint(i)], err = elem.MarshalJSON() + if err != nil { + return nil, err + } + } + return json.Marshal(v) } -func (e eipResolvedArray) Hash() ([]byte, error) { +func (e *eipResolvedArray) Hash() ([]byte, error) { var buf bytes.Buffer - for _, v := range e { + for _, v := range e.values { hash, err := v.Hash() if err != nil { return nil, err @@ -375,21 +365,34 @@ func (e eipResolvedArray) Hash() ([]byte, error) { return keccak256(buf.Bytes()), nil } -func (e eipResolvedArray) header(ret map[string]string) { - for _, v := range e { +func (e *eipResolvedArray) header(ret map[string]string) { + var fields []string + for i, v := range e.values { + fields = append(fields, fmt.Sprintf("%s %d", v.Types(nil), i)) v.header(ret) } + if ret != nil { + ret[e.typeName] = fmt.Sprintf("%s(%s)", e.typeName, strings.Join(fields, ",")) + } } -func (e eipResolvedArray) Types(ret map[string][]*TypeField) { - for _, v := range e { - v.Types(ret) +func (e *eipResolvedArray) Types(ret map[string][]*TypeField) string { + var fields []*TypeField + for i, v := range e.values { + fields = append(fields, &TypeField{ + Name: fmt.Sprint(i), + Type: v.Types(ret), + }) + } + if ret != nil { + ret[e.typeName] = fields } + return e.typeName } -func (e *eipResolvedAtomic) Hash() ([]byte, error) { return e.hash() } -func (e *eipResolvedAtomic) header(map[string]string) {} -func (e *eipResolvedAtomic) Types(map[string][]*TypeField) {} +func (e *eipResolvedAtomic) Hash() ([]byte, error) { return e.hash() } +func (e *eipResolvedAtomic) header(map[string]string) {} +func (e *eipResolvedAtomic) Types(ret map[string][]*TypeField) string { return e.ethType } func (e *eipResolvedAtomic) MarshalJSON() ([]byte, error) { v := e.value @@ -435,34 +438,31 @@ func newAtomicEncoder[T any](ethType string, hasher func(T) ([]byte, error)) EIP return &eip712AtomicResolver[T]{ethType, hasher} } -func (f *TypeField) resolve(v any) (*resolvedFieldValue, error) { +func (f *TypeField) resolve(v any, parentType string) (eipResolvedValue, error) { strippedType, slices := stripSlice(f.Type) encoder, ok := eip712EncoderMap[strippedType] if ok { if slices { - // If v is nil, return an empty array - if v == nil { - return &resolvedFieldValue{*f, false, eipResolvedArray{}}, nil - } vv, ok := v.([]interface{}) if !ok { return nil, fmt.Errorf("eip712 field %s is not of an array of interfaces", f.Name) } var array eipResolvedArray + array.typeName = fmt.Sprintf("%s{%s}", parentType, f.Name) for _, vvv := range vv { r, err := encoder.Resolve(vvv, f.Type) if err != nil { return nil, err } - array = append(array, r) + array.values = append(array.values, r) } - return &resolvedFieldValue{*f, false, array}, nil + return &array, nil } r, err := encoder.Resolve(v, f.Type) if err != nil { return nil, err } - return &resolvedFieldValue{*f, false, r}, nil + return r, nil } //if we get here, we are expecting a struct @@ -471,11 +471,6 @@ func (f *TypeField) resolve(v any) (*resolvedFieldValue, error) { return nil, fmt.Errorf("eip712 field %s", f.Type) } - // If v is nil and the type is a struct, skip this value - if v == nil { - return &resolvedFieldValue{*f, true, &eipEmptyStruct{strippedType, fields}}, nil - } - //from here on down we are expecting a struct if slices { //we expect a slice @@ -484,27 +479,22 @@ func (f *TypeField) resolve(v any) (*resolvedFieldValue, error) { return nil, fmt.Errorf("eip712 field %s is not of an array of interfaces", f.Name) } var array eipResolvedArray + array.typeName = fmt.Sprintf("%s{%s}", parentType, f.Name) for _, vvv := range vv { - //now run the hasher for the type - // look for encoder, if we don't have one, call the types encoder - fields, ok := SchemaDictionary[strippedType] - if !ok { - return nil, fmt.Errorf("eip712 field %s", f.Type) - } r, err := fields.Resolve(vvv, f.Type) if err != nil { return nil, err } - array = append(array, r) + array.values = append(array.values, r) } - return &resolvedFieldValue{*f, false, array}, nil + return &array, nil } r, err := fields.Resolve(v, f.Type) if err != nil { return nil, err } - return &resolvedFieldValue{*f, false, r}, nil + return r, nil } func NewTypeField(n string, tp string) *TypeField { diff --git a/pkg/types/encoding/eip712_test.go b/pkg/types/encoding/eip712_test.go index 2ddcd7247..bc3e5a581 100644 --- a/pkg/types/encoding/eip712_test.go +++ b/pkg/types/encoding/eip712_test.go @@ -34,14 +34,14 @@ func TestEIP712Arrays(t *testing.T) { For("adi.acme", "book", "1"). UpdateKeyPage(). Add().Entry().Hash([32]byte{1, 2, 3}).FinishEntry().FinishOperation(). - Add().Entry().Owner("foo.bar").FinishEntry().FinishOperation(). - Done()), - must(build.Transaction(). - For("adi.acme", "book", "1"). - UpdateKeyPage(). - Add().Entry().Hash([32]byte{1, 2, 3}).FinishEntry().FinishOperation(). - SetThreshold(2). + // Add().Entry().Owner("foo.bar").FinishEntry().FinishOperation(). Done()), + // must(build.Transaction(). + // For("adi.acme", "book", "1"). + // UpdateKeyPage(). + // Add().Entry().Hash([32]byte{1, 2, 3}).FinishEntry().FinishOperation(). + // SetThreshold(2). + // Done()), } for i, txn := range cases { @@ -55,11 +55,7 @@ func TestEIP712Arrays(t *testing.T) { } txn.Header.Initiator = [32]byte(sig.Metadata().Hash()) - b, err := json.Marshal(txn) - require.NoError(t, err) - fmt.Printf("%s\n", b) - - b, err = protocol.MarshalEip712(txn, sig) + b, err := protocol.MarshalEip712(txn, sig) require.NoError(t, err) buf := new(bytes.Buffer) require.NoError(t, json.Indent(buf, b, "", " ")) diff --git a/protocol/signature_eip712.go b/protocol/signature_eip712.go index 8d94e71ec..12513e5de 100644 --- a/protocol/signature_eip712.go +++ b/protocol/signature_eip712.go @@ -63,7 +63,6 @@ func MarshalEip712(txn *Transaction, sig Signature) (ret []byte, err error) { Message json.RawMessage `json:"message"` } e := eip712{} - e.PrimaryType = "Transaction" e.Domain = encoding.Eip712Domain e.Message, err = r.MarshalJSON() @@ -72,7 +71,7 @@ func MarshalEip712(txn *Transaction, sig Signature) (ret []byte, err error) { } e.Types = map[string][]*encoding.TypeField{} - r.Types(e.Types) + e.PrimaryType = r.Types(e.Types) encoding.EIP712DomainValue.Types(e.Types) return json.Marshal(e)