Skip to content

Commit

Permalink
Merge pull request #2971 from armanbilge/fix/leaked-socket-on-open-error
Browse files Browse the repository at this point in the history
Prevent socket leaks due to post-open exceptions
  • Loading branch information
mpilquist committed Sep 6, 2022
2 parents 8f54106 + 44738aa commit 9b1f8b1
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 30 deletions.
16 changes: 11 additions & 5 deletions io/js/src/main/scala/fs2/io/net/SocketGroupPlatform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,18 @@ private[net] trait SocketGroupCompanionPlatform { self: SocketGroup.type =>
options: List[SocketOption]
): Resource[F, Socket[F]] =
(for {
sock <- F
.delay(
new facade.net.Socket(new facade.net.SocketOptions { allowHalfOpen = true })
sock <- Resource
.make(
F.delay(
new facade.net.Socket(new facade.net.SocketOptions { allowHalfOpen = true })
)
)(sock =>
F.delay {
if (!sock.destroyed)
sock.destroy()
}
)
.flatTap(setSocketOptions(options))
.toResource
.evalTap(setSocketOptions(options))
socket <- Socket.forAsync(sock)
_ <- F
.async[Unit] { cb =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,22 @@ private[unixsocket] trait UnixSocketsCompanionPlatform {

override def client(address: UnixSocketAddress): Resource[F, Socket[F]] =
Resource
.eval(for {
socket <- F.delay(
.make(
F.delay(
new facade.net.Socket(new facade.net.SocketOptions { allowHalfOpen = true })
)
_ <- F.async_[Unit] { cb =>
)(socket =>
F.delay {
if (!socket.destroyed)
socket.destroy()
}
)
.evalTap { socket =>
F.async_[Unit] { cb =>
socket.connect(address.path, () => cb(Right(())))
()
}
} yield socket)
}
.flatMap(Socket.forAsync[F])

override def server(
Expand Down
46 changes: 25 additions & 21 deletions io/jvm/src/main/scala/fs2/io/net/SocketGroupPlatform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,13 @@ private[net] trait SocketGroupCompanionPlatform { self: SocketGroup.type =>
options: List[SocketOption]
): Resource[F, Socket[F]] = {
def setup: Resource[F, AsynchronousSocketChannel] =
Resource.make(Async[F].delay {
val ch =
AsynchronousChannelProvider.provider.openAsynchronousSocketChannel(channelGroup)
options.foreach(opt => ch.setOption(opt.key, opt.value))
ch
})(ch => Async[F].delay(if (ch.isOpen) ch.close else ()))
Resource
.make(
Async[F].delay(
AsynchronousChannelProvider.provider.openAsynchronousSocketChannel(channelGroup)
)
)(ch => Async[F].delay(if (ch.isOpen) ch.close else ()))
.evalTap(ch => Async[F].delay(options.foreach(opt => ch.setOption(opt.key, opt.value))))

def connect(ch: AsynchronousSocketChannel): F[AsynchronousSocketChannel] =
to.resolve[F].flatMap { ip =>
Expand All @@ -80,24 +81,27 @@ private[net] trait SocketGroupCompanionPlatform { self: SocketGroup.type =>
options: List[SocketOption]
): Resource[F, (SocketAddress[IpAddress], Stream[F, Socket[F]])] = {

val setup: F[AsynchronousServerSocketChannel] =
address.traverse(_.resolve[F]).flatMap { addr =>
Async[F].delay {
val ch =
AsynchronousChannelProvider.provider.openAsynchronousServerSocketChannel(channelGroup)
ch.bind(
new InetSocketAddress(
addr.map(_.toInetAddress).orNull,
port.map(_.value).getOrElse(0)
val setup: Resource[F, AsynchronousServerSocketChannel] =
Resource.eval(address.traverse(_.resolve[F])).flatMap { addr =>
Resource
.make(
Async[F].delay(
AsynchronousChannelProvider.provider
.openAsynchronousServerSocketChannel(channelGroup)
)
)(sch => Async[F].delay(if (sch.isOpen) sch.close()))
.evalTap(ch =>
Async[F].delay(
ch.bind(
new InetSocketAddress(
addr.map(_.toInetAddress).orNull,
port.map(_.value).getOrElse(0)
)
)
)
)
ch
}
}

def cleanup(sch: AsynchronousServerSocketChannel): F[Unit] =
Async[F].delay(if (sch.isOpen) sch.close())

def acceptIncoming(
sch: AsynchronousServerSocketChannel
): Stream[F, Socket[F]] = {
Expand Down Expand Up @@ -137,7 +141,7 @@ private[net] trait SocketGroupCompanionPlatform { self: SocketGroup.type =>
}
}

Resource.make(setup)(cleanup).map { sch =>
setup.map { sch =>
val jLocalAddress = sch.getLocalAddress.asInstanceOf[java.net.InetSocketAddress]
val localAddress = SocketAddress.fromInetSocketAddress(jLocalAddress)
(localAddress, acceptIncoming(sch))
Expand Down

0 comments on commit 9b1f8b1

Please sign in to comment.