diff --git a/core/shared/src/main/scala/fs2/Shared.scala b/core/shared/src/main/scala/fs2/Shared.scala new file mode 100644 index 0000000000..d0c05c0454 --- /dev/null +++ b/core/shared/src/main/scala/fs2/Shared.scala @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2013 Functional Streams for Scala + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of + * this software and associated documentation files (the "Software"), to deal in + * the Software without restriction, including without limitation the rights to + * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of + * the Software, and to permit persons to whom the Software is furnished to do so, + * subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS + * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER + * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +package fs2 + +import cats.effect._ +import cats.syntax.all._ + +sealed trait Shared[F[_], A] { + def resource: Resource[F, A] +} + +object Shared { + def allocate[F[_], A]( + resource: Resource[F, A] + )(implicit F: Concurrent[F]): Resource[F, (Shared[F, A], A)] = { + final case class State(value: A, finalizer: F[Unit], permits: Int) { + def addPermit: State = copy(permits = permits + 1) + def releasePermit: State = copy(permits = permits - 1) + } + + MonadCancel[Resource[F, *]].uncancelable { poll => + for { + underlying <- poll(Resource.eval(resource.allocated)) + state <- Resource.eval(F.ref[Option[State]](Some(State(underlying._1, underlying._2, 0)))) + shared = new Shared[F, A] { + def acquire: F[A] = + state.modify { + case Some(st) => (Some(st.addPermit), F.pure(st.value)) + case None => + (None, F.raiseError[A](new Throwable("finalization has already occurred"))) + }.flatten + + def release: F[Unit] = + state.modify { + case Some(st) if st.permits > 1 => (Some(st.releasePermit), F.unit) + case Some(st) => (None, st.finalizer) + case None => (None, F.raiseError[Unit](new Throwable("can't finalize"))) + }.flatten + + override def resource: Resource[F, A] = + Resource.make(acquire)(_ => release) + } + _ <- shared.resource + } yield shared -> underlying._1 + } + } +} diff --git a/io/src/main/scala/fs2/io/net/Network.scala b/io/src/main/scala/fs2/io/net/Network.scala index 208df73282..ad3e9381c0 100644 --- a/io/src/main/scala/fs2/io/net/Network.scala +++ b/io/src/main/scala/fs2/io/net/Network.scala @@ -141,6 +141,13 @@ object Network { ): Resource[F, (SocketAddress[IpAddress], Stream[F, Socket[F]])] = globalSocketGroup.serverResource(address, port, options) + def serverResourceShared( + address: Option[Host], + port: Option[Port], + options: List[SocketOption] + ): Resource[F, (SocketAddress[IpAddress], Stream[F, Shared[F, Socket[F]]])] = + globalSocketGroup.serverResourceShared(address, port, options) + def openDatagramSocket( address: Option[Host], port: Option[Port], diff --git a/io/src/main/scala/fs2/io/net/Socket.scala b/io/src/main/scala/fs2/io/net/Socket.scala index 0b0a191679..20ac671433 100644 --- a/io/src/main/scala/fs2/io/net/Socket.scala +++ b/io/src/main/scala/fs2/io/net/Socket.scala @@ -24,7 +24,7 @@ package io package net import com.comcast.ip4s.{IpAddress, SocketAddress} -import cats.effect.{Async, Resource} +import cats.effect.Async import cats.effect.std.Semaphore import cats.syntax.all._ @@ -79,12 +79,10 @@ trait Socket[F[_]] { object Socket { private[net] def forAsync[F[_]: Async]( ch: AsynchronousSocketChannel - ): Resource[F, Socket[F]] = - Resource.make { - (Semaphore[F](1), Semaphore[F](1)).mapN { (readSemaphore, writeSemaphore) => - new AsyncSocket[F](ch, readSemaphore, writeSemaphore) - } - }(_ => Async[F].delay(if (ch.isOpen) ch.close else ())) + ): F[Socket[F]] = + (Semaphore[F](1), Semaphore[F](1)).mapN { (readSemaphore, writeSemaphore) => + new AsyncSocket[F](ch, readSemaphore, writeSemaphore) + } private final class AsyncSocket[F[_]]( ch: AsynchronousSocketChannel, diff --git a/io/src/main/scala/fs2/io/net/SocketGroup.scala b/io/src/main/scala/fs2/io/net/SocketGroup.scala index cb92e71f74..4b2e0cfc8c 100644 --- a/io/src/main/scala/fs2/io/net/SocketGroup.scala +++ b/io/src/main/scala/fs2/io/net/SocketGroup.scala @@ -77,6 +77,12 @@ trait SocketGroup[F[_]] { port: Option[Port] = None, options: List[SocketOption] = List.empty ): Resource[F, (SocketAddress[IpAddress], Stream[F, Socket[F]])] + + def serverResourceShared( + address: Option[Host], + port: Option[Port], + options: List[SocketOption] + ): Resource[F, (SocketAddress[IpAddress], Stream[F, Shared[F, Socket[F]]])] } private[net] object SocketGroup { @@ -114,7 +120,9 @@ private[net] object SocketGroup { } } - Resource.eval(setup.flatMap(connect)).flatMap(Socket.forAsync(_)) + Resource + .make(setup.flatMap(connect))(ch => Async[F].delay(if (ch.isOpen()) ch.close() else ())) + .evalMap(Socket.forAsync(_)) } def server( @@ -136,8 +144,16 @@ private[net] object SocketGroup { address: Option[Host], port: Option[Port], options: List[SocketOption] - ): Resource[F, (SocketAddress[IpAddress], Stream[F, Socket[F]])] = { + ): Resource[F, (SocketAddress[IpAddress], Stream[F, Socket[F]])] = + serverResourceShared(address, port, options).map { case (addr, clients) => + (addr, clients.flatMap(shared => Stream.resource(shared.resource))) + } + def serverResourceShared( + address: Option[Host], + port: Option[Port], + options: List[SocketOption] + ): Resource[F, (SocketAddress[IpAddress], Stream[F, Shared[F, Socket[F]]])] = { val setup: F[AsynchronousServerSocketChannel] = address.traverse(_.resolve[F]).flatMap { addr => Async[F].delay { @@ -161,24 +177,29 @@ private[net] object SocketGroup { def acceptIncoming( sch: AsynchronousServerSocketChannel - ): Stream[F, Socket[F]] = { - def go: Stream[F, Socket[F]] = { - def acceptChannel: F[AsynchronousSocketChannel] = - Async[F].async_[AsynchronousSocketChannel] { cb => - sch.accept( - null, - new CompletionHandler[AsynchronousSocketChannel, Void] { - def completed(ch: AsynchronousSocketChannel, attachment: Void): Unit = - cb(Right(ch)) - def failed(rsn: Throwable, attachment: Void): Unit = - cb(Left(rsn)) + ): Stream[F, Shared[F, Socket[F]]] = { + def go: Stream[F, Shared[F, Socket[F]]] = { + def acceptChannel: Resource[F, AsynchronousSocketChannel] = + Resource.makeFull[F, AsynchronousSocketChannel] { poll => + poll { + Async[F].async_[AsynchronousSocketChannel] { cb => + sch.accept( + null, + new CompletionHandler[AsynchronousSocketChannel, Void] { + def completed(ch: AsynchronousSocketChannel, attachment: Void): Unit = + cb(Right(ch)) + def failed(rsn: Throwable, attachment: Void): Unit = + cb(Left(rsn)) + } + ) } - ) - } + } + }(ch => Async[F].delay(if (ch.isOpen()) ch.close() else ())) - Stream.eval(acceptChannel.attempt).flatMap { - case Left(_) => Stream.empty[F] - case Right(accepted) => Stream.resource(Socket.forAsync(accepted)) + val sharedSocket = Shared.allocate(acceptChannel.evalMap(Socket.forAsync(_))) + Stream.resource(sharedSocket.attempt).flatMap { + case Left(_) => Stream.empty[F] + case Right((shared, _)) => Stream(shared) } ++ go }