diff --git a/build.sbt b/build.sbt index 41874228..1785bb58 100644 --- a/build.sbt +++ b/build.sbt @@ -50,7 +50,9 @@ Global / fileServicePort := { import cats.effect.IO import cats.effect.unsafe.implicits.global import com.comcast.ip4s._ + import fs2.Stream import org.http4s._ + import org.http4s.websocket._ import org.http4s.dsl.io._ import org.http4s.ember.server.EmberServerBuilder import org.http4s.server.staticcontent._ @@ -69,6 +71,8 @@ Global / fileServicePort := { wsb.build(identity) case Method.GET -> Root / "slows" => IO.sleep(3.seconds) *> wsb.build(identity) + case Method.GET -> Root / "hello-goodbye" => + wsb.build(in => Stream(WebSocketFrame.Text("hello")).concurrently(in.drain)) case req => fileService[IO](FileService.Config[IO](".")).orNotFound.run(req).map { res => // TODO find out why mime type is not auto-inferred diff --git a/dom/src/main/scala/org/http4s/dom/WebSocketClient.scala b/dom/src/main/scala/org/http4s/dom/WebSocketClient.scala index bfb9c244..e942a03a 100644 --- a/dom/src/main/scala/org/http4s/dom/WebSocketClient.scala +++ b/dom/src/main/scala/org/http4s/dom/WebSocketClient.scala @@ -26,6 +26,7 @@ import cats.effect.std.Mutex import cats.effect.std.Queue import cats.effect.syntax.all._ import cats.syntax.all._ +import fs2.Chunk import fs2.Stream import org.http4s.Method import org.http4s.client.websocket.WSClientHighLevel @@ -132,11 +133,23 @@ object WebSocketClient { (close: DeferredSource[F, CloseEvent]).map(e => WSFrame.Close(e.code, e.reason)) def receive: F[Option[WSDataFrame]] = - mutex.lock.surround(OptionT(messages.take).map(decodeMessage).value) + close + .tryGet + .map(_.isDefined) + .ifM( + OptionT(messages.tryTake.map(_.flatten)).map(decodeMessage).value, + OptionT(mutex.lock.surround(messages.take)).map(decodeMessage).value + ) override def receiveStream: Stream[F, WSDataFrame] = - Stream.resource(mutex.lock) >> - Stream.fromQueueNoneTerminated(messages).map(decodeMessage) + Stream + .eval(close.tryGet.map(_.isDefined)) + .ifM( + Stream.evalUnChunk( + messages.tryTakeN(None).map(m => Chunk.from(m.flatMap(_.toList)))), + Stream.resource(mutex.lock) >> Stream.fromQueueNoneTerminated(messages) + ) + .map(decodeMessage) private def decodeMessage(e: MessageEvent): WSDataFrame = e.data match { diff --git a/testsBrowser/src/test/scala/org/http4s/dom/WebSocketSuite.scala b/testsBrowser/src/test/scala/org/http4s/dom/WebSocketSuite.scala index a102c5f6..1fd63e6c 100644 --- a/testsBrowser/src/test/scala/org/http4s/dom/WebSocketSuite.scala +++ b/testsBrowser/src/test/scala/org/http4s/dom/WebSocketSuite.scala @@ -70,4 +70,18 @@ class WebSocketSuite extends CatsEffectSuite { } } + test("receive returns None when connection closes") { + WebSocketClient[IO] + .connectHighLevel( + WSRequest( + Uri.fromString(s"ws://localhost:${fileServicePort}/hello-goodbye").toOption.get) + ) + .use { conn => + conn.receive.assertEquals(Some(WSFrame.Text("hello"))) *> + conn.receive.assertEquals(None) *> + conn.receive.assertEquals(None) *> + conn.receiveStream.compile.toList.assertEquals(Nil) + } + } + }