Skip to content

Commit

Permalink
Use weak_ptr when creating a callback for resolver in ntcr/ntcp::Stre…
Browse files Browse the repository at this point in the history
…amSocket
  • Loading branch information
smtrfnv authored Jun 21, 2024
1 parent 043113c commit a9571fa
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 24 deletions.
30 changes: 23 additions & 7 deletions groups/ntc/ntcp/ntcp_streamsocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3265,8 +3265,8 @@ ntsa::Error StreamSocket::privateUpgrade(
bdlf::MemFnUtil::memFn(&StreamSocket::privateEncryptionHandshake,
this);

error = d_encryption_sp->initiateHandshake(upgradeOptions,
handshakeCallback);
error =
d_encryption_sp->initiateHandshake(upgradeOptions, handshakeCallback);
if (error) {
return error;
}
Expand Down Expand Up @@ -3392,17 +3392,33 @@ void StreamSocket::privateRetryConnect(
error = this->privateRetryConnectToEndpoint(self);
}
else {
error = this->privateRetryConnectToName(self);
error = this->privateRetryConnectToName();
}

if (error) {
this->privateFailConnect(self, error, false, false);
}
}

ntsa::Error StreamSocket::privateRetryConnectToName(
const bsl::shared_ptr<StreamSocket>& self)
ntsa::Error StreamSocket::privateRetryConnectToName()
{
struct WeakBinder {
static void invoke(const bsl::weak_ptr<StreamSocket>& socket,
const bsl::shared_ptr<ntci::Resolver>& resolver,
const ntsa::Endpoint& endpoint,
const ntca::GetEndpointEvent& getEndpointEvent,
bsl::size_t connectAttempts)
{
const bsl::shared_ptr<StreamSocket> strongRef = socket.lock();
if (strongRef) {
strongRef->processRemoteEndpointResolution(resolver,
endpoint,
getEndpointEvent,
connectAttempts);
}
}
};

ntsa::Error error;

ntcs::ObserverRef<ntci::Resolver> resolverRef(&d_resolver);
Expand All @@ -3415,8 +3431,8 @@ ntsa::Error StreamSocket::privateRetryConnectToName(

ntci::GetEndpointCallback getEndpointCallback =
resolverRef->createGetEndpointCallback(
NTCCFG_BIND(&StreamSocket::processRemoteEndpointResolution,
self,
NTCCFG_BIND(&WeakBinder::invoke,
this->weak_from_this(),
NTCCFG_BIND_PLACEHOLDER_1,
NTCCFG_BIND_PLACEHOLDER_2,
NTCCFG_BIND_PLACEHOLDER_3,
Expand Down
3 changes: 1 addition & 2 deletions groups/ntc/ntcp/ntcp_streamsocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,7 @@ class StreamSocket : public ntci::StreamSocket,
void privateRetryConnect(const bsl::shared_ptr<StreamSocket>& self);

/// Retry connecting to the remote name. Return the error.
ntsa::Error privateRetryConnectToName(
const bsl::shared_ptr<StreamSocket>& self);
ntsa::Error privateRetryConnectToName();

/// Retry connecting to the remote endpoint. Return the error.
ntsa::Error privateRetryConnectToEndpoint(
Expand Down
41 changes: 28 additions & 13 deletions groups/ntc/ntcr/ntcr_streamsocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3992,8 +3992,8 @@ ntsa::Error StreamSocket::privateUpgrade(
bdlf::MemFnUtil::memFn(&StreamSocket::privateEncryptionHandshake,
this);

error = d_encryption_sp->initiateHandshake(
upgradeOptions, handshakeCallback);
error =
d_encryption_sp->initiateHandshake(upgradeOptions, handshakeCallback);
if (error) {
return error;
}
Expand Down Expand Up @@ -4126,17 +4126,32 @@ void StreamSocket::privateRetryConnect(
error = this->privateRetryConnectToEndpoint(self);
}
else {
error = this->privateRetryConnectToName(self);
error = this->privateRetryConnectToName();
}

if (error) {
this->privateFailConnect(self, error, false, false);
}
}

ntsa::Error StreamSocket::privateRetryConnectToName(
const bsl::shared_ptr<StreamSocket>& self)
ntsa::Error StreamSocket::privateRetryConnectToName()
{
struct WeakBinder {
static void invoke(const bsl::weak_ptr<StreamSocket>& socket,
const bsl::shared_ptr<ntci::Resolver>& resolver,
const ntsa::Endpoint& endpoint,
const ntca::GetEndpointEvent& getEndpointEvent,
bsl::size_t connectAttempts)
{
const bsl::shared_ptr<StreamSocket> strongRef = socket.lock();
if (strongRef) {
strongRef->processRemoteEndpointResolution(resolver,
endpoint,
getEndpointEvent,
connectAttempts);
}
}
};
ntsa::Error error;

ntcs::ObserverRef<ntci::Resolver> resolverRef(&d_resolver);
Expand All @@ -4149,8 +4164,8 @@ ntsa::Error StreamSocket::privateRetryConnectToName(

ntci::GetEndpointCallback getEndpointCallback =
resolverRef->createGetEndpointCallback(
NTCCFG_BIND(&StreamSocket::processRemoteEndpointResolution,
self,
NTCCFG_BIND(&WeakBinder::invoke,
this->weak_from_this(),
NTCCFG_BIND_PLACEHOLDER_1,
NTCCFG_BIND_PLACEHOLDER_2,
NTCCFG_BIND_PLACEHOLDER_3,
Expand Down Expand Up @@ -4277,9 +4292,9 @@ ntsa::Error StreamSocket::privateTimestampOutgoingData(

{
ntsa::SocketOption option(d_allocator_p);
error = d_socket_sp->getOption(
&option,
ntsa::SocketOptionType::e_TX_TIMESTAMPING);
error =
d_socket_sp->getOption(&option,
ntsa::SocketOptionType::e_TX_TIMESTAMPING);
if (error) {
if (error != ntsa::Error::e_NOT_IMPLEMENTED) {
NTCI_LOG_TRACE("Failed to get socket option: "
Expand Down Expand Up @@ -4357,9 +4372,9 @@ ntsa::Error StreamSocket::privateTimestampIncomingData(

{
ntsa::SocketOption option(d_allocator_p);
error = d_socket_sp->getOption(
&option,
ntsa::SocketOptionType::e_RX_TIMESTAMPING);
error =
d_socket_sp->getOption(&option,
ntsa::SocketOptionType::e_RX_TIMESTAMPING);
if (error) {
if (error != ntsa::Error::e_NOT_IMPLEMENTED) {
NTCI_LOG_TRACE("Failed to get socket option: "
Expand Down
3 changes: 1 addition & 2 deletions groups/ntc/ntcr/ntcr_streamsocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -522,8 +522,7 @@ class StreamSocket : public ntci::StreamSocket,
void privateRetryConnect(const bsl::shared_ptr<StreamSocket>& self);

/// Retry connecting to the remote name. Return the error.
ntsa::Error privateRetryConnectToName(
const bsl::shared_ptr<StreamSocket>& self);
ntsa::Error privateRetryConnectToName();

/// Retry connecting to the remote endpoint. Return the error.
ntsa::Error privateRetryConnectToEndpoint(
Expand Down

0 comments on commit a9571fa

Please sign in to comment.