Skip to content

Commit 4f304f2

Browse files
committed
Add tls.Config.ClientCurveGuess to allow specifying which keyshares to send
RTG-2919 [ Bas 1.21.3: Send empty keyshare extension instead of leaving it out ]
1 parent 09f4554 commit 4f304f2

File tree

7 files changed

+144
-47
lines changed

7 files changed

+144
-47
lines changed

src/crypto/tls/cfkem.go

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ import (
2222
"fmt"
2323
"io"
2424

25-
"crypto/ecdh"
26-
2725
"github.com/cloudflare/circl/kem"
2826
"github.com/cloudflare/circl/kem/hybrid"
2927
)
3028

3129
// Either *ecdh.PrivateKey or *kemPrivateKey
32-
type clientKeySharePrivate interface{}
30+
type singleClientKeySharePrivate interface{}
31+
32+
type clientKeySharePrivate map[CurveID]singleClientKeySharePrivate
3333

3434
type kemPrivateKey struct {
3535
secretKey kem.PrivateKey
@@ -44,20 +44,9 @@ var (
4444
invalidCurveID = CurveID(0)
4545
)
4646

47-
// Extract CurveID from clientKeySharePrivate
48-
func clientKeySharePrivateCurveID(ks clientKeySharePrivate) CurveID {
49-
switch v := ks.(type) {
50-
case *kemPrivateKey:
51-
return v.curveID
52-
case *ecdh.PrivateKey:
53-
ret, ok := curveIDForCurve(v.Curve())
54-
if !ok {
55-
panic("cfkem: internal error: unknown curve")
56-
}
57-
return ret
58-
default:
59-
panic("cfkem: internal error: unknown clientKeySharePrivate")
60-
}
47+
func singleClientKeySharePrivateFor(ks clientKeySharePrivate, group CurveID) singleClientKeySharePrivate {
48+
ret, _ := ks[group]
49+
return ret
6150
}
6251

6352
// Returns scheme by CurveID if supported by Circl

src/crypto/tls/cfkem_test.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,48 @@ func TestHybridKEX(t *testing.T) {
104104
run(curveID, true, true, true, true)
105105
}
106106
}
107+
108+
func TestClientCurveGuess(t *testing.T) {
109+
run := func(guess, clientPrefs, serverPrefs []CurveID) {
110+
t.Run(
111+
fmt.Sprintf("guess=%v clientPrefs=%v serverPrefs=%v",
112+
guess, clientPrefs, serverPrefs),
113+
func(t *testing.T) {
114+
testClientCurveGuess(t, guess, clientPrefs, serverPrefs)
115+
})
116+
}
117+
both := []CurveID{X25519Kyber768Draft00, X25519}
118+
run([]CurveID{}, []CurveID{X25519}, both)
119+
run([]CurveID{X25519}, []CurveID{X25519}, both)
120+
run([]CurveID{X25519Kyber768Draft00}, both, []CurveID{X25519})
121+
run(both, both, both)
122+
run(both, both, []CurveID{X25519})
123+
run(both, both, []CurveID{X25519Kyber768Draft00})
124+
}
125+
126+
func testClientCurveGuess(t *testing.T, guess, clientPrefs, serverPrefs []CurveID) {
127+
clientConfig := testConfig.Clone()
128+
serverConfig := testConfig.Clone()
129+
serverConfig.CurvePreferences = serverPrefs
130+
clientConfig.CurvePreferences = clientPrefs
131+
clientConfig.ClientCurveGuess = guess
132+
133+
c, s := localPipe(t)
134+
done := make(chan error)
135+
defer c.Close()
136+
137+
go func() {
138+
defer s.Close()
139+
done <- Server(s, serverConfig).Handshake()
140+
}()
141+
142+
cli := Client(c, clientConfig)
143+
clientErr := cli.HandshakeContext(context.Background())
144+
serverErr := <-done
145+
if clientErr != nil {
146+
t.Errorf("client error: %v", clientErr)
147+
}
148+
if serverErr != nil {
149+
t.Errorf("server error: %v", serverErr)
150+
}
151+
}

src/crypto/tls/common.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,18 @@ type Config struct {
839839
// which is currently TLS 1.3.
840840
MaxVersion uint16
841841

842+
// ClientCurveGuess contains the "curves" for which the client will create
843+
// a keyshare in the initial ClientHello for TLS 1.3. If the client
844+
// guesses incorrectly, and the server does not support or does not
845+
// prefer those keyshares, then the server will return a HelloRetryRequest
846+
// incurring an extra roundtrip.
847+
//
848+
// If empty, no keyshares will be included in the ClientHello.
849+
//
850+
// If nil (default), will send the single most preferred keyshare
851+
// as configurable via CurvePreferences.
852+
ClientCurveGuess []CurveID
853+
842854
// CurvePreferences contains the elliptic curves that will be used in
843855
// an ECDHE handshake, in preference order. If empty, the default will
844856
// be used. The client will use the first preference as the type for
@@ -976,6 +988,7 @@ func (c *Config) Clone() *Config {
976988
MinVersion: c.MinVersion,
977989
MaxVersion: c.MaxVersion,
978990
CurvePreferences: c.CurvePreferences,
991+
ClientCurveGuess: c.ClientCurveGuess,
979992
PQSignatureSchemesEnabled: c.PQSignatureSchemesEnabled,
980993
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
981994
Renegotiation: c.Renegotiation,

src/crypto/tls/handshake_client.go

Lines changed: 66 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ func (c *Conn) makeClientHello(minVersion uint16) (*clientHelloMsg, clientKeySha
134134
hello.supportedSignatureAlgorithms = testingOnlyForceClientHelloSignatureAlgorithms
135135
}
136136

137-
var secret clientKeySharePrivate
137+
secret := make(clientKeySharePrivate)
138138
if hello.supportedVersions[0] == VersionTLS13 {
139139
// Reset the list of ciphers when the client only supports TLS 1.3.
140140
if len(hello.supportedVersions) == 1 {
@@ -146,30 +146,74 @@ func (c *Conn) makeClientHello(minVersion uint16) (*clientHelloMsg, clientKeySha
146146
hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13NoAES...)
147147
}
148148

149-
curveID := config.curvePreferences()[0]
150-
if scheme := curveIdToCirclScheme(curveID); scheme != nil {
151-
pk, sk, err := generateKemKeyPair(scheme, curveID, config.rand())
152-
if err != nil {
153-
return nil, nil, fmt.Errorf("generateKemKeyPair %s: %w",
154-
scheme.Name(), err)
155-
}
156-
packedPk, err := pk.MarshalBinary()
157-
if err != nil {
158-
return nil, nil, fmt.Errorf("pack circl public key %s: %w",
159-
scheme.Name(), err)
149+
curveIDs := []CurveID{config.curvePreferences()[0]}
150+
151+
if config.ClientCurveGuess != nil {
152+
curveIDs = config.ClientCurveGuess
153+
}
154+
155+
hello.keyShares = make([]keyShare, 0, len(curveIDs))
156+
157+
// Check whether ClientCurveGuess is a subsequence of CurvePreferences
158+
// as is required by RFC8446 §4.2.8
159+
offset := 0
160+
curvePreferences := config.curvePreferences()
161+
found := 0
162+
CurveGuessCheck:
163+
for _, curveID := range curveIDs {
164+
for {
165+
if offset == len(curvePreferences) {
166+
break CurveGuessCheck
167+
}
168+
169+
if curvePreferences[offset] == curveID {
170+
found++
171+
break
172+
}
173+
174+
offset++
160175
}
161-
hello.keyShares = []keyShare{{group: curveID, data: packedPk}}
162-
secret = sk
163-
} else {
164-
if _, ok := curveForCurveID(curveID); !ok {
165-
return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve")
176+
}
177+
if found != len(curveIDs) {
178+
return nil, nil, errors.New("tls: ClientCurveGuess not a subsequence of CurvePreferences")
179+
}
180+
181+
for _, curveID := range curveIDs {
182+
var (
183+
singleSecret interface{}
184+
singleShare []byte
185+
)
186+
187+
if _, ok := secret[curveID]; ok {
188+
return nil, nil, errors.New("tls: ClientCurveGuess contains duplicate")
166189
}
167-
key, err := generateECDHEKey(config.rand(), curveID)
168-
if err != nil {
169-
return nil, nil, err
190+
191+
if scheme := curveIdToCirclScheme(curveID); scheme != nil {
192+
pk, sk, err := generateKemKeyPair(scheme, curveID, config.rand())
193+
if err != nil {
194+
return nil, nil, fmt.Errorf("generateKemKeyPair %s: %w",
195+
scheme.Name(), err)
196+
}
197+
packedPk, err := pk.MarshalBinary()
198+
if err != nil {
199+
return nil, nil, fmt.Errorf("pack circl public key %s: %w",
200+
scheme.Name(), err)
201+
}
202+
singleShare = packedPk
203+
singleSecret = sk
204+
} else {
205+
if _, ok := curveForCurveID(curveID); !ok {
206+
return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve")
207+
}
208+
key, err := generateECDHEKey(config.rand(), curveID)
209+
if err != nil {
210+
return nil, nil, err
211+
}
212+
singleShare = key.PublicKey().Bytes()
213+
singleSecret = key
170214
}
171-
hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}}
172-
secret = key
215+
hello.keyShares = append(hello.keyShares, keyShare{group: curveID, data: singleShare})
216+
secret[curveID] = singleSecret
173217
}
174218

175219
hello.delegatedCredentialSupported = config.SupportDelegatedCredential

src/crypto/tls/handshake_client_tls13.go

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
103103
}
104104

105105
// Consistency check on the presence of a keyShare and its parameters.
106-
if hs.keySharePrivate == nil || len(hs.hello.keyShares) != 1 {
106+
if hs.keySharePrivate == nil {
107107
return c.sendAlert(alertInternalError)
108108
}
109109

@@ -379,7 +379,7 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
379379
c.sendAlert(alertIllegalParameter)
380380
return errors.New("tls: server selected unsupported group")
381381
}
382-
if clientKeySharePrivateCurveID(hs.keySharePrivate) == curveID {
382+
if singleClientKeySharePrivateFor(hs.keySharePrivate, curveID) != nil {
383383
c.sendAlert(alertIllegalParameter)
384384
return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share")
385385
}
@@ -396,7 +396,7 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
396396
return fmt.Errorf("HRR pack circl public key %s: %w",
397397
scheme.Name(), err)
398398
}
399-
hs.keySharePrivate = sk
399+
hs.keySharePrivate = clientKeySharePrivate{curveID: sk}
400400
hello.keyShares = []keyShare{{group: curveID, data: packedPk}}
401401
} else {
402402
if _, ok := curveForCurveID(curveID); !ok {
@@ -408,7 +408,7 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
408408
c.sendAlert(alertInternalError)
409409
return err
410410
}
411-
hs.keySharePrivate = key
411+
hs.keySharePrivate = clientKeySharePrivate{curveID: key}
412412
hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}}
413413
}
414414
}
@@ -558,7 +558,7 @@ func (hs *clientHandshakeStateTLS13) processServerHello() error {
558558
c.sendAlert(alertIllegalParameter)
559559
return errors.New("tls: server did not send a key share")
560560
}
561-
if hs.serverHello.serverShare.group != clientKeySharePrivateCurveID(hs.keySharePrivate) {
561+
if singleClientKeySharePrivateFor(hs.keySharePrivate, hs.serverHello.serverShare.group) == nil {
562562
c.sendAlert(alertIllegalParameter)
563563
return errors.New("tls: server selected unsupported group")
564564
}
@@ -613,12 +613,16 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error {
613613

614614
var sharedKey []byte
615615
var err error
616-
if key, ok := hs.keySharePrivate.(*ecdh.PrivateKey); ok {
616+
617+
// We already checked that ks isn't nil in processServerHello()
618+
ks := singleClientKeySharePrivateFor(hs.keySharePrivate, hs.serverHello.serverShare.group)
619+
620+
if key, ok := ks.(*ecdh.PrivateKey); ok {
617621
peerKey, err := key.Curve().NewPublicKey(hs.serverHello.serverShare.data)
618622
if err == nil {
619623
sharedKey, _ = key.ECDH(peerKey)
620624
}
621-
} else if key, ok := hs.keySharePrivate.(*kemPrivateKey); ok {
625+
} else if key, ok := ks.(*kemPrivateKey); ok {
622626
sk := key.secretKey
623627
sharedKey, err = sk.Scheme().Decapsulate(sk, hs.serverHello.serverShare.data)
624628
if err != nil {

src/crypto/tls/handshake_messages.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ func (m *clientHelloMsg) marshal() ([]byte, error) {
248248
})
249249
})
250250
}
251-
if len(m.keyShares) > 0 {
251+
if m.keyShares != nil {
252252
// RFC 8446, Section 4.2.8
253253
exts.AddUint16(extensionKeyShare)
254254
exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {

src/crypto/tls/tls_test.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -865,6 +865,8 @@ func TestCloneNonFuncFields(t *testing.T) {
865865
f.Set(reflect.ValueOf([]uint16{1, 2}))
866866
case "CurvePreferences":
867867
f.Set(reflect.ValueOf([]CurveID{CurveP256}))
868+
case "ClientCurveGuess":
869+
f.Set(reflect.ValueOf([]CurveID{CurveP256}))
868870
case "PQSignatureSchemesEnabled":
869871
f.Set(reflect.ValueOf(true))
870872
case "Renegotiation":

0 commit comments

Comments
 (0)