Skip to content

Commit adb5692

Browse files
committed
shamir: First pass at Shamir secret sharing
1 parent eca348c commit adb5692

File tree

4 files changed

+497
-0
lines changed

4 files changed

+497
-0
lines changed

shamir/shamir.go

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
package shamir
2+
3+
import (
4+
"crypto/rand"
5+
"fmt"
6+
)
7+
8+
// polynomial represents a polynomial of arbitrary degree
9+
type polynomial struct {
10+
coefficients []uint8
11+
}
12+
13+
// makePolynomial constructs a random polynomial of the given
14+
// degree but with the provided intercept value.
15+
func makePolynomial(intercept, degree uint8) (polynomial, error) {
16+
// Create a wrapper
17+
p := polynomial{
18+
coefficients: make([]byte, degree+1),
19+
}
20+
21+
// Ensure the intercept is set
22+
p.coefficients[0] = intercept
23+
24+
// Assign random co-efficients to the polynomial, ensuring
25+
// the highest order co-efficient is non-zero
26+
for p.coefficients[degree] == 0 {
27+
if _, err := rand.Read(p.coefficients[1:]); err != nil {
28+
return p, err
29+
}
30+
}
31+
return p, nil
32+
}
33+
34+
// evaluate returns the value of the polynomial for the given x
35+
func (p *polynomial) evaluate(x uint8) uint8 {
36+
// Special case the origin
37+
if x == 0 {
38+
return p.coefficients[0]
39+
}
40+
41+
// Compute the polynomial value using Horner's method.
42+
degree := len(p.coefficients) - 1
43+
out := p.coefficients[degree]
44+
for i := degree - 1; i >= 0; i-- {
45+
coeff := p.coefficients[i]
46+
out = add(mult(out, x), coeff)
47+
}
48+
return out
49+
}
50+
51+
// interpolatePolynomial takes N sample points and returns
52+
// the value at a given x using a lagrange interpolation.
53+
func interpolatePolynomial(x_samples, y_samples []uint8, x uint8) uint8 {
54+
limit := len(x_samples)
55+
var result, basis uint8
56+
for i := 0; i < limit; i++ {
57+
basis = 1
58+
for j := 0; j < limit; j++ {
59+
if i == j {
60+
continue
61+
}
62+
num := add(x, x_samples[j])
63+
denom := add(x_samples[i], x_samples[j])
64+
term := div(num, denom)
65+
basis = mult(basis, term)
66+
//println(fmt.Sprintf("Num: %d Denom: %d Term: %d Basis: %d",
67+
// num, denom, term, basis))
68+
}
69+
group := mult(y_samples[i], basis)
70+
//println(fmt.Sprintf("Group: %d", group))
71+
result = add(result, group)
72+
}
73+
return result
74+
}
75+
76+
// div divides two numbers in GF(2^8)
77+
func div(a, b uint8) uint8 {
78+
if b == 0 {
79+
panic("divide by zero")
80+
}
81+
if a == 0 {
82+
return 0
83+
}
84+
85+
log_a := logTable[a]
86+
log_b := logTable[b]
87+
diff := (int(log_a) - int(log_b)) % 255
88+
if diff < 0 {
89+
diff += 255
90+
}
91+
return expTable[diff]
92+
}
93+
94+
// mult multiplies two numbers in GF(2^8)
95+
func mult(a, b uint8) (out uint8) {
96+
if a == 0 || b == 0 {
97+
return 0
98+
}
99+
log_a := logTable[a]
100+
log_b := logTable[b]
101+
sum := (int(log_a) + int(log_b)) % 255
102+
return expTable[sum]
103+
}
104+
105+
// add combines two numbers in GF(2^8)
106+
// This can also be used for subtraction since it is symmetric.
107+
func add(a, b uint8) uint8 {
108+
return a ^ b
109+
}
110+
111+
// Split takes an arbitrarily long secret and generates a `parts`
112+
// number of shares, `threshold` of which are required to reconstruct
113+
// the secret. The parts and threshold must be at least 2, and less
114+
// than 256. The returned shares are each one byte longer than the secret
115+
// as they attach a tag used to reconstruct the secret.
116+
func Split(secret []byte, parts, threshold int) ([][]byte, error) {
117+
// Sanity check the input
118+
if parts < threshold {
119+
return nil, fmt.Errorf("parts cannot be less than threshold")
120+
}
121+
if parts > 255 {
122+
return nil, fmt.Errorf("parts cannot exceed 255")
123+
}
124+
if threshold < 2 {
125+
return nil, fmt.Errorf("threshold must be at least 2")
126+
}
127+
if threshold > 255 {
128+
return nil, fmt.Errorf("threshold cannot exceed 255")
129+
}
130+
if len(secret) == 0 {
131+
return nil, fmt.Errorf("cannot split an empty secret")
132+
}
133+
134+
// Allocate the output array, initialize the final byte
135+
// of the output with the offset. The representation of each
136+
// output is {y1, y2, .., yN, x}.
137+
out := make([][]byte, parts)
138+
for idx := range out {
139+
out[idx] = make([]byte, len(secret)+1)
140+
out[idx][len(secret)] = uint8(idx) + 1
141+
}
142+
143+
// Construct a random polynomial for each byte of the secret.
144+
// Because we are using a field of size 256, we can only represent
145+
// a single byte as the intercept of the polynomial, so we must
146+
// use a new polynomial for each byte.
147+
for idx, val := range secret {
148+
p, err := makePolynomial(val, uint8(threshold-1))
149+
if err != nil {
150+
return nil, fmt.Errorf("failed to generate polynomial: %v", err)
151+
}
152+
153+
// Generate a `parts` number of (x,y) pairs
154+
// We cheat by encoding the x value once as the final index,
155+
// so that it only needs to be stored once.
156+
for i := 0; i < parts; i++ {
157+
x := uint8(i) + 1
158+
y := p.evaluate(x)
159+
out[i][idx] = y
160+
}
161+
}
162+
163+
// Return the encoded secrets
164+
return out, nil
165+
}
166+
167+
// Combine is used to reverse a Split and reconstruct a secret
168+
// once a `threshold` number of parts are available.
169+
func Combine(parts [][]byte) ([]byte, error) {
170+
// Verify enough parts provided
171+
if len(parts) < 2 {
172+
return nil, fmt.Errorf("less than two parts cannot be used to reconstruct the secret")
173+
}
174+
175+
// Verify the parts are all the same length
176+
firstPartLen := len(parts[0])
177+
if firstPartLen < 2 {
178+
return nil, fmt.Errorf("parts must be at least two bytes")
179+
}
180+
for i := 1; i < len(parts); i++ {
181+
if len(parts[i]) != firstPartLen {
182+
return nil, fmt.Errorf("all parts must be the same length")
183+
}
184+
}
185+
186+
// Create a buffer to store the reconstructed secret
187+
secret := make([]byte, firstPartLen-1)
188+
189+
// Buffer to store the samples
190+
x_samples := make([]uint8, len(parts))
191+
y_samples := make([]uint8, len(parts))
192+
193+
// Set the x value for each sample
194+
for i, part := range parts {
195+
x_samples[i] = part[firstPartLen-1]
196+
}
197+
198+
// Reconstruct each byte
199+
for idx := range secret {
200+
// Set the y value for each sample
201+
for i, part := range parts {
202+
y_samples[i] = part[idx]
203+
}
204+
205+
// Interpolte the polynomial and compute the value at 0
206+
println(fmt.Sprintf("byte: %d x: %v y: %v", idx, x_samples, y_samples))
207+
val := interpolatePolynomial(x_samples, y_samples, 0)
208+
println(fmt.Sprintf("byte: %d out: %v", idx, val))
209+
210+
// Evaluate the 0th value to get the intercept
211+
secret[idx] = val
212+
}
213+
return secret, nil
214+
}

0 commit comments

Comments
 (0)