Skip to content

Commit

Permalink
more
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Mangum <[email protected]>
  • Loading branch information
hasheddan committed Aug 25, 2023
1 parent 197b9de commit 43dc643
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 60 deletions.
70 changes: 69 additions & 1 deletion connection_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@

package dtls

import "crypto/rand"
import (
"crypto/rand"

"github.com/pion/dtls/v2/pkg/protocol"
"github.com/pion/dtls/v2/pkg/protocol/extension"
"github.com/pion/dtls/v2/pkg/protocol/handshake"
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
)

// RandomCIDGenerator is a random Connection ID generator where CID is the
// specified size. Specifying a size of 0 will indicate to peers that sending a
Expand All @@ -26,3 +33,64 @@ func OnlySendCIDGenerator() func() []byte {
return nil
}
}

// cidConnResolver extracts connection IDs from incoming packets and uses them
// to route to the proper connection.
func cidConnResolver(size int) func([]byte) (string, bool) {
return func(packet []byte) (string, bool) {
pkts, err := recordlayer.ContentAwareUnpackDatagram(packet, size)
if err != nil || len(pkts) < 1 {
return "", false
}
for _, pkt := range pkts {
h := &recordlayer.Header{
ConnectionID: make([]byte, size),
}
if err := h.Unmarshal(pkt); err != nil {
continue

Check warning on line 50 in connection_id.go

View check run for this annotation

Codecov / codecov/patch

connection_id.go#L39-L50

Added lines #L39 - L50 were not covered by tests
}
if h.ContentType != protocol.ContentTypeConnectionID {
continue

Check warning on line 53 in connection_id.go

View check run for this annotation

Codecov / codecov/patch

connection_id.go#L52-L53

Added lines #L52 - L53 were not covered by tests
}
return string(h.ConnectionID), true

Check warning on line 55 in connection_id.go

View check run for this annotation

Codecov / codecov/patch

connection_id.go#L55

Added line #L55 was not covered by tests
}
return "", false

Check warning on line 57 in connection_id.go

View check run for this annotation

Codecov / codecov/patch

connection_id.go#L57

Added line #L57 was not covered by tests
}
}

// cidConnIdentifier extracts connection IDs from outgoing ServerHello packets
// and associates them with the associated connection.
func cidConnIdentifier() func([]byte) (string, bool) {
return func(packet []byte) (string, bool) {
pkts, err := recordlayer.UnpackDatagram(packet)
if err != nil || len(pkts) < 1 {
return "", false
}
h := &recordlayer.Header{}
if err := h.Unmarshal(pkts[0]); err != nil {

Check failure on line 70 in connection_id.go

View workflow job for this annotation

GitHub Actions / lint / Go

shadow: declaration of "err" shadows declaration at line 65 (govet)
return "", false
}
if h.ContentType != protocol.ContentTypeHandshake {
return "", false
}
hh := &handshake.Header{}
sh := &handshake.MessageServerHello{}
for _, pkt := range pkts {
if err := hh.Unmarshal(pkt[recordlayer.FixedHeaderSize:]); err != nil {

Check failure on line 79 in connection_id.go

View workflow job for this annotation

GitHub Actions / lint / Go

shadow: declaration of "err" shadows declaration at line 65 (govet)
continue

Check warning on line 80 in connection_id.go

View check run for this annotation

Codecov / codecov/patch

connection_id.go#L63-L80

Added lines #L63 - L80 were not covered by tests
}
if err = sh.Unmarshal(pkt[recordlayer.FixedHeaderSize+handshake.HeaderLength:]); err == nil {
break

Check warning on line 83 in connection_id.go

View check run for this annotation

Codecov / codecov/patch

connection_id.go#L82-L83

Added lines #L82 - L83 were not covered by tests
}
}
if err != nil {
return "", false
}
for _, ext := range sh.Extensions {
if e, ok := ext.(*extension.ConnectionID); ok {
return string(e.CID), true
}

Check warning on line 92 in connection_id.go

View check run for this annotation

Codecov / codecov/patch

connection_id.go#L86-L92

Added lines #L86 - L92 were not covered by tests
}
return "", false

Check warning on line 94 in connection_id.go

View check run for this annotation

Codecov / codecov/patch

connection_id.go#L94

Added line #L94 was not covered by tests
}
}
17 changes: 9 additions & 8 deletions internal/net/udp/packet_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ type listener struct {
doneCh chan struct{}
doneOnce sync.Once
acceptFilter func([]byte) bool
connResolver func([]byte, net.Addr) string
connIdentifier func([]byte, net.Addr) (string, bool)
connResolver func([]byte) (string, bool)
connIdentifier func([]byte) (string, bool)

connLock sync.Mutex
conns map[string]*PacketConn
Expand Down Expand Up @@ -140,12 +140,12 @@ type ListenConfig struct {

// ConnectionResolver resolves an incoming packet to a connection by
// extracting an identifier from the packet contents.
ConnectionResolver func([]byte, net.Addr) string
ConnectionResolver func([]byte) (string, bool)

// ConnectionIdentifier extracts an identifier from an outgoing packet. If
// the identifier is not already associated with the connection, it will be
// added.
ConnectionIdentifier func([]byte, net.Addr) (string, bool)
ConnectionIdentifier func([]byte) (string, bool)
}

// Listen creates a new listener based on the ListenConfig.
Expand Down Expand Up @@ -222,9 +222,10 @@ func (l *listener) getConn(raddr net.Addr, buf []byte) (*PacketConn, bool, error
defer l.connLock.Unlock()
// If we have a custom resolver, use it.
if l.connResolver != nil {
conn, ok := l.conns[l.connResolver(buf, raddr)]
if ok {
return conn, true, nil
if id, ok := l.connResolver(buf); ok {
if conn, ok := l.conns[id]; ok {
return conn, true, nil
}
}
}

Expand Down Expand Up @@ -293,7 +294,7 @@ func (c *PacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
// sets it.
if c.listener.connIdentifier != nil {
id := c.id.Load()
candidate, ok := c.listener.connIdentifier(p, addr)
candidate, ok := c.listener.connIdentifier(p)
// If this is a new identifier, add entry to connection map.
if ok && id != candidate {
c.listener.connLock.Lock()
Expand Down
8 changes: 4 additions & 4 deletions internal/net/udp/packet_conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -496,17 +496,17 @@ func TestListenerCustomConnID(t *testing.T) {
}
network, addr := getConfig()
listener, err := (&ListenConfig{
ConnectionResolver: func(buf []byte, raddr net.Addr) string {
ConnectionResolver: func(buf []byte) (string, bool) {
p := &pkt{}
if err := json.Unmarshal(buf, p); err != nil {
t.Fatal(err)
}
if p.Payload == helloPayload {
return raddr.String()
return "", false
}
return fmt.Sprint(p.ID)
return fmt.Sprint(p.ID), true
},
ConnectionIdentifier: func(buf []byte, _ net.Addr) (string, bool) {
ConnectionIdentifier: func(buf []byte) (string, bool) {
p := &pkt{}
if err := json.Unmarshal(buf, p); err != nil {
t.Fatal(err)
Expand Down
49 changes: 2 additions & 47 deletions listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,54 +9,9 @@ import (
"github.com/pion/dtls/v2/internal/net/udp"
dtlsnet "github.com/pion/dtls/v2/pkg/net"
"github.com/pion/dtls/v2/pkg/protocol"
"github.com/pion/dtls/v2/pkg/protocol/extension"
"github.com/pion/dtls/v2/pkg/protocol/handshake"
"github.com/pion/dtls/v2/pkg/protocol/recordlayer"
)

// cidConnResolver extracts connection IDs from incoming packets and uses them
// to route to the proper connection.
func cidConnResolver(packet []byte, raddr net.Addr) string {
pkts, err := recordlayer.UnpackDatagram(packet)
if err != nil || len(pkts) < 1 {
return raddr.String()
}
h := &recordlayer.Header{}
if err := h.Unmarshal(pkts[0]); err != nil {
return raddr.String()
}
if h.ContentType != protocol.ContentTypeConnectionID {
return raddr.String()
}
return string(h.ConnectionID)
}

// cidConnIdentifier extracts connection IDs from outgoing ServerHello packets
// and associates them with the associated connection.
func cidConnIdentifier(packet []byte, _ net.Addr) (string, bool) {
pkts, err := recordlayer.UnpackDatagram(packet)
if err != nil || len(pkts) < 1 {
return "", false
}
h := &recordlayer.Header{}
if err := h.Unmarshal(pkts[0]); err != nil {
return "", false
}
if h.ContentType != protocol.ContentTypeHandshake {
return "", false
}
sh := &handshake.MessageServerHello{}
if err := sh.Unmarshal(pkts[0]); err != nil {
return "", false
}
for _, ext := range sh.Extensions {
if e, ok := ext.(*extension.ConnectionID); ok {
return string(e.CID), true
}
}
return "", false
}

// Listen creates a DTLS listener
func Listen(network string, laddr *net.UDPAddr, config *Config) (net.Listener, error) {
if err := validateConfig(config); err != nil {
Expand All @@ -79,8 +34,8 @@ func Listen(network string, laddr *net.UDPAddr, config *Config) (net.Listener, e
// If connection ID support is enabled, then they must be supported in
// routing.
if config.ConnectionIDGenerator != nil {
lc.ConnectionResolver = cidConnResolver
lc.ConnectionIdentifier = cidConnIdentifier
lc.ConnectionResolver = cidConnResolver(len(config.ConnectionIDGenerator()))
lc.ConnectionIdentifier = cidConnIdentifier()
}

Check warning on line 39 in listener.go

View check run for this annotation

Codecov / codecov/patch

listener.go#L36-L39

Added lines #L36 - L39 were not covered by tests
parent, err := lc.Listen(network, laddr)
if err != nil {
Expand Down

0 comments on commit 43dc643

Please sign in to comment.