Skip to content

Commit 272ebee

Browse files
committed
Implement active TCP candidate type
1 parent 9b4e7d9 commit 272ebee

File tree

4 files changed

+173
-32
lines changed

4 files changed

+173
-32
lines changed

agent.go

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"context"
1010
"fmt"
1111
"net"
12+
"strconv"
1213
"strings"
1314
"sync"
1415
"sync/atomic"
@@ -138,6 +139,7 @@ type Agent struct {
138139

139140
interfaceFilter func(string) bool
140141
ipFilter func(net.IP) bool
142+
ActiveTCP bool
141143
includeLoopback bool
142144

143145
insecureSkipVerify bool
@@ -312,6 +314,8 @@ func NewAgent(config *AgentConfig) (*Agent, error) { //nolint:gocognit
312314

313315
ipFilter: config.IPFilter,
314316

317+
ActiveTCP: config.ActiveTCP,
318+
315319
insecureSkipVerify: config.InsecureSkipVerify,
316320

317321
includeLoopback: config.IncludeLoopback,
@@ -579,6 +583,40 @@ func (a *Agent) getBestValidCandidatePair() *CandidatePair {
579583
}
580584

581585
func (a *Agent) addPair(local, remote Candidate) *CandidatePair {
586+
if local.TCPType() == TCPTypeActive && remote.TCPType() == TCPTypePassive {
587+
addressToConnect := net.JoinHostPort(remote.Address(), strconv.Itoa(remote.Port()))
588+
589+
conn, err := net.Dial("tcp", addressToConnect)
590+
if err != nil {
591+
a.log.Errorf("Failed to dial TCP address %s: %v", addressToConnect, err)
592+
return nil
593+
}
594+
595+
packetConn := newTCPPacketConn(tcpPacketParams{
596+
ReadBuffer: 8,
597+
LocalAddr: conn.LocalAddr(),
598+
Logger: a.log,
599+
})
600+
601+
if err = packetConn.AddConn(conn, nil); err != nil {
602+
a.log.Errorf("Failed to add TCP connection: %v", err)
603+
return nil
604+
}
605+
606+
localAddress, ok := conn.LocalAddr().(*net.TCPAddr)
607+
if !ok {
608+
a.log.Errorf("Failed to cast local address to TCP address")
609+
return nil
610+
}
611+
612+
localCandidateHost, ok := local.(*CandidateHost)
613+
if !ok {
614+
a.log.Errorf("Failed to cast local candidate to CandidateHost")
615+
return nil
616+
}
617+
localCandidateHost.port = localAddress.Port // this causes a data race with candidateBase.Port()
618+
local.start(a, packetConn, a.startedCh)
619+
}
582620
p := newCandidatePair(local, remote, a.isControlling)
583621
a.checklist = append(a.checklist, p)
584622
return p
@@ -755,7 +793,9 @@ func (a *Agent) addCandidate(ctx context.Context, c Candidate, candidateConn net
755793
}
756794
}
757795

758-
c.start(a, candidateConn, a.startedCh)
796+
if c.TCPType() != TCPTypeActive {
797+
c.start(a, candidateConn, a.startedCh)
798+
}
759799

760800
set = append(set, c)
761801
a.localCandidates[c.NetworkType()] = set

agent_active_tcp_test.go

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
2+
// SPDX-License-Identifier: MIT
3+
4+
//go:build !js
5+
// +build !js
6+
7+
package ice
8+
9+
import (
10+
"net"
11+
"testing"
12+
13+
"github.com/pion/logging"
14+
"github.com/stretchr/testify/require"
15+
)
16+
17+
func TestAgentActiveTCP(t *testing.T) {
18+
r := require.New(t)
19+
20+
const port = 7686
21+
22+
listener, err := net.ListenTCP("tcp", &net.TCPAddr{
23+
IP: net.IPv4(127, 0, 0, 1),
24+
Port: port,
25+
})
26+
r.NoError(err)
27+
defer func() {
28+
_ = listener.Close()
29+
}()
30+
31+
loggerFactory := logging.NewDefaultLoggerFactory()
32+
loggerFactory.DefaultLogLevel.Set(logging.LogLevelTrace)
33+
34+
tcpMux := NewTCPMuxDefault(TCPMuxParams{
35+
Listener: listener,
36+
Logger: loggerFactory.NewLogger("passive-ice"),
37+
ReadBufferSize: 20,
38+
})
39+
40+
defer func() {
41+
_ = tcpMux.Close()
42+
}()
43+
44+
r.NotNil(tcpMux.LocalAddr(), "tcpMux.LocalAddr() is nil")
45+
46+
passiveAgent, err := NewAgent(&AgentConfig{
47+
TCPMux: tcpMux,
48+
CandidateTypes: []CandidateType{CandidateTypeHost},
49+
NetworkTypes: []NetworkType{NetworkTypeTCP4},
50+
LoggerFactory: loggerFactory,
51+
ActiveTCP: false,
52+
IncludeLoopback: true,
53+
})
54+
r.NoError(err)
55+
r.NotNil(passiveAgent)
56+
57+
activeAgent, err := NewAgent(&AgentConfig{
58+
CandidateTypes: []CandidateType{CandidateTypeHost},
59+
NetworkTypes: []NetworkType{NetworkTypeTCP4},
60+
LoggerFactory: loggerFactory,
61+
ActiveTCP: true,
62+
})
63+
r.NoError(err)
64+
r.NotNil(activeAgent)
65+
66+
passiveAgentConn, activeAgenConn := connect(passiveAgent, activeAgent)
67+
r.NotNil(passiveAgentConn)
68+
r.NotNil(activeAgenConn)
69+
70+
pair := passiveAgent.getSelectedPair()
71+
r.NotNil(pair)
72+
r.Equal(port, pair.Local.Port())
73+
74+
data := []byte("hello world")
75+
_, err = passiveAgentConn.Write(data)
76+
r.NoError(err)
77+
78+
buffer := make([]byte, 1024)
79+
n, err := activeAgenConn.Read(buffer)
80+
r.NoError(err)
81+
r.Equal(data, buffer[:n])
82+
83+
data2 := []byte("hello world 2")
84+
_, err = activeAgenConn.Write(data2)
85+
r.NoError(err)
86+
87+
n, err = passiveAgentConn.Read(buffer)
88+
r.NoError(err)
89+
r.Equal(data2, buffer[:n])
90+
91+
r.NoError(activeAgenConn.Close())
92+
r.NoError(passiveAgentConn.Close())
93+
r.NoError(tcpMux.Close())
94+
}

agent_config.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ type AgentConfig struct {
145145
// the ips which are used to gather ICE candidates.
146146
IPFilter func(net.IP) bool
147147

148+
ActiveTCP bool
149+
148150
// InsecureSkipVerify controls if self-signed certificates are accepted when connecting
149151
// to TURN servers via TLS or DTLS
150152
InsecureSkipVerify bool

gather.go

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -166,44 +166,49 @@ func (a *Agent) gatherCandidatesLocal(ctx context.Context, networkTypes []Networ
166166

167167
switch network {
168168
case tcp:
169-
if a.tcpMux == nil {
170-
continue
171-
}
172-
173-
// Handle ICE TCP passive mode
174-
var muxConns []net.PacketConn
175-
if multi, ok := a.tcpMux.(AllConnsGetter); ok {
176-
a.log.Debugf("GetAllConns by ufrag: %s", a.localUfrag)
177-
muxConns, err = multi.GetAllConns(a.localUfrag, mappedIP.To4() == nil, ip)
178-
if err != nil {
179-
a.log.Warnf("Failed to get all TCP connections by ufrag: %s %s %s", network, ip, a.localUfrag)
180-
continue
181-
}
169+
if a.ActiveTCP {
170+
conns = append(conns, connAndPort{nil, 0})
171+
tcpType = TCPTypeActive
182172
} else {
183-
a.log.Debugf("GetConn by ufrag: %s", a.localUfrag)
184-
conn, err := a.tcpMux.GetConnByUfrag(a.localUfrag, mappedIP.To4() == nil, ip)
185-
if err != nil {
186-
a.log.Warnf("Failed to get TCP connections by ufrag: %s %s %s", network, ip, a.localUfrag)
173+
// Handle ICE TCP passive mode
174+
if a.tcpMux == nil {
187175
continue
188176
}
189-
muxConns = []net.PacketConn{conn}
190-
}
191177

192-
// Extract the port for each PacketConn we got.
193-
for _, conn := range muxConns {
194-
if tcpConn, ok := conn.LocalAddr().(*net.TCPAddr); ok {
195-
conns = append(conns, connAndPort{conn, tcpConn.Port})
178+
var muxConns []net.PacketConn
179+
if multi, ok := a.tcpMux.(AllConnsGetter); ok {
180+
a.log.Debugf("GetAllConns by ufrag: %s", a.localUfrag)
181+
muxConns, err = multi.GetAllConns(a.localUfrag, mappedIP.To4() == nil, ip)
182+
if err != nil {
183+
a.log.Warnf("Failed to get all TCP connections by ufrag: %s %s %s", network, ip, a.localUfrag)
184+
continue
185+
}
196186
} else {
197-
a.log.Warnf("Failed to get port of connection from TCPMux: %s %s %s", network, ip, a.localUfrag)
187+
a.log.Debugf("GetConn by ufrag: %s", a.localUfrag)
188+
conn, err := a.tcpMux.GetConnByUfrag(a.localUfrag, mappedIP.To4() == nil, ip)
189+
if err != nil {
190+
a.log.Warnf("Failed to get TCP connections by ufrag: %s %s %s", network, ip, a.localUfrag)
191+
continue
192+
}
193+
muxConns = []net.PacketConn{conn}
198194
}
195+
196+
// Extract the port for each PacketConn we got.
197+
for _, conn := range muxConns {
198+
if tcpConn, ok := conn.LocalAddr().(*net.TCPAddr); ok {
199+
conns = append(conns, connAndPort{conn, tcpConn.Port})
200+
} else {
201+
a.log.Warnf("Failed to get port of connection from TCPMux: %s %s %s", network, ip, a.localUfrag)
202+
}
203+
}
204+
if len(conns) == 0 {
205+
// Didn't succeed with any, try the next network.
206+
continue
207+
}
208+
tcpType = TCPTypePassive
209+
// Is there a way to verify that the listen address is even
210+
// accessible from the current interface.
199211
}
200-
if len(conns) == 0 {
201-
// Didn't succeed with any, try the next network.
202-
continue
203-
}
204-
tcpType = TCPTypePassive
205-
// Is there a way to verify that the listen address is even
206-
// accessible from the current interface.
207212
case udp:
208213
conn, err := listenUDPInPortRange(a.net, a.log, int(a.portMax), int(a.portMin), network, &net.UDPAddr{IP: ip, Port: 0})
209214
if err != nil {

0 commit comments

Comments
 (0)