Skip to content

Commit

Permalink
Progress
Browse files Browse the repository at this point in the history
  • Loading branch information
firelizzard18 committed Jul 12, 2024
1 parent 0ab36fc commit c92861d
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 127 deletions.
243 changes: 178 additions & 65 deletions pkg/types/encoding/eip712.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,16 @@ var eip712EncoderMap map[string]EIP712Resolver

func init() {
eip712EncoderMap = make(map[string]EIP712Resolver)
eip712EncoderMap["bool"] = newAtomicEncoder(FromboolToBytes)
eip712EncoderMap["bytes"] = newAtomicEncoder(FrombytesToBytes)
eip712EncoderMap["bytes32"] = newAtomicEncoder(Frombytes32ToBytes)
eip712EncoderMap["int64"] = newAtomicEncoder(Fromint64ToBytes)
eip712EncoderMap["uint64"] = newAtomicEncoder(Fromuint64ToBytes)
eip712EncoderMap["string"] = newAtomicEncoder(FromstringToBytes)
eip712EncoderMap["address"] = newAtomicEncoder(FromaddressToBytes)
eip712EncoderMap["uint256"] = newAtomicEncoder(Fromuint256ToBytes)
eip712EncoderMap["float64"] = newAtomicEncoder(FromfloatToBytes)
eip712EncoderMap["float"] = newAtomicEncoder(FromfloatToBytes) //Note = Float is not a valid type in EIP-712, so it is converted to a string
eip712EncoderMap["bool"] = newAtomicEncoder("bool", FromboolToBytes)
eip712EncoderMap["bytes"] = newAtomicEncoder("bytes", FrombytesToBytes)
eip712EncoderMap["bytes32"] = newAtomicEncoder("bytes32", Frombytes32ToBytes)
eip712EncoderMap["int64"] = newAtomicEncoder("int64", Fromint64ToBytes)
eip712EncoderMap["uint64"] = newAtomicEncoder("uint64", Fromuint64ToBytes)
eip712EncoderMap["string"] = newAtomicEncoder("string", FromstringToBytes)
eip712EncoderMap["address"] = newAtomicEncoder("address", FromaddressToBytes)
eip712EncoderMap["uint256"] = newAtomicEncoder("uint256", Fromuint256ToBytes)
eip712EncoderMap["float64"] = newAtomicEncoder("float64", FromfloatToBytes)
eip712EncoderMap["float"] = newAtomicEncoder("float", FromfloatToBytes) //Note = Float is not a valid type in EIP-712, so it is converted to a string

// Handle EIP712 domain initialization
var jdomain map[string]interface{}
Expand Down Expand Up @@ -175,55 +175,97 @@ func (td *TypeDefinition) Resolve(v any, typeName string) (eipResolvedValue, err
return nil, fmt.Errorf("cannot hash type definition with invalid interface %T", v)
}

var fields []*eipResolvedField
var fields []*resolvedFieldValue
for _, field := range *td.Fields {
value, ok := data[field.Name]
if !ok {
continue
}

r, err := field.resolve(value)
value, _ := data[field.Name]
v, err := field.resolve(value)
if err != nil {
return nil, err
}

fields = append(fields, &eipResolvedField{
Name: field.Name,
Type: field.Type,
Value: r,
})
fields = append(fields, v)
}

return &eipResolvedStruct{
Type: typeName,
Fields: fields,
typeName: typeName,
fields: fields,
}, nil
}

type eipResolvedValue interface {
Hash() ([]byte, error)
Types(map[string][]*TypeField)
MarshalJSON() ([]byte, error)
header(map[string]string)
}

type eipResolvedStruct struct {
Type string
Fields []*eipResolvedField
typeName string
fields []*resolvedFieldValue
}

type eipResolvedField struct {
Name string
Type string
Value eipResolvedValue
type resolvedFieldValue struct {
TypeField
skip bool
value eipResolvedValue
}

type eipResolvedArray []eipResolvedValue

type eipResolvedAtomic func() ([]byte, error)
func (e eipResolvedArray) MarshalJSON() ([]byte, error) {
return json.Marshal([]eipResolvedValue(e))
}

type eipEmptyStruct struct {
typeName string
td *TypeDefinition
}

func (e *eipEmptyStruct) MarshalJSON() ([]byte, error) { panic("invalid call") }

func (e *eipEmptyStruct) header(ret map[string]string) {
headerFor(ret, e.typeName, func(fn func(*TypeField)) {
for _, field := range *e.td.Fields {
fn(field)
}
})
}

func (e *eipEmptyStruct) Hash() ([]byte, error) {
return make([]byte, 32), nil
}

func (e *eipEmptyStruct) Types(ret map[string][]*TypeField) {
name, _ := stripSlice(e.typeName)
ret[name] = *e.td.Fields
}

type eipResolvedAtomic struct {
ethType string
value any
hash func() ([]byte, error)
}

func (e *eipResolvedStruct) MarshalJSON() ([]byte, error) {
v := map[string]json.RawMessage{}
for _, f := range e.fields {
if f.skip {
continue
}

var err error
v[f.Name], err = f.value.MarshalJSON()
if err != nil {
return nil, err
}
}
return json.Marshal(v)
}

const debugHash = false

func (e *eipResolvedStruct) Hash() ([]byte, error) {
//the stripping shouldn't be necessary, but do it as a precaution
strippedType, _ := stripSlice(e.Type)
strippedType, _ := stripSlice(e.typeName)

deps := map[string]string{}
e.header(deps)
Expand All @@ -242,46 +284,82 @@ func (e *eipResolvedStruct) Hash() ([]byte, error) {
}

var buf bytes.Buffer
buf.Write(keccak256(header.Bytes()))
hash := keccak256(header.Bytes())
buf.Write(hash)

var parts [][]byte
if debugHash {
parts = append(parts, hash)
}

//now loop through the fields and either encode the value or recursively dive into more types
for _, field := range e.Fields {
for _, field := range e.fields {
//now run the hasher
encodedValue, err := field.Value.Hash()
encodedValue, err := field.value.Hash()
if err != nil {
return nil, err
}
buf.Write(encodedValue)

if debugHash {
parts = append(parts, encodedValue)
}
}

return keccak256(buf.Bytes()), nil
hash = keccak256(buf.Bytes())
if debugHash {
fmt.Println(header.String())
for i, v := range parts {
if i == 0 {
fmt.Printf(" %x\n", v)
} else {
fmt.Printf(" + %x\n", v)
}
}
fmt.Printf(" = %x\n", hash)
}

return hash, nil
}

func (e *eipResolvedStruct) header(ret map[string]string) {
headerFor(ret, e.typeName, func(fn func(*TypeField)) {
for _, field := range e.fields {
fn(&field.TypeField)
}
})
for _, field := range e.fields {
field.value.header(ret)
}
}

func headerFor(ret map[string]string, typeName string, fields func(func(*TypeField))) {
//the stripping shouldn't be necessary, but do it as a precaution
strippedType, _ := stripSlice(e.Type)
strippedType, _ := stripSlice(typeName)

//define the type structure
var header strings.Builder
header.WriteString(strippedType + "(")
for i, field := range e.Fields {
if i > 0 {
first := true
fields(func(field *TypeField) {
if first {
first = false
} else {
header.WriteString(",")
}
header.WriteString(field.Type + " " + field.Name)
field.Value.header(ret)
}
})
header.WriteString(")")
ret[strippedType] = header.String()
}

func (e *eipResolvedStruct) Types(ret map[string][]*TypeField) {
var fields []*TypeField
for _, f := range e.Fields {
for _, f := range e.fields {
fields = append(fields, &TypeField{Name: f.Name, Type: f.Type})
f.Value.Types(ret)
f.value.Types(ret)
}
name, _ := stripSlice(e.Type)
name, _ := stripSlice(e.typeName)
ret[name] = fields
}

Expand Down Expand Up @@ -309,13 +387,31 @@ func (e eipResolvedArray) Types(ret map[string][]*TypeField) {
}
}

func (e eipResolvedAtomic) Hash() ([]byte, error) { return e() }
func (e eipResolvedAtomic) header(map[string]string) {}
func (e eipResolvedAtomic) Types(map[string][]*TypeField) {}
func (e *eipResolvedAtomic) Hash() ([]byte, error) { return e.hash() }
func (e *eipResolvedAtomic) header(map[string]string) {}
func (e *eipResolvedAtomic) Types(map[string][]*TypeField) {}

type eip712AtomicResolver[V any] func(V) ([]byte, error)
func (e *eipResolvedAtomic) MarshalJSON() ([]byte, error) {
v := e.value
switch e.ethType {
case "bytes", "bytes32", "address":
v = fmt.Sprintf("0x%v", v)
}
return json.Marshal(v)
}

type eip712AtomicResolver[V any] struct {
ethType string
hash func(V) ([]byte, error)
}

func (r *eip712AtomicResolver[T]) Resolve(v any, _ string) (eipResolvedValue, error) {
// If v is nil, use T's zero value instead
if v == nil {
var z T
v = z
}

func (r eip712AtomicResolver[T]) Resolve(v any, _ string) (eipResolvedValue, error) {
// JSON always decodes numbers as floats
if u, ok := v.(float64); ok {
var z T
Expand All @@ -332,18 +428,22 @@ func (r eip712AtomicResolver[T]) Resolve(v any, _ string) (eipResolvedValue, err
return nil, fmt.Errorf("eip712 value of type %T does not match type field", v)
}

return eipResolvedAtomic(func() ([]byte, error) { return r(t) }), nil
return &eipResolvedAtomic{r.ethType, v, func() ([]byte, error) { return r.hash(t) }}, nil
}

func newAtomicEncoder[T any](hasher func(T) ([]byte, error)) EIP712Resolver {
return eip712AtomicResolver[T](hasher)
func newAtomicEncoder[T any](ethType string, hasher func(T) ([]byte, error)) EIP712Resolver {
return &eip712AtomicResolver[T]{ethType, hasher}
}

func (f *TypeField) resolve(v any) (eipResolvedValue, error) {
func (f *TypeField) resolve(v any) (*resolvedFieldValue, error) {
strippedType, slices := stripSlice(f.Type)
encoder, ok := eip712EncoderMap[strippedType]
if ok {
if slices {
// If v is nil, return an empty array
if v == nil {
return &resolvedFieldValue{*f, false, eipResolvedArray{}}, nil
}
vv, ok := v.([]interface{})
if !ok {
return nil, fmt.Errorf("eip712 field %s is not of an array of interfaces", f.Name)
Expand All @@ -356,9 +456,24 @@ func (f *TypeField) resolve(v any) (eipResolvedValue, error) {
}
array = append(array, r)
}
return array, nil
return &resolvedFieldValue{*f, false, array}, nil
}
return encoder.Resolve(v, f.Type)
r, err := encoder.Resolve(v, f.Type)
if err != nil {
return nil, err
}
return &resolvedFieldValue{*f, false, r}, nil
}

//if we get here, we are expecting a struct
fields, ok := SchemaDictionary[strippedType]
if !ok {
return nil, fmt.Errorf("eip712 field %s", f.Type)
}

// If v is nil and the type is a struct, skip this value
if v == nil {
return &resolvedFieldValue{*f, true, &eipEmptyStruct{strippedType, fields}}, nil
}

//from here on down we are expecting a struct
Expand All @@ -382,16 +497,14 @@ func (f *TypeField) resolve(v any) (eipResolvedValue, error) {
}
array = append(array, r)
}
return array, nil
return &resolvedFieldValue{*f, false, array}, nil
}

//if we get here, we are expecting a struct
fields, ok := SchemaDictionary[strippedType]
if !ok {
return nil, fmt.Errorf("eip712 field %s", f.Type)
r, err := fields.Resolve(v, f.Type)
if err != nil {
return nil, err
}

return fields.Resolve(v, f.Type)
return &resolvedFieldValue{*f, false, r}, nil
}

func NewTypeField(n string, tp string) *TypeField {
Expand All @@ -405,7 +518,7 @@ func stripSlice(input string) (string, bool) {
return s, len(s) < len(input)
}

var SchemaDictionary map[string]EIP712Resolver
var SchemaDictionary map[string]*TypeDefinition

func (td *TypeDefinition) sort() {
//all types need to be sorted, so just make sure they are...
Expand All @@ -419,7 +532,7 @@ func RegisterTypeDefinition(tf *[]*TypeField, aliases ...string) {
td.sort()

if SchemaDictionary == nil {
SchemaDictionary = make(map[string]EIP712Resolver)
SchemaDictionary = make(map[string]*TypeDefinition)
}

for _, alias := range aliases {
Expand Down
Loading

0 comments on commit c92861d

Please sign in to comment.