Skip to content

Commit 70cc34a

Browse files
committed
fix: adding singleflight for creating connections to remove issue of race #171
1 parent 20759b9 commit 70cc34a

File tree

8 files changed

+276
-172
lines changed

8 files changed

+276
-172
lines changed

sip/transport_connection_pool.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,52 @@ func (p *ConnectionPool) init() {
4747
p.m = make(map[string]Connection)
4848
}
4949

50+
func (p *ConnectionPool) addSingleflight(raddr Addr, laddr Addr, reuse bool, do func() (Connection, error)) (Connection, error) {
51+
a := raddr.String()
52+
// If user wants to create connection per remote addr, allow this always unless reuse is forced
53+
if laddr.Port == 0 && !reuse {
54+
// There is nothing here to block
55+
c, err := do()
56+
if err != nil {
57+
return nil, err
58+
}
59+
p.m[a] = c
60+
p.m[c.LocalAddr().String()] = c
61+
return c, nil
62+
}
63+
64+
p.Lock()
65+
defer p.Unlock()
66+
// If local port connection is needed check only this connection,
67+
// otherwise return existing only if reuse is wanted.
68+
69+
// TODO: Improve. There is no need to lock if no reuse is needed
70+
if laddr.Port > 0 {
71+
la := laddr.String()
72+
if c, exists := p.m[la]; exists {
73+
return c, nil
74+
}
75+
} else if reuse {
76+
//
77+
if c, exists := p.m[a]; exists {
78+
return c, nil
79+
}
80+
}
81+
82+
c, err := do()
83+
if err != nil {
84+
return nil, err
85+
}
86+
87+
if c.Ref(0) < 1 {
88+
c.Ref(1) // Make 1 reference count by default
89+
}
90+
// Add both references
91+
p.m[a] = c
92+
p.m[c.LocalAddr().String()] = c
93+
return c, err
94+
}
95+
5096
func (p *ConnectionPool) Add(a string, c Connection) {
5197
// TODO how about multi connection support for same remote address
5298
// We can then check ref count

sip/transport_layer.go

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ type TransportLayer struct {
3636

3737
log *slog.Logger
3838

39-
// ConnectionReuse will force connection reuse when passing request
40-
ConnectionReuse bool
39+
// connectionReuse will force connection reuse when passing request
40+
connectionReuse bool
4141

4242
// PreferSRV does always SRV lookup first
4343
DNSPreferSRV bool
@@ -53,6 +53,12 @@ func WithTransportLayerLogger(logger *slog.Logger) TransportLayerOption {
5353
}
5454
}
5555

56+
func WithTransportLayerConnectionReuse(f bool) TransportLayerOption {
57+
return func(l *TransportLayer) {
58+
l.connectionReuse = f
59+
}
60+
}
61+
5662
// NewLayer creates transport layer.
5763
// dns Resolver
5864
// sip parser
@@ -67,7 +73,7 @@ func NewTransportLayer(
6773
transports: make(map[string]Transport),
6874
listenPorts: make(map[string][]int),
6975
dnsResolver: dnsResolver,
70-
ConnectionReuse: true,
76+
connectionReuse: true,
7177
log: slog.With("caller", "TransportLayer"),
7278
}
7379

@@ -83,21 +89,24 @@ func NewTransportLayer(
8389
// Exporting transport configuration
8490
// UDP
8591
l.udp = &transportUDP{
86-
log: l.log.With("caller", "Transport<UDP>"),
92+
log: l.log.With("caller", "Transport<UDP>"),
93+
connectionReuse: l.connectionReuse,
8794
}
8895
l.udp.init(sipparser)
8996

9097
// TCP
9198
l.tcp = &transportTCP{
92-
log: l.log.With("caller", "Transport<TCP>"),
99+
log: l.log.With("caller", "Transport<TCP>"),
100+
connectionReuse: l.connectionReuse,
93101
}
94102
l.tcp.init(sipparser)
95103

96104
// TLS
97105
// TODO. Using default dial tls, but it needs to configurable via client
98106
l.tls = &transportTLS{
99107
transportTCP: &transportTCP{
100-
log: l.log.With("caller", "Transport<TLS>"),
108+
log: l.log.With("caller", "Transport<TLS>"),
109+
connectionReuse: l.connectionReuse,
101110
},
102111
}
103112
l.tls.init(sipparser, tlsConfig)
@@ -112,7 +121,8 @@ func NewTransportLayer(
112121
// TODO. Using default dial tls, but it needs to configurable via client
113122
l.wss = &transportWSS{
114123
transportWS: &transportWS{
115-
log: l.log.With("caller", "Transport<WSS>"),
124+
log: l.log.With("caller", "Transport<WSS>"),
125+
connectionReuse: l.connectionReuse,
116126
},
117127
}
118128
l.wss.init(sipparser, tlsConfig)
@@ -389,7 +399,7 @@ func (l *TransportLayer) ClientRequestConnection(ctx context.Context, req *Reque
389399
// This is probably client forcing host:port
390400
if laddr.IP != nil && laddr.Port > 0 {
391401
c = transport.GetConnection(laddr.String())
392-
} else if l.ConnectionReuse {
402+
} else if l.connectionReuse {
393403
// viaHop.Params.Add("alias", "")
394404
addr := raddr.String()
395405
c = transport.GetConnection(addr)

sip/transport_layer_test.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ func TestTransportLayerClientConnectionReuse(t *testing.T) {
3838
defer func() {
3939
require.NoError(t, tp.Close())
4040
}()
41-
42-
require.True(t, tp.ConnectionReuse)
41+
require.True(t, tp.connectionReuse)
4342

4443
t.Run("Default", func(t *testing.T) {
4544
req := NewRequest(OPTIONS, Uri{Host: "localhost", Port: 5066})
@@ -86,14 +85,13 @@ func TestTransportLayerClientConnectionReuse(t *testing.T) {
8685

8786
func TestTransportLayerClientConnectionNoReuse(t *testing.T) {
8887
// NOTE it creates real network connection
89-
tp := NewTransportLayer(net.DefaultResolver, NewParser(), nil)
88+
tp := NewTransportLayer(net.DefaultResolver, NewParser(), nil, WithTransportLayerConnectionReuse(false))
9089
defer func() {
9190
require.Empty(t, tp.udp.pool.Size())
9291
}()
9392
defer func() {
9493
require.NoError(t, tp.Close())
9594
}()
96-
tp.ConnectionReuse = false
9795

9896
t.Run("Default", func(t *testing.T) {
9997
req := NewRequest(OPTIONS, Uri{Host: "localhost", Port: 5066})

sip/transport_tcp.go

Lines changed: 48 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@ import (
1313

1414
// TCP transport implementation
1515
type transportTCP struct {
16-
addr string
17-
transport string
18-
parser *Parser
19-
log *slog.Logger
16+
addr string
17+
transport string
18+
parser *Parser
19+
log *slog.Logger
20+
createMu sync.Mutex
21+
connectionReuse bool
2022

2123
pool *ConnectionPool
2224
}
@@ -68,44 +70,55 @@ func (t *transportTCP) CreateConnection(ctx context.Context, laddr Addr, raddr A
6870
// if err != nil {
6971
// return nil, err
7072
// }
71-
var tladdr *net.TCPAddr = nil
72-
if laddr.IP != nil {
73-
tladdr = &net.TCPAddr{
74-
IP: laddr.IP,
75-
Port: laddr.Port,
73+
// We do singleflight if laddr is required or connection reuse
74+
conn, err := t.pool.addSingleflight(raddr, laddr, t.connectionReuse, func() (Connection, error) {
75+
var tladdr *net.TCPAddr = nil
76+
if laddr.IP != nil {
77+
tladdr = &net.TCPAddr{
78+
IP: laddr.IP,
79+
Port: laddr.Port,
80+
}
7681
}
77-
}
7882

79-
traddr := &net.TCPAddr{
80-
IP: raddr.IP,
81-
Port: raddr.Port,
82-
}
83-
return t.createConnection(ctx, tladdr, traddr, handler)
84-
}
83+
traddr := &net.TCPAddr{
84+
IP: raddr.IP,
85+
Port: raddr.Port,
86+
}
8587

86-
func (t *transportTCP) createConnection(ctx context.Context, laddr *net.TCPAddr, raddr *net.TCPAddr, handler MessageHandler) (Connection, error) {
87-
addr := raddr.String()
88-
t.log.Debug("Dialing new connection", "raddr", addr)
88+
addr := traddr.String()
89+
t.log.Debug("Dialing new connection", "raddr", addr)
8990

90-
d := net.Dialer{
91-
LocalAddr: laddr,
92-
}
93-
conn, err := d.DialContext(ctx, "tcp", addr)
94-
if err != nil {
95-
return nil, fmt.Errorf("%s dial err=%w", t, err)
96-
}
91+
d := net.Dialer{
92+
LocalAddr: tladdr,
93+
}
94+
conn, err := d.DialContext(ctx, "tcp", addr)
95+
if err != nil {
96+
return nil, fmt.Errorf("%s dial err=%w", t, err)
97+
}
9798

98-
// if err := conn.SetKeepAlive(true); err != nil {
99-
// return nil, fmt.Errorf("%s keepalive err=%w", t, err)
100-
// }
99+
// if err := conn.SetKeepAlive(true); err != nil {
100+
// return nil, fmt.Errorf("%s keepalive err=%w", t, err)
101+
// }
101102

102-
// if err := conn.SetKeepAlivePeriod(30 * time.Second); err != nil {
103-
// return nil, fmt.Errorf("%s keepalive period err=%w", t, err)
104-
// }
105-
c := t.initConnection(conn, addr, handler)
103+
// if err := conn.SetKeepAlivePeriod(30 * time.Second); err != nil {
104+
// return nil, fmt.Errorf("%s keepalive period err=%w", t, err)
105+
// }
106106

107-
// Increase ref by 1 before returnin
108-
c.Ref(1)
107+
t.log.Debug("New connection", "raddr", raddr)
108+
c := &TCPConnection{
109+
Conn: conn,
110+
refcount: 2 + IdleConnection,
111+
}
112+
113+
// Increase ref by 1 before returnin
114+
// c.Ref(1)
115+
return c, nil
116+
})
117+
if err != nil {
118+
return nil, err
119+
}
120+
c := conn.(*TCPConnection)
121+
go t.readConnection(c, c.LocalAddr().String(), c.RemoteAddr().String(), handler)
109122
return c, nil
110123
}
111124

sip/transport_tls.go

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -45,43 +45,54 @@ func (t *transportTLS) String() string {
4545

4646
// CreateConnection creates TLS connection for TCP transport
4747
func (t *transportTLS) CreateConnection(ctx context.Context, laddr Addr, raddr Addr, handler MessageHandler) (Connection, error) {
48-
hostname := raddr.Hostname
49-
if hostname == "" {
50-
hostname = raddr.IP.String()
51-
}
48+
conn, err := t.pool.addSingleflight(raddr, laddr, t.connectionReuse, func() (Connection, error) {
49+
hostname := raddr.Hostname
50+
if hostname == "" {
51+
hostname = raddr.IP.String()
52+
}
5253

53-
var tladdr *net.TCPAddr = nil
54-
if laddr.IP != nil {
55-
tladdr = &net.TCPAddr{
56-
IP: laddr.IP,
57-
Port: laddr.Port,
54+
var tladdr *net.TCPAddr = nil
55+
if laddr.IP != nil {
56+
tladdr = &net.TCPAddr{
57+
IP: laddr.IP,
58+
Port: laddr.Port,
59+
}
5860
}
59-
}
6061

61-
traddr := &net.TCPAddr{
62-
IP: raddr.IP,
63-
Port: raddr.Port,
64-
}
62+
traddr := &net.TCPAddr{
63+
IP: raddr.IP,
64+
Port: raddr.Port,
65+
}
6566

66-
netDialer := &net.Dialer{
67-
LocalAddr: tladdr,
68-
}
67+
netDialer := &net.Dialer{
68+
LocalAddr: tladdr,
69+
}
6970

70-
addr := traddr.String()
71-
t.log.Debug("Dialing new connection", "raddr", addr)
72-
// No resolving should happen here
73-
conn, err := netDialer.DialContext(ctx, "tcp", addr)
74-
if err != nil {
75-
return nil, fmt.Errorf("dial TCP error: %w", err)
76-
}
71+
addr := traddr.String()
72+
t.log.Debug("Dialing new connection", "raddr", addr)
73+
// No resolving should happen here
74+
conn, err := netDialer.DialContext(ctx, "tcp", addr)
75+
if err != nil {
76+
return nil, fmt.Errorf("dial TCP error: %w", err)
77+
}
7778

78-
tlsConn := t.tlsClient(conn, hostname)
79+
tlsConn := t.tlsClient(conn, hostname)
7980

80-
if err := tlsConn.HandshakeContext(ctx); err != nil {
81-
return nil, fmt.Errorf("TLS handshake error: %w", err)
82-
}
81+
if err := tlsConn.HandshakeContext(ctx); err != nil {
82+
return nil, fmt.Errorf("TLS handshake error: %w", err)
83+
}
8384

84-
c := t.initConnection(tlsConn, addr, handler)
85-
c.Ref(1)
85+
t.log.Debug("New connection", "raddr", raddr)
86+
c := &TCPConnection{
87+
Conn: tlsConn,
88+
refcount: 2 + IdleConnection,
89+
}
90+
return c, nil
91+
})
92+
if err != nil {
93+
return nil, err
94+
}
95+
c := conn.(*TCPConnection)
96+
go t.readConnection(c, c.LocalAddr().String(), c.RemoteAddr().String(), handler)
8697
return c, nil
8798
}

0 commit comments

Comments
 (0)