Skip to content

Commit

Permalink
Implement eth_signTypedData_v4
Browse files Browse the repository at this point in the history
  • Loading branch information
firelizzard18 committed Jul 10, 2024
1 parent 5cd7331 commit 56e7894
Show file tree
Hide file tree
Showing 10 changed files with 621 additions and 71 deletions.
244 changes: 213 additions & 31 deletions pkg/types/encoding/eip712.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"math/big"
"reflect"
Expand Down Expand Up @@ -50,8 +49,9 @@ var Eip712Domain = EIP712Domain{
}

type Eip712Encoder struct {
hasher func(v interface{}) ([]byte, error)
types func(ret map[string][]*TypeField, v interface{}, fieldType string) error
hasher func(v interface{}) ([]byte, error)
resolver func(any, string) (eipResolvedValue, error)
types func(ret map[string][]*TypeField, v interface{}, fieldType string) error
}

var eip712EncoderMap map[string]Eip712Encoder
Expand Down Expand Up @@ -250,49 +250,170 @@ func (td *TypeDefinition) types(ret map[string][]*TypeField, d interface{}, type
return nil
}

func (td *TypeDefinition) hash(v interface{}, typeName string) ([]byte, error) {
data, ok := v.(map[string]interface{})
func (td *TypeDefinition) Resolve(v any, typeName string) (*eipResolvedStruct, error) {
data, ok := v.(map[string]any)
if !ok {
return nil, fmt.Errorf("cannot hash type definition with invalid interface %T", v)
}

//define the type structure
var header bytes.Buffer
var body bytes.Buffer

//now loop through the fields and either encode the value or recursively dive into more types
first := true
//the stripping shouldn't be necessary, but do it as a precaution
strippedType, _ := stripSlice(typeName)
header.WriteString(strippedType + "(")
var fields []*eipResolvedField
for _, field := range *td.Fields {
value, ok := data[field.Name]
if !ok {
continue
}
delete(data, field.Name)

r, err := field.encoder.resolver(value, field.Type)
if err != nil {
return nil, err
}

fields = append(fields, &eipResolvedField{
Name: field.Name,
Type: field.Type,
Value: r,
})
}

return &eipResolvedStruct{
Type: typeName,
Fields: fields,
}, nil
}

func (td *TypeDefinition) hash(v interface{}, typeName string) ([]byte, error) {
e, err := td.Resolve(v, typeName)
if err != nil {
return nil, err
}
return e.Hash()
}

type eipResolvedValue interface {
Hash() ([]byte, error)
header(map[string]string)
types(map[string][]*TypeField)
}

type eipResolvedStruct struct {
Type string
Fields []*eipResolvedField
}

type eipResolvedField struct {
Name string
Type string
Value eipResolvedValue
}

type eipResolvedArray []eipResolvedValue

type eipResolvedAtomic struct {
Value any
hasher func(any) ([]byte, error)
}

func (e *eipResolvedStruct) Hash() ([]byte, error) {
//the stripping shouldn't be necessary, but do it as a precaution
strippedType, _ := stripSlice(e.Type)

deps := map[string]string{}
e.header(deps)

var header bytes.Buffer
header.WriteString(deps[strippedType])
delete(deps, strippedType)

var depNames []string
for name := range deps {
depNames = append(depNames, name)
}
sort.Strings(depNames)
for _, name := range depNames {
header.WriteString(deps[name])
}

var buf bytes.Buffer
buf.Write(keccak256(header.Bytes()))

//now loop through the fields and either encode the value or recursively dive into more types
for _, field := range e.Fields {
//now run the hasher
encodedValue, err := field.encoder.hasher(value)
encodedValue, err := field.Value.Hash()
if err != nil {
return nil, err
}
if !first {
buf.Write(encodedValue)
}

return keccak256(buf.Bytes()), nil
}

func (e *eipResolvedStruct) header(ret map[string]string) {
//the stripping shouldn't be necessary, but do it as a precaution
strippedType, _ := stripSlice(e.Type)

//define the type structure
var header strings.Builder
header.WriteString(strippedType + "(")
for i, field := range e.Fields {
if i > 0 {
header.WriteString(",")
}
header.WriteString(field.Type + " " + field.Name)
body.Write(encodedValue)
first = false
field.Value.header(ret)
}
header.WriteString(")")
ret[strippedType] = header.String()
}

func (e *eipResolvedStruct) Types() map[string][]*TypeField {
ret := map[string][]*TypeField{}
e.types(ret)
return ret
}

func (e *eipResolvedStruct) types(ret map[string][]*TypeField) {
var fields []*TypeField
for _, f := range e.Fields {
fields = append(fields, &TypeField{Name: f.Name, Type: f.Type})
f.Value.types(ret)
}
name, _ := stripSlice(e.Type)
ret[name] = fields
}

func (e eipResolvedArray) Hash() ([]byte, error) {
var buf bytes.Buffer
for _, v := range e {
hash, err := v.Hash()
if err != nil {
return nil, err
}
_, _ = buf.Write(hash)
}
return keccak256(buf.Bytes()), nil
}

if len(data) > 0 {
return nil, errors.New("eip712 payload contains unknown fields")
func (e eipResolvedArray) header(ret map[string]string) {
for _, v := range e {
v.header(ret)
}
}

func (e eipResolvedArray) types(ret map[string][]*TypeField) {
for _, v := range e {
v.types(ret)
}
}

return keccak256(append(keccak256(header.Bytes()), body.Bytes()...)), nil
func (e *eipResolvedAtomic) Hash() ([]byte, error) {
return e.hasher(e.Value)
}

func (e *eipResolvedAtomic) header(map[string]string) {}
func (e *eipResolvedAtomic) types(map[string][]*TypeField) {}

func (t *TypeField) types(ret map[string][]*TypeField, v interface{}, fieldType string) error {
if t.encoder.types != nil {
//process more complex type
Expand All @@ -302,7 +423,7 @@ func (t *TypeField) types(ret map[string][]*TypeField, v interface{}, fieldType
}

func NewEncoder[T any](hasher func(T) ([]byte, error), types func(ret map[string][]*TypeField, v interface{}, typeField string) error) Eip712Encoder {
return Eip712Encoder{func(v interface{}) ([]byte, error) {
hasher2 := func(v interface{}) ([]byte, error) {
// JSON always decodes numbers as floats
if u, ok := v.(float64); ok {
var z T
Expand All @@ -319,7 +440,11 @@ func NewEncoder[T any](hasher func(T) ([]byte, error), types func(ret map[string
return nil, fmt.Errorf("eip712 value of type %T does not match type field", v)
}
return hasher(t)
}, types}
}
resolver := func(v any, _ string) (eipResolvedValue, error) {
return &eipResolvedAtomic{v, hasher2}, nil
}
return Eip712Encoder{hasher2, resolver, types}
}

func NewTypeField(n string, tp string) *TypeField {
Expand Down Expand Up @@ -389,6 +514,59 @@ func NewTypeField(n string, tp string) *TypeField {
return nil, err
}
return b, nil
}, func(v any, typeName string) (eipResolvedValue, error) {
strippedType, slices := stripSlice(tp)
encoder, ok := eip712EncoderMap[strippedType]
if ok {
if slices > 0 {
vv, ok := v.([]interface{})
if !ok {
return nil, fmt.Errorf("eip712 field %s is not of an array of interfaces", n)
}
var array eipResolvedArray
for _, vvv := range vv {
r, err := encoder.resolver(vvv, tp)
if err != nil {
return nil, err
}
array = append(array, r)
}
return array, nil
}
return encoder.resolver(v, tp)
}

//from here on down we are expecting a struct
if slices > 0 {
//we expect a slice
vv, ok := v.([]interface{})
if !ok {
return nil, fmt.Errorf("eip712 field %s is not of an array of interfaces", n)
}
var array eipResolvedArray
for _, vvv := range vv {
//now run the hasher for the type
// look for encoder, if we don't have one, call the types encoder
fields, ok := SchemaDictionary[strippedType]
if !ok {
return nil, fmt.Errorf("eip712 field %s", tp)
}
r, err := fields.Resolve(vvv, tp)
if err != nil {
return nil, err
}
array = append(array, r)
}
return array, nil
}

//if we get here, we are expecting a struct
fields, ok := SchemaDictionary[strippedType]
if !ok {
return nil, fmt.Errorf("eip712 field %s", tp)
}

return fields.Resolve(v, tp)
}, func(ret map[string][]*TypeField, v interface{}, fieldType string) error {
strippedType, slices := stripSlice(tp)
encoder, ok := eip712EncoderMap[strippedType]
Expand Down Expand Up @@ -480,17 +658,21 @@ func RegisterTypeDefinition(tf *[]*TypeField, aliases ...string) {
}

func Eip712Hash(v map[string]interface{}, typeName string, td *TypeDefinition) ([]byte, error) {
messageHash, err := td.hash(v, typeName)
r, err := td.Resolve(v, typeName)
if err != nil {
return nil, err
}
messageHash, err := r.Hash()
if err != nil {
return nil, err
}
return keccak256(append(EIP712DomainHash, messageHash...)), nil
}

func Eip712Types(v map[string]any, typeName string, td *TypeDefinition) (map[string][]*TypeField, error) {
ret := map[string][]*TypeField{}
err := td.types(ret, v, typeName)
return ret, err
var buf bytes.Buffer
buf.WriteByte(0x19)
buf.WriteByte(0x01)
buf.Write(EIP712DomainHash)
buf.Write(messageHash)
return keccak256(buf.Bytes()), nil
}

func Eip712DomainType() *TypeDefinition {
Expand Down
20 changes: 0 additions & 20 deletions protocol/signature.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,26 +126,6 @@ func PublicKeyHash(key []byte, typ SignatureType) ([]byte, error) {
}
}

// generates privatekey and compressed public key
func SECP256K1Keypair() (privKey []byte, pubKey []byte) {
priv, _ := btc.NewPrivateKey(btc.S256())

privKey = priv.Serialize()
_, pub := btc.PrivKeyFromBytes(btc.S256(), privKey)
pubKey = pub.SerializeCompressed()
return privKey, pubKey
}

// generates privatekey and Un-compressed public key
func SECP256K1UncompressedKeypair() (privKey []byte, pubKey []byte) {
priv, _ := btc.NewPrivateKey(btc.S256())

privKey = priv.Serialize()
_, pub := btc.PrivKeyFromBytes(btc.S256(), privKey)
pubKey = pub.SerializeUncompressed()
return privKey, pubKey
}

func BTCHash(pubKey []byte) []byte {
hasher := ripemd160.New()
hash := sha256.Sum256(pubKey[:])
Expand Down
16 changes: 7 additions & 9 deletions protocol/signature_eip712.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ func MarshalEip712(txn *Transaction, sig Signature) (ret []byte, err error) {
if err != nil {
return nil, err
}
r, err := NewEip712TransactionDefinition(txn).Resolve(jtx, "Transaction")
if err != nil {
return nil, err
}

// Construct the wallet RPC call
type eip712 struct {
Expand All @@ -56,16 +60,8 @@ func MarshalEip712(txn *Transaction, sig Signature) (ret []byte, err error) {
e := eip712{}
e.PrimaryType = "Transaction"
e.Domain = encoding.Eip712Domain

// Reformat the message JSON to be compatible with Ethereum
td := NewEip712TransactionDefinition(txn)
formatEIP712Message(jtx, td)
e.Message = jtx

e.Types, err = encoding.Eip712Types(jtx, "Transaction", td)
if err != nil {
return nil, err
}
e.Types = r.Types()
e.Types["EIP712Domain"] = *encoding.Eip712DomainType().Fields

return json.Marshal(e)
Expand Down Expand Up @@ -131,6 +127,8 @@ func makeEIP712Message(txn *Transaction, sig Signature) (map[string]any, error)
}
jtx["signature"] = jsig

// Reformat the message JSON to be compatible with Ethereum
formatEIP712Message(jtx, NewEip712TransactionDefinition(txn))
return jtx, nil
}

Expand Down
Loading

0 comments on commit 56e7894

Please sign in to comment.