From 50b6d5e6031e1b6fa4a25c41a1d40cd0d78a205d Mon Sep 17 00:00:00 2001 From: Ethan Reesor Date: Thu, 11 Jul 2024 22:34:23 -0500 Subject: [PATCH] Cleanup --- pkg/types/encoding/eip712.go | 68 ++++++++++++++++-------------------- 1 file changed, 30 insertions(+), 38 deletions(-) diff --git a/pkg/types/encoding/eip712.go b/pkg/types/encoding/eip712.go index 11424ae24..f70eb237d 100644 --- a/pkg/types/encoding/eip712.go +++ b/pkg/types/encoding/eip712.go @@ -67,7 +67,7 @@ var Eip712Domain = EIP712Domain{ } type EIP712Resolver interface { - Resolve(any, string) (eipResolvedValue, error) + Resolve(any) (eipResolvedValue, error) } type eipResolvedValue interface { @@ -104,7 +104,7 @@ func init() { }, eipDomainKey) td := schemaDictionary[eipDomainKey] - EIP712DomainValue = must2(td.Resolve(jdomain, eipDomainKey)) + EIP712DomainValue = must2(td.Resolve(jdomain)) EIP712DomainHash = must2(EIP712DomainValue.Hash(map[string][]*TypeField{ eipDomainKey: *td.Fields, })) @@ -135,13 +135,6 @@ func RegisterUnion[T any, R any](op Func[T, R]) { if !ok { panic(fmt.Errorf("%T is not an enumeration type", *enumType)) } - var a *R - tp := reflect.TypeOf(a).Elem().String() - //strip package name if present - idx := strings.LastIndex(tp, ".") - if idx != -1 { - tp = tp[idx+1:] - } //build a map of types for maxType := uint64(0); ; maxType++ { if !enumValue.SetEnumValue(maxType) { @@ -157,52 +150,51 @@ func RegisterUnion[T any, R any](op Func[T, R]) { t = t.Elem() // Safely obtaining the element type } - name := t.Name() - idx := strings.LastIndex(name, ".") - if idx > 0 { - name = name[idx+1:] - } - key := enumValue.String() - typesMap[key] = name + typesMap[key] = t.Name() } - eip712EncoderMap[tp] = eip712EnumResolver(typesMap) + + tp := reflect.TypeFor[R]().Name() + eip712EncoderMap[tp] = &eip712UnionResolver{tp, typesMap} } -type eip712EnumResolver map[string]string +type eip712UnionResolver struct { + typeName string + members map[string]string +} -func (r eip712EnumResolver) Resolve(v any, typeName string) (eipResolvedValue, error) { +func (r *eip712UnionResolver) Resolve(v any) (eipResolvedValue, error) { //this is a complex structure, so upcast it to an interface map vv, ok := v.(map[string]interface{}) if !ok { return nil, fmt.Errorf("invalid data entry type: %T", v) } - t, ok := vv["type"].(string) + enum, ok := vv["type"].(string) if !ok { return nil, fmt.Errorf("invalid data entry type: %T", vv["type"]) } - a, ok := r[t] + member, ok := r.members[enum] if !ok { - return nil, fmt.Errorf("type alias does not exist in map: %T, %v", v, t) + return nil, fmt.Errorf("type alias does not exist in map: %T, %v", v, enum) } - d, ok := schemaDictionary[a] + typ, ok := schemaDictionary[member] if !ok { return nil, fmt.Errorf("invalid data entry type: %T", vv["type"]) } delete(vv, "type") - rv, err := d.Resolve(vv, typeName) + rv, err := typ.Resolve(vv) if err != nil { return nil, err } return &eipResolvedUnionValue{ - union: typeName, - enum: t, - member: a, + union: r.typeName, + enum: enum, + member: member, value: rv, }, nil } @@ -245,7 +237,7 @@ func (e *eipResolvedUnionValue) MarshalJSON() ([]byte, error) { return json.Marshal(map[string]json.RawMessage{e.enum: b}) } -func (td *TypeDefinition) Resolve(v any, typeName string) (eipResolvedValue, error) { +func (td *TypeDefinition) Resolve(v any) (eipResolvedValue, error) { data, ok := v.(map[string]any) if !ok { return nil, fmt.Errorf("cannot hash type definition with invalid interface %T", v) @@ -451,7 +443,7 @@ type eip712AtomicResolver[V any] struct { hash func(V) ([]byte, error) } -func (r *eip712AtomicResolver[T]) Resolve(v any, _ string) (eipResolvedValue, error) { +func (r *eip712AtomicResolver[T]) Resolve(v any) (eipResolvedValue, error) { // If v is nil, use T's zero value instead if v == nil { var z T @@ -482,8 +474,8 @@ func newAtomicEncoder[T any](ethType string, hasher func(T) ([]byte, error)) EIP } func (f *TypeField) resolve(v any) (*resolvedFieldValue, error) { - strippedType, slices := stripSlice(f.Type) - encoder, ok := eip712EncoderMap[strippedType] + typeName, slices := stripSlice(f.Type) + encoder, ok := eip712EncoderMap[typeName] if ok { if slices { // If v is nil, return an empty array @@ -496,7 +488,7 @@ func (f *TypeField) resolve(v any) (*resolvedFieldValue, error) { } var array eipResolvedArray for _, vvv := range vv { - r, err := encoder.Resolve(vvv, strippedType) + r, err := encoder.Resolve(vvv) if err != nil { return nil, err } @@ -504,7 +496,7 @@ func (f *TypeField) resolve(v any) (*resolvedFieldValue, error) { } return &resolvedFieldValue{*f, false, array}, nil } - r, err := encoder.Resolve(v, f.Type) + r, err := encoder.Resolve(v) if err != nil { return nil, err } @@ -512,7 +504,7 @@ func (f *TypeField) resolve(v any) (*resolvedFieldValue, error) { } //if we get here, we are expecting a struct - fields, ok := schemaDictionary[strippedType] + fields, ok := schemaDictionary[typeName] if !ok { return nil, fmt.Errorf("eip712 field %s", f.Type) } @@ -533,11 +525,11 @@ func (f *TypeField) resolve(v any) (*resolvedFieldValue, error) { for _, vvv := range vv { //now run the hasher for the type // look for encoder, if we don't have one, call the types encoder - fields, ok := schemaDictionary[strippedType] + fields, ok := schemaDictionary[typeName] if !ok { return nil, fmt.Errorf("eip712 field %s", f.Type) } - r, err := fields.Resolve(vvv, strippedType) + r, err := fields.Resolve(vvv) if err != nil { return nil, err } @@ -546,7 +538,7 @@ func (f *TypeField) resolve(v any) (*resolvedFieldValue, error) { return &resolvedFieldValue{*f, false, array}, nil } - r, err := fields.Resolve(v, f.Type) + r, err := fields.Resolve(v) if err != nil { return nil, err } @@ -590,7 +582,7 @@ type EIP712Call struct { } func NewEIP712Call(value any, typ EIP712Resolver) (*EIP712Call, error) { - r, err := typ.Resolve(value, "") + r, err := typ.Resolve(value) if err != nil { return nil, err }