diff --git a/internal/generator/backend/template/gkr/gkr.go.tmpl b/internal/generator/backend/template/gkr/gkr.go.tmpl index 3e3881d15f..16d5eb970b 100644 --- a/internal/generator/backend/template/gkr/gkr.go.tmpl +++ b/internal/generator/backend/template/gkr/gkr.go.tmpl @@ -771,13 +771,13 @@ func (gateAPI) Println(a ...frontend.Variable) { for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) diff --git a/internal/gkr/bls12-377/gkr.go b/internal/gkr/bls12-377/gkr.go index b92ac1249d..b8ef9ea973 100644 --- a/internal/gkr/bls12-377/gkr.go +++ b/internal/gkr/bls12-377/gkr.go @@ -775,13 +775,13 @@ func (gateAPI) Println(a ...frontend.Variable) { for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) diff --git a/internal/gkr/bls12-381/gkr.go b/internal/gkr/bls12-381/gkr.go index 82084049d9..8f72898737 100644 --- a/internal/gkr/bls12-381/gkr.go +++ b/internal/gkr/bls12-381/gkr.go @@ -775,13 +775,13 @@ func (gateAPI) Println(a ...frontend.Variable) { for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) diff --git a/internal/gkr/bls24-315/gkr.go b/internal/gkr/bls24-315/gkr.go index f182c9176b..7aee277ba4 100644 --- a/internal/gkr/bls24-315/gkr.go +++ b/internal/gkr/bls24-315/gkr.go @@ -775,13 +775,13 @@ func (gateAPI) Println(a ...frontend.Variable) { for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) diff --git a/internal/gkr/bls24-317/gkr.go b/internal/gkr/bls24-317/gkr.go index a284f14ae9..7c679216fc 100644 --- a/internal/gkr/bls24-317/gkr.go +++ b/internal/gkr/bls24-317/gkr.go @@ -775,13 +775,13 @@ func (gateAPI) Println(a ...frontend.Variable) { for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) diff --git a/internal/gkr/bn254/gkr.go b/internal/gkr/bn254/gkr.go index 14269151b3..0174caa564 100644 --- a/internal/gkr/bn254/gkr.go +++ b/internal/gkr/bn254/gkr.go @@ -775,13 +775,13 @@ func (gateAPI) Println(a ...frontend.Variable) { for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) diff --git a/internal/gkr/bw6-633/gkr.go b/internal/gkr/bw6-633/gkr.go index ec1067f736..2c2bda2037 100644 --- a/internal/gkr/bw6-633/gkr.go +++ b/internal/gkr/bw6-633/gkr.go @@ -775,13 +775,13 @@ func (gateAPI) Println(a ...frontend.Variable) { for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) diff --git a/internal/gkr/bw6-761/gkr.go b/internal/gkr/bw6-761/gkr.go index ad5197feef..099b015b02 100644 --- a/internal/gkr/bw6-761/gkr.go +++ b/internal/gkr/bw6-761/gkr.go @@ -775,13 +775,13 @@ func (gateAPI) Println(a ...frontend.Variable) { for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) diff --git a/internal/gkr/engine_hints.go b/internal/gkr/engine_hints.go index 74b15c77ba..8c8bc1b797 100644 --- a/internal/gkr/engine_hints.go +++ b/internal/gkr/engine_hints.go @@ -187,7 +187,7 @@ func (g gateAPI) Println(a ...frontend.Variable) { for i := range a { if s, ok := a[i].(fmt.Stringer); ok { strings[i] = s.String() - } else { + } else if strings[i], ok = a[i].(string); !ok { bigInt := utils.FromInterface(a[i]) strings[i] = bigInt.String() } diff --git a/internal/gkr/small_rational/gkr.go b/internal/gkr/small_rational/gkr.go index cdf62359f2..d085c6305f 100644 --- a/internal/gkr/small_rational/gkr.go +++ b/internal/gkr/small_rational/gkr.go @@ -775,13 +775,13 @@ func (gateAPI) Println(a ...frontend.Variable) { for i, v := range a { if _, err := x.SetInterface(v); err != nil { - toPrint[i] = x.String() - } else { if s, ok := v.(string); ok { toPrint[i] = s continue } panic(fmt.Errorf("not numeric or string: %w", err)) + } else { + toPrint[i] = x.String() } } fmt.Println(toPrint...) diff --git a/std/hash/hash.go b/std/hash/hash.go index c077fd0d37..564f122a48 100644 --- a/std/hash/hash.go +++ b/std/hash/hash.go @@ -5,6 +5,8 @@ package hash import ( + "fmt" + "github.com/consensys/gnark/frontend" "github.com/consensys/gnark/std/math/uints" ) @@ -110,7 +112,7 @@ type merkleDamgardHasher struct { // NewMerkleDamgardHasher transforms a 2-1 one-way function into a hash // initialState is a value whose preimage is not known -func NewMerkleDamgardHasher(api frontend.API, f Compressor, initialState frontend.Variable) FieldHasher { +func NewMerkleDamgardHasher(api frontend.API, f Compressor, initialState frontend.Variable) StateStorer { return &merkleDamgardHasher{ state: initialState, iv: initialState, @@ -132,3 +134,18 @@ func (h *merkleDamgardHasher) Write(data ...frontend.Variable) { func (h *merkleDamgardHasher) Sum() frontend.Variable { return h.state } + +func (h *merkleDamgardHasher) State() []frontend.Variable { + return []frontend.Variable{h.state} +} + +func (h *merkleDamgardHasher) SetState(state []frontend.Variable) error { + if h.state != h.iv { + return fmt.Errorf("the hasher is not in an initial state; reset before attempting to set the state") + } + if len(state) != 1 { + return fmt.Errorf("expected one state variable, got %d", len(state)) + } + h.state = state[0] + return nil +} diff --git a/std/hash/mimc/gkr-mimc/gkr-mimc.go b/std/hash/mimc/gkr-mimc/gkr-mimc.go new file mode 100644 index 0000000000..8e6a8766d8 --- /dev/null +++ b/std/hash/mimc/gkr-mimc/gkr-mimc.go @@ -0,0 +1,17 @@ +package gkr_mimc + +import ( + "fmt" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/hash" + gkr_mimc "github.com/consensys/gnark/std/permutation/gkr-mimc" +) + +func New(api frontend.API) (hash.StateStorer, error) { + f, err := gkr_mimc.NewCompressor(api) + if err != nil { + return nil, fmt.Errorf("could not create mimc hasher: %w", err) + } + return hash.NewMerkleDamgardHasher(api, f, 0), nil +} diff --git a/std/hash/mimc/gkr-mimc/gkr-mimc_test.go b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go new file mode 100644 index 0000000000..8a9381bfd7 --- /dev/null +++ b/std/hash/mimc/gkr-mimc/gkr-mimc_test.go @@ -0,0 +1,165 @@ +package gkr_mimc + +import ( + "errors" + "fmt" + "os" + "slices" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/plonk" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/scs" + "github.com/consensys/gnark/std/hash/mimc" + "github.com/consensys/gnark/test" + "github.com/stretchr/testify/require" +) + +func TestGkrMiMC(t *testing.T) { + lengths := []int{1, 2, 3} + vals := make([]frontend.Variable, len(lengths)*2) + for i := range vals { + vals[i] = i + 1 + } + + for _, length := range lengths[1:2] { + circuit := &testGkrMiMCCircuit{ + In: make([]frontend.Variable, length*2), + } + assignment := &testGkrMiMCCircuit{ + In: slices.Clone(vals[:length*2]), + } + + test.NewAssert(t).CheckCircuit(circuit, test.WithValidAssignment(assignment)) + } +} + +type testGkrMiMCCircuit struct { + In []frontend.Variable + skipCheck bool +} + +func (c *testGkrMiMCCircuit) Define(api frontend.API) error { + gkrmimc, err := New(api) + if err != nil { + return err + } + + plainMiMC, err := mimc.New(api) + if err != nil { + return err + } + + // first check that empty input is handled correctly + api.AssertIsEqual(gkrmimc.Sum(), plainMiMC.Sum()) + + ins := [][]frontend.Variable{c.In[:len(c.In)/2], c.In[len(c.In)/2:]} + for _, in := range ins { + gkrmimc.Reset() + gkrmimc.Write(in...) + res := gkrmimc.Sum() + + if !c.skipCheck { + plainMiMC.Reset() + plainMiMC.Write(in...) + expected := plainMiMC.Sum() + api.AssertIsEqual(expected, res) + } + } + + return nil +} + +func TestGkrMiMCCompiles(t *testing.T) { + const n = 52000 + circuit := testGkrMiMCCircuit{ + In: make([]frontend.Variable, n), + } + cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &circuit, frontend.WithCapacity(27_000_000)) + require.NoError(t, err) + fmt.Println(cs.GetNbConstraints(), "constraints") +} + +type hashTreeCircuit struct { + Leaves []frontend.Variable +} + +func (c hashTreeCircuit) Define(api frontend.API) error { + if len(c.Leaves) == 0 { + return errors.New("no hashing to do") + } + + hsh, err := New(api) + if err != nil { + return err + } + + layer := slices.Clone(c.Leaves) + + for len(layer) > 1 { + if len(layer)%2 == 1 { + layer = append(layer, 0) // pad with zero + } + + for i := range len(layer) / 2 { + hsh.Reset() + hsh.Write(layer[2*i], layer[2*i+1]) + layer[i] = hsh.Sum() + } + + layer = layer[:len(layer)/2] + } + + api.AssertIsDifferent(layer[0], 0) + return nil +} + +func loadCs(t require.TestingT, filename string, circuit frontend.Circuit) constraint.ConstraintSystem { + f, err := os.Open(filename) + + if os.IsNotExist(err) { + // actually compile + cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, circuit) + require.NoError(t, err) + f, err = os.Create(filename) + require.NoError(t, err) + defer f.Close() + _, err = cs.WriteTo(f) + require.NoError(t, err) + return cs + } + + defer f.Close() + require.NoError(t, err) + + cs := plonk.NewCS(ecc.BLS12_377) + + _, err = cs.ReadFrom(f) + require.NoError(t, err) + + return cs +} + +func BenchmarkHashTree(b *testing.B) { + const size = 1 << 15 // about 2 ^ 16 total hashes + + circuit := hashTreeCircuit{ + Leaves: make([]frontend.Variable, size), + } + assignment := hashTreeCircuit{ + Leaves: make([]frontend.Variable, size), + } + + for i := range assignment.Leaves { + assignment.Leaves[i] = i + } + + cs := loadCs(b, "gkrmimc_hashtree.cs", &circuit) + + w, err := frontend.NewWitness(&assignment, ecc.BLS12_377.ScalarField()) + require.NoError(b, err) + + require.NoError(b, cs.IsSolved(w)) +} diff --git a/std/hash/mimc/mimc.go b/std/hash/mimc/mimc.go index 9d8a98e306..db72f54429 100644 --- a/std/hash/mimc/mimc.go +++ b/std/hash/mimc/mimc.go @@ -34,7 +34,7 @@ func NewMiMC(api frontend.API) (MiMC, error) { // NB! See the package documentation for length extension attack consideration. // // [gnark-crypto]: https://pkg.go.dev/github.com/consensys/gnark-crypto/hash -func New(api frontend.API) (hash.FieldHasher, error) { +func New(api frontend.API) (hash.StateStorer, error) { h, err := NewMiMC(api) if err != nil { return nil, err @@ -43,5 +43,7 @@ func New(api frontend.API) (hash.FieldHasher, error) { } func init() { - hash.Register(hash.MIMC, New) + hash.Register(hash.MIMC, func(api frontend.API) (hash.FieldHasher, error) { + return New(api) + }) } diff --git a/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go b/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go new file mode 100644 index 0000000000..bbbef1f87c --- /dev/null +++ b/std/hash/poseidon2/gkr-poseidon2/gkr-poseidon2.go @@ -0,0 +1,18 @@ +package gkr_poseidon2 + +import ( + "fmt" + + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/hash" + _ "github.com/consensys/gnark/std/hash/all" + gkr_poseidon2 "github.com/consensys/gnark/std/permutation/poseidon2/gkr-poseidon2" +) + +func New(api frontend.API) (hash.StateStorer, error) { + f, err := gkr_poseidon2.NewCompressor(api) + if err != nil { + return nil, fmt.Errorf("could not create poseidon2 hasher: %w", err) + } + return hash.NewMerkleDamgardHasher(api, f, 0), nil +} diff --git a/std/hash/poseidon2/poseidon2.go b/std/hash/poseidon2/poseidon2.go index 804740ff7c..e15c4ca587 100644 --- a/std/hash/poseidon2/poseidon2.go +++ b/std/hash/poseidon2/poseidon2.go @@ -8,9 +8,9 @@ import ( "github.com/consensys/gnark/std/permutation/poseidon2" ) -// NewMerkleDamgardHasher returns a Poseidon2 hasher using the Merkle-Damgard +// New returns a Poseidon2 hasher using the Merkle-Damgard // construction with the default parameters. -func NewMerkleDamgardHasher(api frontend.API) (hash.FieldHasher, error) { +func New(api frontend.API) (hash.StateStorer, error) { f, err := poseidon2.NewPoseidon2(api) if err != nil { return nil, fmt.Errorf("could not create poseidon2 hasher: %w", err) @@ -19,5 +19,7 @@ func NewMerkleDamgardHasher(api frontend.API) (hash.FieldHasher, error) { } func init() { - hash.Register(hash.POSEIDON2, NewMerkleDamgardHasher) + hash.Register(hash.POSEIDON2, func(api frontend.API) (hash.FieldHasher, error) { + return New(api) + }) } diff --git a/std/hash/poseidon2/poseidon2_test.go b/std/hash/poseidon2/poseidon2_test.go index 1ce1d46fef..f6c57736df 100644 --- a/std/hash/poseidon2/poseidon2_test.go +++ b/std/hash/poseidon2/poseidon2_test.go @@ -1,26 +1,34 @@ -package poseidon2 +package poseidon2_test import ( "testing" "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" + poseidonbls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/std/hash/poseidon2" + gkr_poseidon2 "github.com/consensys/gnark/std/hash/poseidon2/gkr-poseidon2" "github.com/consensys/gnark/test" ) -type Poseidon2Circuit struct { +type poseidon2Circuit struct { Input []frontend.Variable Expected frontend.Variable `gnark:",public"` } -func (c *Poseidon2Circuit) Define(api frontend.API) error { - hsh, err := NewMerkleDamgardHasher(api) +func (c *poseidon2Circuit) Define(api frontend.API) error { + hsh, err := poseidon2.New(api) + if err != nil { + return err + } + gkr, err := gkr_poseidon2.New(api) if err != nil { return err } hsh.Write(c.Input...) api.AssertIsEqual(hsh.Sum(), c.Expected) + gkr.Write(c.Input...) + api.AssertIsEqual(gkr.Sum(), c.Expected) return nil } @@ -29,7 +37,7 @@ func TestPoseidon2Hash(t *testing.T) { const nbInputs = 5 // prepare expected output - h := poseidon2.NewMerkleDamgardHasher() + h := poseidonbls12377.NewMerkleDamgardHasher() circInput := make([]frontend.Variable, nbInputs) for i := range nbInputs { _, err := h.Write([]byte{byte(i)}) @@ -37,5 +45,59 @@ func TestPoseidon2Hash(t *testing.T) { circInput[i] = i } res := h.Sum(nil) - assert.CheckCircuit(&Poseidon2Circuit{Input: make([]frontend.Variable, nbInputs)}, test.WithValidAssignment(&Poseidon2Circuit{Input: circInput, Expected: res}), test.WithCurves(ecc.BLS12_377)) // we have parametrized currently only for BLS12-377 + assert.CheckCircuit(&poseidon2Circuit{Input: make([]frontend.Variable, nbInputs)}, test.WithValidAssignment(&poseidon2Circuit{Input: circInput, Expected: res}), test.WithCurves(ecc.BLS12_377)) // we have parametrized currently only for BLS12-377 +} + +func TestStateStorer(t *testing.T) { + assignment := testStateStorerCircuit{ + Input: [][]frontend.Variable{ + {0, 1, 2, 3, 4}, + }, + } + + circuit := testStateStorerCircuit{ + Input: make([][]frontend.Variable, len(assignment.Input)), + } + for i := range assignment.Input { + circuit.Input[i] = make([]frontend.Variable, len(assignment.Input[i])) + } + + test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment)) +} + +type testStateStorerCircuit struct { + Input [][]frontend.Variable +} + +func (c *testStateStorerCircuit) Define(api frontend.API) error { + // hashes the whole input in one go + hshFull, err := poseidon2.New(api) + if err != nil { + return err + } + + // hashes the input in two parts + hshPartial, err := poseidon2.New(api) + if err != nil { + return err + } + + for _, input := range c.Input { + // compute desired output + hshFull.Reset() + hshFull.Write(input...) + digest := hshFull.Sum() + + hshPartial.Reset() + hshPartial.Write(input[:len(input)/2]...) + state := hshPartial.State() + hshPartial.Reset() + api.AssertIsEqual(hshPartial.State()[0], 0) + if err = hshPartial.SetState(state); err != nil { + return err + } + hshPartial.Write(input[len(input)/2:]...) + api.AssertIsEqual(hshPartial.Sum(), digest) + } + return nil } diff --git a/std/internal/mimc/encrypt.go b/std/internal/mimc/encrypt.go index 0d45a81506..9c499be976 100644 --- a/std/internal/mimc/encrypt.go +++ b/std/internal/mimc/encrypt.go @@ -106,7 +106,7 @@ func newMimcBW633(api frontend.API) MiMC { } // ------------------------------------------------------------------------------------------------- -// encryptions functions +// encryption functions func pow5(api frontend.API, x frontend.Variable) frontend.Variable { r := api.Mul(x, x) diff --git a/std/permutation/gkr-mimc/gkr-mimc.go b/std/permutation/gkr-mimc/gkr-mimc.go new file mode 100644 index 0000000000..266ee00e67 --- /dev/null +++ b/std/permutation/gkr-mimc/gkr-mimc.go @@ -0,0 +1,235 @@ +package gkr_mimc + +import ( + "fmt" + "math/big" + + "github.com/consensys/gnark-crypto/ecc" + bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/mimc" + bls12381 "github.com/consensys/gnark-crypto/ecc/bls12-381/fr/mimc" + bls24315 "github.com/consensys/gnark-crypto/ecc/bls24-315/fr/mimc" + bls24317 "github.com/consensys/gnark-crypto/ecc/bls24-317/fr/mimc" + bn254 "github.com/consensys/gnark-crypto/ecc/bn254/fr/mimc" + bw6633 "github.com/consensys/gnark-crypto/ecc/bw6-633/fr/mimc" + bw6761 "github.com/consensys/gnark-crypto/ecc/bw6-761/fr/mimc" + "github.com/consensys/gnark/constraint/solver/gkrgates" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/internal/kvstore" + "github.com/consensys/gnark/internal/utils" + "github.com/consensys/gnark/std/gkrapi" + "github.com/consensys/gnark/std/gkrapi/gkr" + "github.com/consensys/gnark/std/hash" + _ "github.com/consensys/gnark/std/hash/all" +) + +// compressor implements a compression function by applying +// the Miyaguchi–Preneel transformation to the MiMC encryption function. +type compressor struct { + gkrCircuit *gkrapi.Circuit + in0, in1, out gkr.Variable +} + +func (c *compressor) Compress(x frontend.Variable, y frontend.Variable) frontend.Variable { + res, err := c.gkrCircuit.AddInstance(map[gkr.Variable]frontend.Variable{c.in0: x, c.in1: y}) + if err != nil { + panic(err) + } + return res[c.out] +} + +func NewCompressor(api frontend.API) (hash.Compressor, error) { + + store, ok := api.(kvstore.Store) + if !ok { + return nil, fmt.Errorf("api of type %T does not implement kvstore.Store", api) + } + + cached := store.GetKeyValue(gkrMiMCKey{}) + if cached != nil { + if compressor, ok := cached.(*compressor); ok { + return compressor, nil + } + return nil, fmt.Errorf("cached value is of type %T, not a compressor", cached) + } + + gkrApi := gkrapi.New() + + in0 := gkrApi.NewInput() + in1 := gkrApi.NewInput() + + y := in1 + + curve := utils.FieldToCurve(api.Compiler().Field()) + params, _, err := getParams(curve) // params is only used for its length + if err != nil { + return nil, err + } + if err = RegisterGates(curve); err != nil { + return nil, err + } + gateNamer := newGateNamer(curve) + + for i := range len(params) - 1 { + y = gkrApi.NamedGate(gateNamer.round(i), in0, y) + } + + y = gkrApi.NamedGate(gateNamer.round(len(params)-1), in0, y, in1) + + res := + &compressor{ + gkrCircuit: gkrApi.Compile(api, "POSEIDON2"), + in0: in0, + in1: in1, + out: y, + } + + store.SetKeyValue(gkrMiMCKey{}, res) + return res, nil +} + +func RegisterGates(curves ...ecc.ID) error { + for _, curve := range curves { + constants, deg, err := getParams(curve) + if err != nil { + return err + } + gateNamer := newGateNamer(curve) + var lastLayerSBox, nonLastLayerSBox func(*big.Int) gkr.GateFunction + switch deg { + case 5: + lastLayerSBox = addPow5Add + nonLastLayerSBox = addPow5 + case 7: + lastLayerSBox = addPow7Add + nonLastLayerSBox = addPow7 + case 17: + lastLayerSBox = addPow17Add + nonLastLayerSBox = addPow17 + default: + return fmt.Errorf("s-Box of degree %d not supported", deg) + } + + for i := range len(constants) - 1 { + if _, err = gkrgates.Register(nonLastLayerSBox(&constants[i]), 2, gkrgates.WithName(gateNamer.round(i)), gkrgates.WithUnverifiedDegree(deg), gkrgates.WithCurves(curve)); err != nil { + return fmt.Errorf("failed to register keyed GKR gate for round %d of MiMC on curve %s: %w", i, curve, err) + } + } + + if _, err = gkrgates.Register(lastLayerSBox(&constants[len(constants)-1]), 3, gkrgates.WithName(gateNamer.round(len(constants)-1)), gkrgates.WithUnverifiedDegree(deg), gkrgates.WithCurves(curve)); err != nil { + return fmt.Errorf("failed to register keyed GKR gate for round %d of MiMC on curve %s: %w", len(constants)-1, curve, err) + } + } + return nil +} + +// getParams returns the parameters for the MiMC encryption function for the given curve. +// It also returns the degree of the s-Box +func getParams(curve ecc.ID) ([]big.Int, int, error) { + switch curve { + case ecc.BN254: + return bn254.GetConstants(), 5, nil + case ecc.BLS12_381: + return bls12381.GetConstants(), 5, nil + case ecc.BLS12_377: + return bls12377.GetConstants(), 17, nil + case ecc.BLS24_315: + return bls24315.GetConstants(), 5, nil + case ecc.BLS24_317: + return bls24317.GetConstants(), 7, nil + case ecc.BW6_633: + return bw6633.GetConstants(), 5, nil + case ecc.BW6_761: + return bw6761.GetConstants(), 5, nil + default: + return nil, -1, fmt.Errorf("unsupported curve ID: %s", curve) + } +} + +type gateNamer string + +func newGateNamer(o fmt.Stringer) gateNamer { + return gateNamer("MiMC-" + o.String() + "-round-") +} +func (n gateNamer) round(i int) gkr.GateName { + return gkr.GateName(fmt.Sprintf("%s%d", string(n), i)) +} + +func addPow5(key *big.Int) gkr.GateFunction { + return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { + if len(in) != 2 { + panic("expected two input") + } + s := api.Add(in[0], in[1], key) + t := api.Mul(s, s) + return api.Mul(t, t, s) + } +} + +// addPow5Add: (in[0]+in[1]+key)⁵ + 2*in[0] + in[2] +func addPow5Add(key *big.Int) gkr.GateFunction { + return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { + if len(in) != 3 { + panic("expected three input") + } + s := api.Add(in[0], in[1], key) + t := api.Mul(s, s) + t = api.Mul(t, t, s) + + return api.Add(t, in[0], in[0], in[2]) + } +} + +func addPow7(key *big.Int) gkr.GateFunction { + return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { + if len(in) != 2 { + panic("expected two input") + } + s := api.Add(in[0], in[1], key) + t := api.Mul(s, s) + return api.Mul(t, t, t, s) // s⁶ × s + } +} + +// addPow7Add: (in[0]+in[1]+key)⁷ + 2*in[0] + in[2] +func addPow7Add(key *big.Int) gkr.GateFunction { + return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { + if len(in) != 3 { + panic("expected three input") + } + s := api.Add(in[0], in[1], key) + t := api.Mul(s, s) + return api.Add(api.Mul(t, t, t, s), in[0], in[0], in[2]) // s⁶ × s + 2*in[0] + in[2] + } +} + +// addPow17: (in[0]+in[1]+key)¹⁷ +func addPow17(key *big.Int) gkr.GateFunction { + return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { + if len(in) != 2 { + panic("expected two input") + } + s := api.Add(in[0], in[1], key) + t := api.Mul(s, s) // s² + t = api.Mul(t, t) // s⁴ + t = api.Mul(t, t) // s⁸ + t = api.Mul(t, t) // s¹⁶ + return api.Mul(t, s) // s¹⁶ × s + } +} + +// addPow17Add: (in[0]+in[1]+key)¹⁷ + in[0] + in[2] +func addPow17Add(key *big.Int) gkr.GateFunction { + return func(api gkr.GateAPI, in ...frontend.Variable) frontend.Variable { + if len(in) != 3 { + panic("expected three input") + } + s := api.Add(in[0], in[1], key) + t := api.Mul(s, s) // s² + t = api.Mul(t, t) // s⁴ + t = api.Mul(t, t) // s⁸ + t = api.Mul(t, t) // s¹⁶ + return api.Add(api.Mul(t, s), in[0], in[0], in[2]) // s¹⁶ × s + 2*in[0] + in[2] + } +} + +type gkrMiMCKey struct{} diff --git a/std/permutation/gkr-mimc/gkr-mimc_test.go b/std/permutation/gkr-mimc/gkr-mimc_test.go new file mode 100644 index 0000000000..93143b1279 --- /dev/null +++ b/std/permutation/gkr-mimc/gkr-mimc_test.go @@ -0,0 +1,70 @@ +package gkr_mimc + +import ( + "errors" + "slices" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/scs" + "github.com/stretchr/testify/require" +) + +type hashTreeCircuit struct { + Leaves []frontend.Variable +} + +func (c hashTreeCircuit) Define(api frontend.API) error { + if len(c.Leaves) == 0 { + return errors.New("no hashing to do") + } + + hsh, err := NewCompressor(api) + if err != nil { + return err + } + + layer := slices.Clone(c.Leaves) + + for len(layer) > 1 { + if len(layer)%2 == 1 { + layer = append(layer, 0) // pad with zero + } + + for i := range len(layer) / 2 { + layer[i] = hsh.Compress(layer[2*i], layer[2*i+1]) + } + + layer = layer[:len(layer)/2] + } + + api.AssertIsDifferent(layer[0], 0) + return nil +} + +func BenchmarkGkrPermutations(b *testing.B) { + circuit, assignment := hashTreeCircuits(50000) + + cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &circuit) + require.NoError(b, err) + + witness, err := frontend.NewWitness(&assignment, ecc.BLS12_377.ScalarField()) + require.NoError(b, err) + + _, err = cs.Solve(witness) + require.NoError(b, err) +} + +func hashTreeCircuits(n int) (circuit, assignment hashTreeCircuit) { + leaves := make([]frontend.Variable, n) + for i := range n { + leaves[i] = i + } + + return hashTreeCircuit{ + Leaves: make([]frontend.Variable, len(leaves)), + }, hashTreeCircuit{ + Leaves: leaves, + } +} diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr.go b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2.go similarity index 68% rename from std/permutation/poseidon2/gkr-poseidon2/gkr.go rename to std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2.go index 800e79ed05..4bb0b7a6a7 100644 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr.go +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2.go @@ -1,15 +1,18 @@ package gkr_poseidon2 import ( + "errors" "fmt" - "sync" "github.com/consensys/gnark/constraint/solver/gkrgates" + "github.com/consensys/gnark/internal/kvstore" + "github.com/consensys/gnark/internal/utils" "github.com/consensys/gnark/std/gkrapi" "github.com/consensys/gnark/std/gkrapi/gkr" + "github.com/consensys/gnark/std/hash" + "github.com/consensys/gnark/std/permutation/poseidon2" "github.com/consensys/gnark-crypto/ecc" - poseidon2Bls12377 "github.com/consensys/gnark-crypto/ecc/bls12-377/fr/poseidon2" "github.com/consensys/gnark/frontend" ) @@ -46,6 +49,14 @@ func pow4TimesGate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { return api.Mul(y, x[1]) } +// pow3Gate computes a -> a³ +func pow3Gate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { + if len(x) != 1 { + panic("expected 1 input") + } + return api.Mul(x[0], x[0], x[0]) +} + // pow2Gate computes a -> a² func pow2Gate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { if len(x) != 1 { @@ -108,34 +119,46 @@ func extAddGate(api gkr.GateAPI, x ...frontend.Variable) frontend.Variable { return api.Add(api.Mul(x[0], 2), x[1], x[2]) } -type GkrCompressor struct { +type compressor struct { api frontend.API gkrCircuit *gkrapi.Circuit in1, in2, out gkr.Variable } -// NewGkrCompressor returns an object that can compute the Poseidon2 compression function (currently only for BLS12-377) +// NewCompressor returns an object that can compute the Poseidon2 compression function (currently only for BLS12-377) // which consists of a permutation along with the input fed forward. // The correctness of the compression functions is proven using GKR. -// Note that the solver will need the function RegisterGkrGates to be called with the desired curves -func NewGkrCompressor(api frontend.API) *GkrCompressor { - if api.Compiler().Field().Cmp(ecc.BLS12_377.ScalarField()) != 0 { - panic("currently only BL12-377 is supported") +// Note that the solver will need the function RegisterGates to be called with the desired curves +func NewCompressor(api frontend.API) (hash.Compressor, error) { + store, ok := api.(kvstore.Store) + if !ok { + return nil, fmt.Errorf("api of type %T does not implement kvstore.Store", api) } - gkrApi, in1, in2, out, err := defineCircuitBls12377() + + cached := store.GetKeyValue(gkrPoseidon2Key{}) + if cached != nil { + if compressor, ok := cached.(*compressor); ok { + return compressor, nil + } + return nil, fmt.Errorf("cached value is of type %T, not a mimcCompressor", cached) + } + + gkrCircuit, in1, in2, out, err := defineCircuit(api) if err != nil { - panic(fmt.Errorf("failed to define GKR circuit: %v", err)) + return nil, fmt.Errorf("failed to define GKR circuit: %w", err) } - return &GkrCompressor{ + res := &compressor{ api: api, - gkrCircuit: gkrApi.Compile(api, "MIMC"), + gkrCircuit: gkrCircuit, in1: in1, in2: in2, out: out, } + store.SetKeyValue(gkrPoseidon2Key{}, res) + return res, nil } -func (p *GkrCompressor) Compress(a, b frontend.Variable) frontend.Variable { +func (p *compressor) Compress(a, b frontend.Variable) frontend.Variable { outs, err := p.gkrCircuit.AddInstance(map[gkr.Variable]frontend.Variable{p.in1: a, p.in2: b}) if err != nil { panic(err) @@ -144,27 +167,28 @@ func (p *GkrCompressor) Compress(a, b frontend.Variable) frontend.Variable { return outs[p.out] } -// defineCircuitBls12377 defines the GKR circuit for the Poseidon2 permutation over BLS12-377 +// defineCircuit defines the GKR circuit for the Poseidon2 permutation over BLS12-377 // insLeft and insRight are the inputs to the permutation // they must be padded to a power of 2 -func defineCircuitBls12377() (gkrApi *gkrapi.API, in1, in2, out gkr.Variable, err error) { +func defineCircuit(api frontend.API) (gkrCircuit *gkrapi.Circuit, in1, in2, out gkr.Variable, err error) { // variable indexes const ( xI = iota yI ) - if err = registerGatesBls12377(); err != nil { + curve := utils.FieldToCurve(api.Compiler().Field()) + p, err := poseidon2.GetDefaultParameters(curve) + if err != nil { return } + gateNamer := newRoundGateNamer(&p, curve) - // poseidon2 parameters - gateNamer := newRoundGateNamer(poseidon2Bls12377.GetDefaultParameters()) - rF := poseidon2Bls12377.GetDefaultParameters().NbFullRounds - rP := poseidon2Bls12377.GetDefaultParameters().NbPartialRounds - halfRf := rF / 2 + if err = registerGates(&p, curve); err != nil { + return + } - gkrApi = gkrapi.New() + gkrApi := gkrapi.New() x := gkrApi.NewInput() y := gkrApi.NewInput() @@ -180,10 +204,27 @@ func defineCircuitBls12377() (gkrApi *gkrapi.API, in1, in2, out gkr.Variable, er // in every round comes from the previous (canonical) round. // apply the s-Box to u - // the s-Box gates: u¹⁷ = (u⁴)⁴ * u - sBox := func(u gkr.Variable) gkr.Variable { - v := gkrApi.Gate(pow4Gate, u) // u⁴ - return gkrApi.Gate(pow4TimesGate, v, u) // u¹⁷ + + var sBox func(gkr.Variable) gkr.Variable + switch p.DegreeSBox { + case 5: + sBox = func(u gkr.Variable) gkr.Variable { + v := gkrApi.Gate(pow2Gate, u) // u² + return gkrApi.Gate(pow2TimesGate, v, u) // u⁵ + } + case 7: + sBox = func(u gkr.Variable) gkr.Variable { + v := gkrApi.Gate(pow3Gate, u) // u³ + return gkrApi.Gate(pow2TimesGate, v, u) // u⁷ + } + case 17: + sBox = func(u gkr.Variable) gkr.Variable { + v := gkrApi.Gate(pow4Gate, u) // u⁴ + return gkrApi.Gate(pow4TimesGate, v, u) // u¹⁷ + } + default: + err = fmt.Errorf("unsupported s-Box degree %d", p.DegreeSBox) + return } // apply external matrix multiplication and round key addition @@ -208,97 +249,76 @@ func defineCircuitBls12377() (gkrApi *gkrapi.API, in1, in2, out gkr.Variable, er // *** construct the circuit *** - for i := range halfRf { + for i := range p.NbFullRounds / 2 { fullRound(i) } { // i = halfRf: first partial round // still using the external matrix, since the linear operation still belongs to a full (canonical) round - x1 := extKeySBox(halfRf, xI, x, y) + x1 := extKeySBox(p.NbFullRounds/2, xI, x, y) x, y = x1, gkrApi.Gate(extGate2, x, y) } - for i := halfRf + 1; i < halfRf+rP; i++ { + for i := p.NbFullRounds/2 + 1; i < p.NbFullRounds/2+p.NbPartialRounds; i++ { x1 := extKeySBox(i, xI, x, y) // the first row of the internal matrix is the same as that of the external matrix x, y = x1, gkrApi.Gate(intGate2, x, y) } { - i := halfRf + rP + i := p.NbFullRounds/2 + p.NbPartialRounds // first iteration of the final batch of full rounds // still using the internal matrix, since the linear operation still belongs to a partial (canonical) round x1 := extKeySBox(i, xI, x, y) x, y = x1, intKeySBox2(i, x, y) } - for i := halfRf + rP + 1; i < rP+rF; i++ { + for i := p.NbFullRounds/2 + p.NbPartialRounds + 1; i < p.NbPartialRounds+p.NbFullRounds; i++ { fullRound(i) } // apply the external matrix one last time to obtain the final value of y - out = gkrApi.NamedGate(gateNamer.linear(yI, rP+rF), y, x, in2) + out = gkrApi.Gate(extAddGate, y, x, in2) + + gkrCircuit = gkrApi.Compile(api, "MIMC") return } -var bls12377Permutation = sync.OnceValue(func() *poseidon2Bls12377.Permutation { - params := poseidon2Bls12377.GetDefaultParameters() - return poseidon2Bls12377.NewPermutation(2, params.NbFullRounds, params.NbPartialRounds) // TODO @Tabaie add NewDefaultPermutation to gnark-crypto -}) - -// RegisterGkrGates registers the GKR gates corresponding to the given curves for the solver -func RegisterGkrGates(curves ...ecc.ID) { +// RegisterGates registers the GKR gates corresponding to the given curves for the solver. +func RegisterGates(curves ...ecc.ID) error { if len(curves) == 0 { - panic("expected at least one curve") + return errors.New("expected at least one curve") } for _, curve := range curves { - switch curve { - case ecc.BLS12_377: - if err := registerGatesBls12377(); err != nil { - panic(err) - } - default: - panic(fmt.Sprintf("curve %s not currently supported", curve)) + p, err := poseidon2.GetDefaultParameters(curve) + if err != nil { + return fmt.Errorf("failed to get default parameters for curve %s: %w", curve, err) + } + if err = registerGates(&p, curve); err != nil { + return fmt.Errorf("failed to register gates for curve %s: %w", curve, err) } } + return nil } -func registerGatesBls12377() error { +func registerGates(p *poseidon2.Parameters, curve ecc.ID) error { const ( x = iota y ) - p := poseidon2Bls12377.GetDefaultParameters() + gateNames := newRoundGateNamer(p, curve) halfRf := p.NbFullRounds / 2 - gateNames := newRoundGateNamer(p) - - if _, err := gkrgates.Register(pow2Gate, 1, gkrgates.WithUnverifiedDegree(2), gkrgates.WithNoSolvableVar(), gkrgates.WithCurves(ecc.BLS12_377)); err != nil { - return err - } - if _, err := gkrgates.Register(pow4Gate, 1, gkrgates.WithUnverifiedDegree(4), gkrgates.WithNoSolvableVar(), gkrgates.WithCurves(ecc.BLS12_377)); err != nil { - return err - } - if _, err := gkrgates.Register(pow2TimesGate, 2, gkrgates.WithUnverifiedDegree(3), gkrgates.WithNoSolvableVar(), gkrgates.WithCurves(ecc.BLS12_377)); err != nil { - return err - } - if _, err := gkrgates.Register(pow4TimesGate, 2, gkrgates.WithUnverifiedDegree(5), gkrgates.WithNoSolvableVar(), gkrgates.WithCurves(ecc.BLS12_377)); err != nil { - return err - } - - if _, err := gkrgates.Register(intGate2, 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithCurves(ecc.BLS12_377)); err != nil { - return err - } extKeySBox := func(round int, varIndex int) error { - _, err := gkrgates.Register(extKeyGate(&p.RoundKeys[round][varIndex]), 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(varIndex, round)), gkrgates.WithCurves(ecc.BLS12_377)) + _, err := gkrgates.Register(extKeyGate(&p.RoundKeys[round][varIndex]), 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(varIndex, round)), gkrgates.WithCurves(curve)) return err } intKeySBox2 := func(round int) error { - _, err := gkrgates.Register(intKeyGate2(&p.RoundKeys[round][1]), 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(y, round)), gkrgates.WithCurves(ecc.BLS12_377)) + _, err := gkrgates.Register(intKeyGate2(&p.RoundKeys[round][1]), 2, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(y, round)), gkrgates.WithCurves(curve)) return err } @@ -343,15 +363,14 @@ func registerGatesBls12377() error { } } - _, err := gkrgates.Register(extAddGate, 3, gkrgates.WithUnverifiedDegree(1), gkrgates.WithUnverifiedSolvableVar(0), gkrgates.WithName(gateNames.linear(y, p.NbPartialRounds+p.NbFullRounds)), gkrgates.WithCurves(ecc.BLS12_377)) - return err + return nil } type roundGateNamer string // newRoundGateNamer returns an object that returns standardized names for gates in the GKR circuit -func newRoundGateNamer(p fmt.Stringer) roundGateNamer { - return roundGateNamer(p.String()) +func newRoundGateNamer(p *poseidon2.Parameters, curve ecc.ID) roundGateNamer { + return roundGateNamer(fmt.Sprintf("Poseidon2-%s[t=%d,rF=%d,rP=%d,d=%d]", curve.String(), p.Width, p.NbFullRounds, p.NbPartialRounds, p.DegreeSBox)) } // linear is the name of a gate where a polynomial of total degree 1 is applied to the input @@ -363,3 +382,5 @@ func (n roundGateNamer) linear(varIndex, round int) gkr.GateName { func (n roundGateNamer) integrated(varIndex, round int) gkr.GateName { return gkr.GateName(fmt.Sprintf("x%d-i-op-round=%d;%s", varIndex, round, n)) } + +type gkrPoseidon2Key struct{} diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go new file mode 100644 index 0000000000..b224bf1414 --- /dev/null +++ b/std/permutation/poseidon2/gkr-poseidon2/gkr-poseidon2_test.go @@ -0,0 +1,81 @@ +package gkr_poseidon2 + +import ( + "fmt" + "testing" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/scs" + _ "github.com/consensys/gnark/std/hash/all" + "github.com/consensys/gnark/std/permutation/poseidon2" + "github.com/consensys/gnark/test" + "github.com/stretchr/testify/require" +) + +func gkrCompressionsCircuits(n int) (circuit, assignment testGkrCompressionCircuit) { + ins := make([][2]frontend.Variable, n) + for i := range n { + ins[i] = [2]frontend.Variable{i * 2, i*2 + 1} + } + + return testGkrCompressionCircuit{ + Ins: make([][2]frontend.Variable, len(ins)), + }, testGkrCompressionCircuit{ + Ins: ins, + } +} + +func TestGkrCompression(t *testing.T) { + circuit, assignment := gkrCompressionsCircuits(2) + + test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment)) +} + +type testGkrCompressionCircuit struct { + Ins [][2]frontend.Variable + skipCheck bool +} + +func (c *testGkrCompressionCircuit) Define(api frontend.API) error { + + gkr, err := NewCompressor(api) + if err != nil { + return err + } + pos2, err := poseidon2.NewPoseidon2(api) + if err != nil { + return err + } + for i := range c.Ins { + fromGkr := gkr.Compress(c.Ins[i][0], c.Ins[i][1]) + if !c.skipCheck { + api.AssertIsEqual(pos2.Compress(c.Ins[i][0], c.Ins[i][1]), fromGkr) + } + } + + return nil +} + +func TestGkrCompressionCompiles(t *testing.T) { + // just measure the number of constraints + cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &testGkrCompressionCircuit{ + Ins: make([][2]frontend.Variable, 52000), + skipCheck: true, + }) + require.NoError(t, err) + fmt.Println(cs.GetNbConstraints(), "constraints") +} + +func BenchmarkGkrCompressions(b *testing.B) { + circuit, assignment := gkrCompressionsCircuits(50000) + + cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &circuit) + require.NoError(b, err) + + witness, err := frontend.NewWitness(&assignment, ecc.BLS12_377.ScalarField()) + require.NoError(b, err) + + _, err = cs.Solve(witness) + require.NoError(b, err) +} diff --git a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go b/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go deleted file mode 100644 index 0a230c4381..0000000000 --- a/std/permutation/poseidon2/gkr-poseidon2/gkr_test.go +++ /dev/null @@ -1,74 +0,0 @@ -package gkr_poseidon2 - -import ( - "fmt" - "testing" - - "github.com/consensys/gnark-crypto/ecc" - "github.com/consensys/gnark-crypto/ecc/bls12-377/fr" - "github.com/consensys/gnark/frontend" - "github.com/consensys/gnark/frontend/cs/scs" - _ "github.com/consensys/gnark/std/hash/all" - "github.com/consensys/gnark/test" - "github.com/stretchr/testify/require" -) - -func gkrCompressionCircuits(t require.TestingT, n int) (circuit, assignment testGkrCompressionCircuit) { - var k int64 - ins := make([][2]frontend.Variable, n) - outs := make([]frontend.Variable, n) - for i := range n { - var x [2]fr.Element - ins[i] = [2]frontend.Variable{k, k + 1} - - x[0].SetInt64(k) - x[1].SetInt64(k + 1) - y0 := x[1] - - require.NoError(t, bls12377Permutation().Permutation(x[:])) - x[1].Add(&x[1], &y0) - outs[i] = x[1] - - k += 2 - } - - return testGkrCompressionCircuit{ - Ins: make([][2]frontend.Variable, len(ins)), - Outs: make([]frontend.Variable, len(outs)), - }, testGkrCompressionCircuit{ - Ins: ins, - Outs: outs, - } -} - -func TestGkrCompression(t *testing.T) { - circuit, assignment := gkrCompressionCircuits(t, 2) - - test.NewAssert(t).CheckCircuit(&circuit, test.WithValidAssignment(&assignment), test.WithCurves(ecc.BLS12_377)) -} - -type testGkrCompressionCircuit struct { - Ins [][2]frontend.Variable - Outs []frontend.Variable -} - -func (c *testGkrCompressionCircuit) Define(api frontend.API) error { - - pos2 := NewGkrCompressor(api) - api.AssertIsEqual(len(c.Ins), len(c.Outs)) - for i := range c.Ins { - api.AssertIsEqual(c.Outs[i], pos2.Compress(c.Ins[i][0], c.Ins[i][1])) - } - - return nil -} - -func TestGkrCompressionCompiles(t *testing.T) { - // just measure the number of constraints - cs, err := frontend.Compile(ecc.BLS12_377.ScalarField(), scs.NewBuilder, &testGkrCompressionCircuit{ - Ins: make([][2]frontend.Variable, 52000), - Outs: make([]frontend.Variable, 52000), - }) - require.NoError(t, err) - fmt.Println(cs.GetNbConstraints(), "constraints") -} diff --git a/std/permutation/poseidon2/poseidon2.go b/std/permutation/poseidon2/poseidon2.go index 55afe73be5..317cc977dd 100644 --- a/std/permutation/poseidon2/poseidon2.go +++ b/std/permutation/poseidon2/poseidon2.go @@ -23,38 +23,157 @@ var ( type Permutation struct { api frontend.API - params parameters + params Parameters } -// parameters describing the poseidon2 implementation -type parameters struct { +// Parameters describing the poseidon2 implementation +type Parameters struct { // len(preimage)+len(digest)=len(preimage)+ceil(log(2*/r)) - width int + Width int // sbox degree - degreeSBox int + DegreeSBox int // number of full rounds (even number) - nbFullRounds int + NbFullRounds int // number of partial rounds - nbPartialRounds int + NbPartialRounds int // round keys: ordered by round then variable - roundKeys [][]big.Int + RoundKeys [][]big.Int +} + +func GetDefaultParameters(curve ecc.ID) (Parameters, error) { + switch curve { // TODO: assumes pairing based builder, reconsider when supporting other backends + case ecc.BN254: + p := poseidonbn254.GetDefaultParameters() + res := Parameters{ + Width: p.Width, + DegreeSBox: poseidonbn254.DegreeSBox(), + NbFullRounds: p.NbFullRounds, + NbPartialRounds: p.NbPartialRounds, + RoundKeys: make([][]big.Int, len(p.RoundKeys)), + } + for i := range res.RoundKeys { + res.RoundKeys[i] = make([]big.Int, len(p.RoundKeys[i])) + for j := range res.RoundKeys[i] { + p.RoundKeys[i][j].BigInt(&res.RoundKeys[i][j]) + } + } + return res, nil + case ecc.BLS12_381: + p := poseidonbls12381.GetDefaultParameters() + res := Parameters{ + Width: p.Width, + DegreeSBox: poseidonbls12381.DegreeSBox(), + NbFullRounds: p.NbFullRounds, + NbPartialRounds: p.NbPartialRounds, + RoundKeys: make([][]big.Int, len(p.RoundKeys)), + } + for i := range res.RoundKeys { + res.RoundKeys[i] = make([]big.Int, len(p.RoundKeys[i])) + for j := range res.RoundKeys[i] { + p.RoundKeys[i][j].BigInt(&res.RoundKeys[i][j]) + } + } + return res, nil + case ecc.BLS12_377: + p := poseidonbls12377.GetDefaultParameters() + res := Parameters{ + Width: p.Width, + DegreeSBox: poseidonbls12377.DegreeSBox(), + NbFullRounds: p.NbFullRounds, + NbPartialRounds: p.NbPartialRounds, + RoundKeys: make([][]big.Int, len(p.RoundKeys)), + } + for i := range res.RoundKeys { + res.RoundKeys[i] = make([]big.Int, len(p.RoundKeys[i])) + for j := range res.RoundKeys[i] { + p.RoundKeys[i][j].BigInt(&res.RoundKeys[i][j]) + } + } + return res, nil + case ecc.BW6_761: + p := poseidonbw6761.GetDefaultParameters() + res := Parameters{ + Width: p.Width, + DegreeSBox: poseidonbw6761.DegreeSBox(), + NbFullRounds: p.NbFullRounds, + NbPartialRounds: p.NbPartialRounds, + RoundKeys: make([][]big.Int, len(p.RoundKeys)), + } + for i := range res.RoundKeys { + res.RoundKeys[i] = make([]big.Int, len(p.RoundKeys[i])) + for j := range res.RoundKeys[i] { + p.RoundKeys[i][j].BigInt(&res.RoundKeys[i][j]) + } + } + return res, nil + case ecc.BW6_633: + p := poseidonbw6633.GetDefaultParameters() + res := Parameters{ + Width: p.Width, + DegreeSBox: poseidonbw6633.DegreeSBox(), + NbFullRounds: p.NbFullRounds, + NbPartialRounds: p.NbPartialRounds, + RoundKeys: make([][]big.Int, len(p.RoundKeys)), + } + for i := range res.RoundKeys { + res.RoundKeys[i] = make([]big.Int, len(p.RoundKeys[i])) + for j := range res.RoundKeys[i] { + p.RoundKeys[i][j].BigInt(&res.RoundKeys[i][j]) + } + } + return res, nil + case ecc.BLS24_315: + p := poseidonbls24315.GetDefaultParameters() + res := Parameters{ + Width: p.Width, + DegreeSBox: poseidonbls24315.DegreeSBox(), + NbFullRounds: p.NbFullRounds, + NbPartialRounds: p.NbPartialRounds, + RoundKeys: make([][]big.Int, len(p.RoundKeys)), + } + for i := range res.RoundKeys { + res.RoundKeys[i] = make([]big.Int, len(p.RoundKeys[i])) + for j := range res.RoundKeys[i] { + p.RoundKeys[i][j].BigInt(&res.RoundKeys[i][j]) + } + } + return res, nil + case ecc.BLS24_317: + p := poseidonbls24317.GetDefaultParameters() + res := Parameters{ + Width: p.Width, + DegreeSBox: poseidonbls24317.DegreeSBox(), + NbFullRounds: p.NbFullRounds, + NbPartialRounds: p.NbPartialRounds, + RoundKeys: make([][]big.Int, len(p.RoundKeys)), + } + for i := range res.RoundKeys { + res.RoundKeys[i] = make([]big.Int, len(p.RoundKeys[i])) + for j := range res.RoundKeys[i] { + p.RoundKeys[i][j].BigInt(&res.RoundKeys[i][j]) + } + } + return res, nil + default: + return Parameters{}, fmt.Errorf("curve %s not supported", curve) + } } // NewPoseidon2 returns a new Poseidon2 hasher with default parameters as // defined in the gnark-crypto library. func NewPoseidon2(api frontend.API) (*Permutation, error) { - switch utils.FieldToCurve(api.Compiler().Field()) { // TODO: assumes pairing based builder, reconsider when supporting other backends - case ecc.BLS12_377: - params := poseidonbls12377.GetDefaultParameters() - return NewPoseidon2FromParameters(api, 2, params.NbFullRounds, params.NbPartialRounds) - // TODO: we don't have default parameters for other curves yet. Update this when we do. - default: - return nil, fmt.Errorf("field %s not supported", api.Compiler().Field().String()) + params, err := GetDefaultParameters(utils.FieldToCurve(api.Compiler().Field())) + if err != nil { + return nil, err } + return &Permutation{ + api: api, + params: params, + }, nil } // NewPoseidon2FromParameters returns a new Poseidon2 hasher with the given parameters. @@ -62,76 +181,76 @@ func NewPoseidon2(api frontend.API) (*Permutation, error) { // is deterministic and depends on the curve ID. See the corresponding NewParameters // function in the gnark-crypto library poseidon2 packages for more details. func NewPoseidon2FromParameters(api frontend.API, width, nbFullRounds, nbPartialRounds int) (*Permutation, error) { - params := parameters{width: width, nbFullRounds: nbFullRounds, nbPartialRounds: nbPartialRounds} + params := Parameters{Width: width, NbFullRounds: nbFullRounds, NbPartialRounds: nbPartialRounds} switch utils.FieldToCurve(api.Compiler().Field()) { // TODO: assumes pairing based builder, reconsider when supporting other backends case ecc.BN254: - params.degreeSBox = poseidonbn254.DegreeSBox() + params.DegreeSBox = poseidonbn254.DegreeSBox() concreteParams := poseidonbn254.NewParameters(width, nbFullRounds, nbPartialRounds) - params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) - for i := range params.roundKeys { - params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) - for j := range params.roundKeys[i] { - concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) + params.RoundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.RoundKeys { + params.RoundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.RoundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.RoundKeys[i][j]) } } case ecc.BLS12_381: - params.degreeSBox = poseidonbls12381.DegreeSBox() + params.DegreeSBox = poseidonbls12381.DegreeSBox() concreteParams := poseidonbls12381.NewParameters(width, nbFullRounds, nbPartialRounds) - params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) - for i := range params.roundKeys { - params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) - for j := range params.roundKeys[i] { - concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) + params.RoundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.RoundKeys { + params.RoundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.RoundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.RoundKeys[i][j]) } } case ecc.BLS12_377: - params.degreeSBox = poseidonbls12377.DegreeSBox() + params.DegreeSBox = poseidonbls12377.DegreeSBox() concreteParams := poseidonbls12377.NewParameters(width, nbFullRounds, nbPartialRounds) - params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) - for i := range params.roundKeys { - params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) - for j := range params.roundKeys[i] { - concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) + params.RoundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.RoundKeys { + params.RoundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.RoundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.RoundKeys[i][j]) } } case ecc.BW6_761: - params.degreeSBox = poseidonbw6761.DegreeSBox() + params.DegreeSBox = poseidonbw6761.DegreeSBox() concreteParams := poseidonbw6761.NewParameters(width, nbFullRounds, nbPartialRounds) - params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) - for i := range params.roundKeys { - params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) - for j := range params.roundKeys[i] { - concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) + params.RoundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.RoundKeys { + params.RoundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.RoundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.RoundKeys[i][j]) } } case ecc.BW6_633: - params.degreeSBox = poseidonbw6633.DegreeSBox() + params.DegreeSBox = poseidonbw6633.DegreeSBox() concreteParams := poseidonbw6633.NewParameters(width, nbFullRounds, nbPartialRounds) - params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) - for i := range params.roundKeys { - params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) - for j := range params.roundKeys[i] { - concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) + params.RoundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.RoundKeys { + params.RoundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.RoundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.RoundKeys[i][j]) } } case ecc.BLS24_315: - params.degreeSBox = poseidonbls24315.DegreeSBox() + params.DegreeSBox = poseidonbls24315.DegreeSBox() concreteParams := poseidonbls24315.NewParameters(width, nbFullRounds, nbPartialRounds) - params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) - for i := range params.roundKeys { - params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) - for j := range params.roundKeys[i] { - concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) + params.RoundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.RoundKeys { + params.RoundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.RoundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.RoundKeys[i][j]) } } case ecc.BLS24_317: - params.degreeSBox = poseidonbls24317.DegreeSBox() + params.DegreeSBox = poseidonbls24317.DegreeSBox() concreteParams := poseidonbls24317.NewParameters(width, nbFullRounds, nbPartialRounds) - params.roundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) - for i := range params.roundKeys { - params.roundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) - for j := range params.roundKeys[i] { - concreteParams.RoundKeys[i][j].BigInt(¶ms.roundKeys[i][j]) + params.RoundKeys = make([][]big.Int, len(concreteParams.RoundKeys)) + for i := range params.RoundKeys { + params.RoundKeys[i] = make([]big.Int, len(concreteParams.RoundKeys[i])) + for j := range params.RoundKeys[i] { + concreteParams.RoundKeys[i][j].BigInt(¶ms.RoundKeys[i][j]) } } default: @@ -143,25 +262,25 @@ func NewPoseidon2FromParameters(api frontend.API, width, nbFullRounds, nbPartial // sBox applies the sBox on buffer[index] func (h *Permutation) sBox(index int, input []frontend.Variable) { tmp := input[index] - if h.params.degreeSBox == 3 { + if h.params.DegreeSBox == 3 { input[index] = h.api.Mul(input[index], input[index]) input[index] = h.api.Mul(tmp, input[index]) - } else if h.params.degreeSBox == 5 { + } else if h.params.DegreeSBox == 5 { input[index] = h.api.Mul(input[index], input[index]) input[index] = h.api.Mul(input[index], input[index]) input[index] = h.api.Mul(input[index], tmp) - } else if h.params.degreeSBox == 7 { + } else if h.params.DegreeSBox == 7 { input[index] = h.api.Mul(input[index], input[index]) input[index] = h.api.Mul(input[index], tmp) input[index] = h.api.Mul(input[index], input[index]) input[index] = h.api.Mul(input[index], tmp) - } else if h.params.degreeSBox == 17 { + } else if h.params.DegreeSBox == 17 { input[index] = h.api.Mul(input[index], input[index]) input[index] = h.api.Mul(input[index], input[index]) input[index] = h.api.Mul(input[index], input[index]) input[index] = h.api.Mul(input[index], input[index]) input[index] = h.api.Mul(input[index], tmp) - } else if h.params.degreeSBox == -1 { + } else if h.params.DegreeSBox == -1 { input[index] = h.api.Inverse(input[index]) } } @@ -204,30 +323,30 @@ func (h *Permutation) matMulM4InPlace(s []frontend.Variable) { // see https://eprint.iacr.org/2023/323.pdf func (h *Permutation) matMulExternalInPlace(input []frontend.Variable) { - if h.params.width == 2 { + if h.params.Width == 2 { tmp := h.api.Add(input[0], input[1]) input[0] = h.api.Add(tmp, input[0]) input[1] = h.api.Add(tmp, input[1]) - } else if h.params.width == 3 { + } else if h.params.Width == 3 { tmp := h.api.Add(input[0], input[1]) tmp = h.api.Add(tmp, input[2]) input[0] = h.api.Add(input[0], tmp) input[1] = h.api.Add(input[1], tmp) input[2] = h.api.Add(input[2], tmp) - } else if h.params.width == 4 { + } else if h.params.Width == 4 { h.matMulM4InPlace(input) } else { // at this stage t is supposed to be a multiple of 4 // the MDS matrix is circ(2M4,M4,..,M4) h.matMulM4InPlace(input) tmp := make([]frontend.Variable, 4) - for i := 0; i < h.params.width/4; i++ { + for i := 0; i < h.params.Width/4; i++ { tmp[0] = h.api.Add(tmp[0], input[4*i]) tmp[1] = h.api.Add(tmp[1], input[4*i+1]) tmp[2] = h.api.Add(tmp[2], input[4*i+2]) tmp[3] = h.api.Add(tmp[3], input[4*i+3]) } - for i := 0; i < h.params.width/4; i++ { + for i := 0; i < h.params.Width/4; i++ { input[4*i] = h.api.Add(input[4*i], tmp[0]) input[4*i+1] = h.api.Add(input[4*i], tmp[1]) input[4*i+2] = h.api.Add(input[4*i], tmp[2]) @@ -239,12 +358,12 @@ func (h *Permutation) matMulExternalInPlace(input []frontend.Variable) { // when t=2,3 the matrix are respectively [[2,1][1,3]] and [[2,1,1][1,2,1][1,1,3]] // otherwise the matrix is filled with ones except on the diagonal, func (h *Permutation) matMulInternalInPlace(input []frontend.Variable) { - if h.params.width == 2 { + if h.params.Width == 2 { sum := h.api.Add(input[0], input[1]) input[0] = h.api.Add(input[0], sum) input[1] = h.api.Mul(2, input[1]) input[1] = h.api.Add(input[1], sum) - } else if h.params.width == 3 { + } else if h.params.Width == 3 { sum := h.api.Add(input[0], input[1]) sum = h.api.Add(sum, input[2]) input[0] = h.api.Add(input[0], sum) @@ -259,10 +378,10 @@ func (h *Permutation) matMulInternalInPlace(input []frontend.Variable) { // var sum frontend.Variable // sum = input[0] - // for i := 1; i < h.params.width; i++ { + // for i := 1; i < h.params.Width; i++ { // sum = api.Add(sum, input[i]) // } - // for i := 0; i < h.params.width; i++ { + // for i := 0; i < h.params.Width; i++ { // input[i] = api.Mul(input[i], h.params.diagInternalMatrices[i]) // input[i] = api.Add(input[i], sum) // } @@ -272,40 +391,40 @@ func (h *Permutation) matMulInternalInPlace(input []frontend.Variable) { // addRoundKeyInPlace adds the round-th key to the buffer func (h *Permutation) addRoundKeyInPlace(round int, input []frontend.Variable) { - for i := 0; i < len(h.params.roundKeys[round]); i++ { - input[i] = h.api.Add(input[i], h.params.roundKeys[round][i]) + for i := 0; i < len(h.params.RoundKeys[round]); i++ { + input[i] = h.api.Add(input[i], h.params.RoundKeys[round][i]) } } // Permutation applies the permutation on input, and stores the result in input. func (h *Permutation) Permutation(input []frontend.Variable) error { - if len(input) != h.params.width { + if len(input) != h.params.Width { return ErrInvalidSizebuffer } // external matrix multiplication, cf https://eprint.iacr.org/2023/323.pdf page 14 (part 6) h.matMulExternalInPlace(input) - rf := h.params.nbFullRounds / 2 + rf := h.params.NbFullRounds / 2 for i := 0; i < rf; i++ { // one round = matMulExternal(sBox_Full(addRoundKey)) h.addRoundKeyInPlace(i, input) - for j := 0; j < h.params.width; j++ { + for j := 0; j < h.params.Width; j++ { h.sBox(j, input) } h.matMulExternalInPlace(input) } - for i := rf; i < rf+h.params.nbPartialRounds; i++ { + for i := rf; i < rf+h.params.NbPartialRounds; i++ { // one round = matMulInternal(sBox_sparse(addRoundKey)) h.addRoundKeyInPlace(i, input) h.sBox(0, input) h.matMulInternalInPlace(input) } - for i := rf + h.params.nbPartialRounds; i < h.params.nbFullRounds+h.params.nbPartialRounds; i++ { + for i := rf + h.params.NbPartialRounds; i < h.params.NbFullRounds+h.params.NbPartialRounds; i++ { // one round = matMulExternal(sBox_Full(addRoundKey)) h.addRoundKeyInPlace(i, input) - for j := 0; j < h.params.width; j++ { + for j := 0; j < h.params.Width; j++ { h.sBox(j, input) } h.matMulExternalInPlace(input) @@ -321,7 +440,7 @@ func (h *Permutation) Permutation(input []frontend.Variable) error { // Implements the [hash.Compressor] interface for building a Merkle-Damgard // hash construction. func (h *Permutation) Compress(left, right frontend.Variable) frontend.Variable { - if h.params.width != 2 { + if h.params.Width != 2 { panic("poseidon2: Compress can only be used when t=2") } vars := [2]frontend.Variable{left, right} diff --git a/test/engine.go b/test/engine.go index 267a75a98d..4504239fc5 100644 --- a/test/engine.go +++ b/test/engine.go @@ -110,7 +110,7 @@ func IsSolved(circuit, witness frontend.Circuit, field *big.Int, opts ...TestEng defer func() { if r := recover(); r != nil { - err = fmt.Errorf("%v\n%s", r, string(debug.Stack())) + err = fmt.Errorf("%v\n%s", r, debug.Stack()) } }()