Skip to content

Commit 806ff2f

Browse files
committed
Refactor cfg.onFlightState, avoid data race
1 parent f5e908f commit 806ff2f

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

conn.go

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ type Conn struct {
7272
maximumTransmissionUnit int
7373
paddingLengthGenerator func(uint) uint
7474

75-
handshakeCompletedSuccessfully atomic.Value
75+
handshakeCompletedSuccessfully atomic.Bool
7676
handshakeMutex sync.Mutex
7777
handshakeDone chan struct{}
7878

@@ -1077,14 +1077,12 @@ func (c *Conn) notify(ctx context.Context, level alert.Level, desc alert.Descrip
10771077
})
10781078
}
10791079

1080-
func (c *Conn) setHandshakeCompletedSuccessfully() {
1081-
c.handshakeCompletedSuccessfully.Store(struct{ bool }{true})
1080+
func (c *Conn) setHandshakeCompletedSuccessfully() bool {
1081+
return c.handshakeCompletedSuccessfully.CompareAndSwap(false, true)
10821082
}
10831083

10841084
func (c *Conn) isHandshakeCompletedSuccessfully() bool {
1085-
boolean, _ := c.handshakeCompletedSuccessfully.Load().(struct{ bool })
1086-
1087-
return boolean.bool
1085+
return c.handshakeCompletedSuccessfully.Load()
10881086
}
10891087

10901088
//nolint:cyclop,gocognit,contextcheck
@@ -1099,8 +1097,7 @@ func (c *Conn) handshake(
10991097
done := make(chan struct{})
11001098
ctxRead, cancelRead := context.WithCancel(context.Background())
11011099
cfg.onFlightState = func(_ flightVal, s handshakeState) {
1102-
if s == handshakeFinished && !c.isHandshakeCompletedSuccessfully() {
1103-
c.setHandshakeCompletedSuccessfully()
1100+
if s == handshakeFinished && c.setHandshakeCompletedSuccessfully() {
11041101
close(done)
11051102
}
11061103
}

0 commit comments

Comments
 (0)