diff --git a/go.mod b/go.mod index 195c00ee8..18cd9d59e 100644 --- a/go.mod +++ b/go.mod @@ -380,7 +380,7 @@ require ( gitlab.com/bosi/decorder v0.4.1 // indirect go.etcd.io/bbolt v1.3.6 go.uber.org/multierr v1.11.0 // indirect - go.uber.org/zap v1.27.0 + go.uber.org/zap v1.27.0 // indirect golang.org/x/crypto v0.23.0 golang.org/x/mod v0.17.0 // indirect golang.org/x/net v0.25.0 // indirect diff --git a/pkg/types/encoding/eip712.go b/pkg/types/encoding/eip712.go index 5051ebb01..2d696f6c8 100644 --- a/pkg/types/encoding/eip712.go +++ b/pkg/types/encoding/eip712.go @@ -17,14 +17,12 @@ import ( "sort" "strings" - "go.uber.org/zap/buffer" "golang.org/x/crypto/sha3" ) type TypeField struct { - Name string `json:"name"` - Type string `json:"type"` - encoder Eip712Encoder + Name string `json:"name"` + Type string `json:"type"` } type TypeDefinition struct { @@ -37,7 +35,7 @@ type EIP712Domain struct { ChainId *big.Int `json:"chainId,omitempty" form:"chainId" query:"chainId" validate:"required"` } -var EIP712DomainMap map[string]interface{} +var EIP712DomainValue eipResolvedValue var EIP712DomainHash []byte var Eip712Domain = EIP712Domain{ Name: "Accumulate", @@ -45,38 +43,29 @@ var Eip712Domain = EIP712Domain{ ChainId: big.NewInt(281), } -type Eip712Encoder struct { - hasher func(v interface{}) ([]byte, error) - resolver func(any, string) (eipResolvedValue, error) - types func(ret map[string][]*TypeField, v interface{}, fieldType string) error +type EIP712Resolver interface { + Resolve(any, string) (eipResolvedValue, error) } -var eip712EncoderMap map[string]Eip712Encoder +var eip712EncoderMap map[string]EIP712Resolver func init() { - eip712EncoderMap = make(map[string]Eip712Encoder) - eip712EncoderMap["bool"] = NewEncoder(FromboolToBytes, nil) - eip712EncoderMap["bytes"] = NewEncoder(FrombytesToBytes, nil) - eip712EncoderMap["bytes32"] = NewEncoder(Frombytes32ToBytes, nil) - eip712EncoderMap["int64"] = NewEncoder(Fromint64ToBytes, nil) - eip712EncoderMap["uint64"] = NewEncoder(Fromuint64ToBytes, nil) - eip712EncoderMap["string"] = NewEncoder(FromstringToBytes, nil) - eip712EncoderMap["address"] = NewEncoder(FromaddressToBytes, nil) - eip712EncoderMap["uint256"] = NewEncoder(Fromuint256ToBytes, nil) - eip712EncoderMap["float64"] = NewEncoder(FromfloatToBytes, nil) - eip712EncoderMap["float"] = NewEncoder(FromfloatToBytes, nil) //Note = Float is not a valid type in EIP-712, so it is converted to a string + eip712EncoderMap = make(map[string]EIP712Resolver) + eip712EncoderMap["bool"] = newAtomicEncoder(FromboolToBytes) + eip712EncoderMap["bytes"] = newAtomicEncoder(FrombytesToBytes) + eip712EncoderMap["bytes32"] = newAtomicEncoder(Frombytes32ToBytes) + eip712EncoderMap["int64"] = newAtomicEncoder(Fromint64ToBytes) + eip712EncoderMap["uint64"] = newAtomicEncoder(Fromuint64ToBytes) + eip712EncoderMap["string"] = newAtomicEncoder(FromstringToBytes) + eip712EncoderMap["address"] = newAtomicEncoder(FromaddressToBytes) + eip712EncoderMap["uint256"] = newAtomicEncoder(Fromuint256ToBytes) + eip712EncoderMap["float64"] = newAtomicEncoder(FromfloatToBytes) + eip712EncoderMap["float"] = newAtomicEncoder(FromfloatToBytes) //Note = Float is not a valid type in EIP-712, so it is converted to a string // Handle EIP712 domain initialization - j, err := Eip712Domain.MarshalJSON() - if err != nil { - //should never get here - panic(err) - } - err = json.Unmarshal(j, &EIP712DomainMap) - if err != nil { - //should never get here - panic(err) - } + var jdomain map[string]interface{} + j := must2(Eip712Domain.MarshalJSON()) + must(json.Unmarshal(j, &jdomain)) RegisterTypeDefinition(&[]*TypeField{ NewTypeField("name", "string"), @@ -85,13 +74,21 @@ func init() { }, "EIP712Domain") td := SchemaDictionary["EIP712Domain"] - EIP712DomainHash, err = td.hash(EIP712DomainMap, "EIP712Domain") + EIP712DomainValue = must2(td.Resolve(jdomain, "EIP712Domain")) + EIP712DomainHash = must2(EIP712DomainValue.Hash()) +} + +func must(err error) { if err != nil { - //shouldn't fail, but if it does, catch it panic(err) } } +func must2[V any](v V, err error) V { + must(err) + return v +} + type Func[T any, R any] func(T) (R, error) type Enum interface { SetEnumValue(id uint64) bool @@ -142,14 +139,12 @@ func mapEnumTypes[T any, R any](f Func[T, R]) (string, map[string]string) { func RegisterEnumeratedTypeInterface[T any, R any](op Func[T, R]) { tp, typesMap := mapEnumTypes(op) - eip712EncoderMap[tp] = NewEncoder(func(v interface{}) ([]byte, error) { - return FromTypedInterfaceToBytes(v, typesMap) - }, func(ret map[string][]*TypeField, v interface{}, typeField string) error { - return FromTypedInterfaceToTypes(ret, v, typesMap) - }) + eip712EncoderMap[tp] = eip712EnumResolver(typesMap) } -func FromTypedInterfaceToBytes(v interface{}, typesAliasMap map[string]string) ([]byte, error) { +type eip712EnumResolver map[string]string + +func (r eip712EnumResolver) Resolve(v any, typeName string) (eipResolvedValue, error) { //this is a complex structure, so upcast it to an interface map vv, ok := v.(map[string]interface{}) if !ok { @@ -161,7 +156,7 @@ func FromTypedInterfaceToBytes(v interface{}, typesAliasMap map[string]string) ( return nil, fmt.Errorf("invalid data entry type: %T", vv["type"]) } - a, ok := typesAliasMap[t] + a, ok := r[t] if !ok { return nil, fmt.Errorf("type alias does not exist in map: %T, %v", v, t) } @@ -171,40 +166,7 @@ func FromTypedInterfaceToBytes(v interface{}, typesAliasMap map[string]string) ( return nil, fmt.Errorf("invalid data entry type: %T", vv["type"]) } - b, err := d.hash(vv, t) - if err != nil { - return nil, err - } - return keccak256(b), nil -} - -func FromTypedInterfaceToTypes(ret map[string][]*TypeField, v interface{}, typesAliasMap map[string]string) error { - //this is a complex structure, so upcast it to an interface map - vv, ok := v.(map[string]interface{}) - if !ok { - return fmt.Errorf("invalid data entry type: %T", v) - } - - t, ok := vv["type"].(string) - if !ok { - return fmt.Errorf("invalid data entry type: %T", vv["type"]) - } - - a, ok := typesAliasMap[t] - if !ok { - return fmt.Errorf("type alias does not exist in map: %T, %v", v, t) - } - - d, ok := SchemaDictionary[a] - if !ok { - return fmt.Errorf("invalid data entry type: %T", vv["type"]) - } - - err := d.types(ret, v, t) - if err != nil { - return err - } - return nil + return d.Resolve(v, typeName) } type TypedData struct { @@ -213,40 +175,7 @@ type TypedData struct { Types []TypeField } -func (td *TypeDefinition) types(ret map[string][]*TypeField, d interface{}, typeName string) error { - var err error - //define the type structure - if ret == nil { - ret = make(map[string][]*TypeField) - } - - data := d.(map[string]interface{}) - - //now loop through the fields and either encode the value or recursively dive into more types - - //the stripping shouldn't be necessary, but do it as a precaution - strippedType, _ := stripSlice(typeName) - - for i, field := range *td.Fields { - value, ok := data[field.Name] - if !ok { - continue - } - - //append the fields - ret[strippedType] = append(ret[strippedType], (*td.Fields)[i]) - - //breakdown field further if required - err = field.types(ret, value, field.Type) - if err != nil { - return err - } - } - - return nil -} - -func (td *TypeDefinition) Resolve(v any, typeName string) (*eipResolvedStruct, error) { +func (td *TypeDefinition) Resolve(v any, typeName string) (eipResolvedValue, error) { data, ok := v.(map[string]any) if !ok { return nil, fmt.Errorf("cannot hash type definition with invalid interface %T", v) @@ -259,7 +188,7 @@ func (td *TypeDefinition) Resolve(v any, typeName string) (*eipResolvedStruct, e continue } - r, err := field.encoder.resolver(value, field.Type) + r, err := field.resolve(value) if err != nil { return nil, err } @@ -287,8 +216,8 @@ func (td *TypeDefinition) hash(v interface{}, typeName string) ([]byte, error) { type eipResolvedValue interface { Hash() ([]byte, error) + Types(map[string][]*TypeField) header(map[string]string) - types(map[string][]*TypeField) } type eipResolvedStruct struct { @@ -363,17 +292,11 @@ func (e *eipResolvedStruct) header(ret map[string]string) { ret[strippedType] = header.String() } -func (e *eipResolvedStruct) Types() map[string][]*TypeField { - ret := map[string][]*TypeField{} - e.types(ret) - return ret -} - -func (e *eipResolvedStruct) types(ret map[string][]*TypeField) { +func (e *eipResolvedStruct) Types(ret map[string][]*TypeField) { var fields []*TypeField for _, f := range e.Fields { fields = append(fields, &TypeField{Name: f.Name, Type: f.Type}) - f.Value.types(ret) + f.Value.Types(ret) } name, _ := stripSlice(e.Type) ret[name] = fields @@ -397,9 +320,9 @@ func (e eipResolvedArray) header(ret map[string]string) { } } -func (e eipResolvedArray) types(ret map[string][]*TypeField) { +func (e eipResolvedArray) Types(ret map[string][]*TypeField) { for _, v := range e { - v.types(ret) + v.Types(ret) } } @@ -408,230 +331,106 @@ func (e *eipResolvedAtomic) Hash() ([]byte, error) { } func (e *eipResolvedAtomic) header(map[string]string) {} -func (e *eipResolvedAtomic) types(map[string][]*TypeField) {} +func (e *eipResolvedAtomic) Types(map[string][]*TypeField) {} + +type eip712AtomicResolver[V any] func(V) ([]byte, error) + +func (r eip712AtomicResolver[T]) Resolve(v any, _ string) (eipResolvedValue, error) { + // JSON always decodes numbers as floats + if u, ok := v.(float64); ok { + var z T + switch any(z).(type) { + case int64: + v = int64(u) + case uint64: + v = uint64(u) + } + } -func (t *TypeField) types(ret map[string][]*TypeField, v interface{}, fieldType string) error { - if t.encoder.types != nil { - //process more complex type - return t.encoder.types(ret, v, fieldType) + t, ok := v.(T) + if !ok { + return nil, fmt.Errorf("eip712 value of type %T does not match type field", v) } - return nil -} -func NewEncoder[T any](hasher func(T) ([]byte, error), types func(ret map[string][]*TypeField, v interface{}, typeField string) error) Eip712Encoder { - hasher2 := func(v interface{}) ([]byte, error) { - // JSON always decodes numbers as floats - if u, ok := v.(float64); ok { - var z T - switch any(z).(type) { - case int64: - v = int64(u) - case uint64: - v = uint64(u) - } - } + return &eipResolvedAtomic{ + Value: v, + hasher: func(any) ([]byte, error) { + return r(t) + }, + }, nil +} - t, ok := v.(T) - if !ok { - return nil, fmt.Errorf("eip712 value of type %T does not match type field", v) - } - return hasher(t) - } - resolver := func(v any, _ string) (eipResolvedValue, error) { - return &eipResolvedAtomic{v, hasher2}, nil - } - return Eip712Encoder{hasher2, resolver, types} +func newAtomicEncoder[T any](hasher func(T) ([]byte, error)) EIP712Resolver { + return eip712AtomicResolver[T](hasher) } -func NewTypeField(n string, tp string) *TypeField { - return &TypeField{n, tp, - Eip712Encoder{func(v interface{}) ([]byte, error) { - strippedType, slices := stripSlice(tp) - encoder, ok := eip712EncoderMap[strippedType] - if ok { - if slices > 0 { - vv, ok := v.([]interface{}) - if !ok { - return nil, fmt.Errorf("eip712 field %s is not of an array of interfaces", n) - } - var buff buffer.Buffer - for _, vvv := range vv { - b, err := encoder.hasher(vvv) - if err != nil { - return nil, err - } - _, err = buff.Write(b) - if err != nil { - return nil, err - } - } - return keccak256(buff.Bytes()), nil - } - return encoder.hasher(v) +func (f *TypeField) resolve(v any) (eipResolvedValue, error) { + strippedType, slices := stripSlice(f.Type) + encoder, ok := eip712EncoderMap[strippedType] + if ok { + if slices { + vv, ok := v.([]interface{}) + if !ok { + return nil, fmt.Errorf("eip712 field %s is not of an array of interfaces", f.Name) } - - //from here on down we are expecting a struct - if slices > 0 { - //we expect a slice - vv, ok := v.([]interface{}) - if !ok { - return nil, fmt.Errorf("eip712 field %s is not of an array of interfaces", n) + var array eipResolvedArray + for _, vvv := range vv { + r, err := encoder.Resolve(vvv, f.Type) + if err != nil { + return nil, err } - var buff buffer.Buffer - //iterate through the interfaces - for _, vvv := range vv { - //now run the hasher for the type - // look for encoder, if we don't have one, call the types encoder - fields, ok := SchemaDictionary[strippedType] - if !ok { - return nil, fmt.Errorf("eip712 field %s", tp) - } - b, err := fields.hash(vvv, tp) - if err != nil { - return nil, err - } - - _, err = buff.Write(b) - if err != nil { - return nil, err - } - } - return keccak256(buff.Bytes()), nil + array = append(array, r) } + return array, nil + } + return encoder.Resolve(v, f.Type) + } - //if we get here, we are expecting a struct + //from here on down we are expecting a struct + if slices { + //we expect a slice + vv, ok := v.([]interface{}) + if !ok { + return nil, fmt.Errorf("eip712 field %s is not of an array of interfaces", f.Name) + } + var array eipResolvedArray + for _, vvv := range vv { + //now run the hasher for the type + // look for encoder, if we don't have one, call the types encoder fields, ok := SchemaDictionary[strippedType] if !ok { - return nil, fmt.Errorf("eip712 field %s", tp) + return nil, fmt.Errorf("eip712 field %s", f.Type) } - - b, err := fields.hash(v, tp) + r, err := fields.Resolve(vvv, f.Type) if err != nil { return nil, err } - return b, nil - }, func(v any, typeName string) (eipResolvedValue, error) { - strippedType, slices := stripSlice(tp) - encoder, ok := eip712EncoderMap[strippedType] - if ok { - if slices > 0 { - vv, ok := v.([]interface{}) - if !ok { - return nil, fmt.Errorf("eip712 field %s is not of an array of interfaces", n) - } - var array eipResolvedArray - for _, vvv := range vv { - r, err := encoder.resolver(vvv, tp) - if err != nil { - return nil, err - } - array = append(array, r) - } - return array, nil - } - return encoder.resolver(v, tp) - } - - //from here on down we are expecting a struct - if slices > 0 { - //we expect a slice - vv, ok := v.([]interface{}) - if !ok { - return nil, fmt.Errorf("eip712 field %s is not of an array of interfaces", n) - } - var array eipResolvedArray - for _, vvv := range vv { - //now run the hasher for the type - // look for encoder, if we don't have one, call the types encoder - fields, ok := SchemaDictionary[strippedType] - if !ok { - return nil, fmt.Errorf("eip712 field %s", tp) - } - r, err := fields.Resolve(vvv, tp) - if err != nil { - return nil, err - } - array = append(array, r) - } - return array, nil - } - - //if we get here, we are expecting a struct - fields, ok := SchemaDictionary[strippedType] - if !ok { - return nil, fmt.Errorf("eip712 field %s", tp) - } + array = append(array, r) + } + return array, nil + } - return fields.Resolve(v, tp) - }, func(ret map[string][]*TypeField, v interface{}, fieldType string) error { - strippedType, slices := stripSlice(tp) - encoder, ok := eip712EncoderMap[strippedType] - if ok { - if encoder.types == nil { - return nil - } - if slices > 0 { - vv, ok := v.([]interface{}) - if !ok { - return fmt.Errorf("eip712 field %s is not of an array of interfaces", n) - } - for _, vvv := range vv { - err := encoder.types(ret, vvv, fieldType) - if err != nil { - return err - } - } - return nil - } - return encoder.types(ret, v, fieldType) - } + //if we get here, we are expecting a struct + fields, ok := SchemaDictionary[strippedType] + if !ok { + return nil, fmt.Errorf("eip712 field %s", f.Type) + } - fields, ok := SchemaDictionary[strippedType] - if !ok { - return fmt.Errorf("eip712 field %s", tp) - } - if slices > 0 { - vv, ok := v.([]interface{}) - if !ok { - return fmt.Errorf("eip712 field %s is not of an array of interfaces", n) - } - for _, vvv := range vv { - err := fields.types(ret, vvv, fieldType) - if err != nil { - return err - } - } - return nil - } + return fields.Resolve(v, f.Type) +} - return fields.types(ret, v, fieldType) - }}, - } +func NewTypeField(n string, tp string) *TypeField { + return &TypeField{n, tp} } // stripSlice removes all array indicators from the input string // and returns the cleaned string along with the count of stripped indicators. -func stripSlice(input string) (string, int) { - count := 0 - indicator := "[]" - - for strings.Contains(input, indicator) { - input = strings.Replace(input, indicator, "", 1) - count++ - } - return input, count +func stripSlice(input string) (string, bool) { + s := strings.TrimSuffix(input, "[]") + return s, len(s) < len(input) } -var SchemaDictionary map[string]*TypeDefinition - -var resolvers map[string]func() error - -func RegisterTypeDefinitionResolver(name string, deferFunc func() error) { - if resolvers == nil { - resolvers = make(map[string]func() error) - } - resolvers[name] = deferFunc -} +var SchemaDictionary map[string]EIP712Resolver func (td *TypeDefinition) sort() { //all types need to be sorted, so just make sure they are... @@ -645,7 +444,7 @@ func RegisterTypeDefinition(tf *[]*TypeField, aliases ...string) { td.sort() if SchemaDictionary == nil { - SchemaDictionary = make(map[string]*TypeDefinition) + SchemaDictionary = make(map[string]EIP712Resolver) } for _, alias := range aliases { @@ -671,10 +470,6 @@ func Eip712Hash(v map[string]interface{}, typeName string, td *TypeDefinition) ( return keccak256(buf.Bytes()), nil } -func Eip712DomainType() *TypeDefinition { - return SchemaDictionary["EIP712Domain"] -} - func FromstringToBytes(s string) ([]byte, error) { return keccak256([]byte(s)), nil } diff --git a/protocol/signature_eip712.go b/protocol/signature_eip712.go index 239600b95..b78dd7842 100644 --- a/protocol/signature_eip712.go +++ b/protocol/signature_eip712.go @@ -67,14 +67,18 @@ func MarshalEip712(txn *Transaction, sig Signature) (ret []byte, err error) { e.PrimaryType = "Transaction" e.Domain = encoding.Eip712Domain e.Message = jtx - e.Types = r.Types() - e.Types["EIP712Domain"] = *encoding.Eip712DomainType().Fields + e.Types = map[string][]*encoding.TypeField{} + r.Types(e.Types) + encoding.EIP712DomainValue.Types(e.Types) + + // Reformat the message JSON to be compatible with Ethereum + formatEIP712Message(jtx, e.Types, e.Types[e.PrimaryType]) return json.Marshal(e) } -func formatEIP712Message(v map[string]any, td *encoding.TypeDefinition) { - for _, field := range *td.Fields { +func formatEIP712Message(v map[string]any, types map[string][]*encoding.TypeField, fields []*encoding.TypeField) { + for _, field := range fields { fv, ok := v[field.Name] if !ok { continue @@ -86,9 +90,9 @@ func formatEIP712Message(v map[string]any, td *encoding.TypeDefinition) { continue } - sch, ok := encoding.SchemaDictionary[strings.TrimPrefix(field.Type, "[]")] + fields, ok := types[strings.TrimPrefix(field.Type, "[]")] if ok { - formatEIP712Message(fv.(map[string]any), sch) + formatEIP712Message(fv.(map[string]any), types, fields) } } } @@ -133,8 +137,6 @@ func makeEIP712Message(txn *Transaction, sig Signature) (map[string]any, error) } jtx["signature"] = jsig - // Reformat the message JSON to be compatible with Ethereum - formatEIP712Message(jtx, NewEip712TransactionDefinition(txn)) return jtx, nil }