Skip to content

Commit

Permalink
Merge pull request #344 from http4s/pr/fix-closure-receive
Browse files Browse the repository at this point in the history
Fix `receive` on websocket closure
  • Loading branch information
armanbilge authored Dec 9, 2023
2 parents 8aa3fa1 + aa8d05d commit 5825600
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 3 deletions.
4 changes: 4 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ Global / fileServicePort := {
import cats.effect.IO
import cats.effect.unsafe.implicits.global
import com.comcast.ip4s._
import fs2.Stream
import org.http4s._
import org.http4s.websocket._
import org.http4s.dsl.io._
import org.http4s.ember.server.EmberServerBuilder
import org.http4s.server.staticcontent._
Expand All @@ -69,6 +71,8 @@ Global / fileServicePort := {
wsb.build(identity)
case Method.GET -> Root / "slows" =>
IO.sleep(3.seconds) *> wsb.build(identity)
case Method.GET -> Root / "hello-goodbye" =>
wsb.build(in => Stream(WebSocketFrame.Text("hello")).concurrently(in.drain))
case req =>
fileService[IO](FileService.Config[IO](".")).orNotFound.run(req).map { res =>
// TODO find out why mime type is not auto-inferred
Expand Down
19 changes: 16 additions & 3 deletions dom/src/main/scala/org/http4s/dom/WebSocketClient.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import cats.effect.std.Mutex
import cats.effect.std.Queue
import cats.effect.syntax.all._
import cats.syntax.all._
import fs2.Chunk
import fs2.Stream
import org.http4s.Method
import org.http4s.client.websocket.WSClientHighLevel
Expand Down Expand Up @@ -132,11 +133,23 @@ object WebSocketClient {
(close: DeferredSource[F, CloseEvent]).map(e => WSFrame.Close(e.code, e.reason))

def receive: F[Option[WSDataFrame]] =
mutex.lock.surround(OptionT(messages.take).map(decodeMessage).value)
close
.tryGet
.map(_.isDefined)
.ifM(
OptionT(messages.tryTake.map(_.flatten)).map(decodeMessage).value,
OptionT(mutex.lock.surround(messages.take)).map(decodeMessage).value
)

override def receiveStream: Stream[F, WSDataFrame] =
Stream.resource(mutex.lock) >>
Stream.fromQueueNoneTerminated(messages).map(decodeMessage)
Stream
.eval(close.tryGet.map(_.isDefined))
.ifM(
Stream.evalUnChunk(
messages.tryTakeN(None).map(m => Chunk.from(m.flatMap(_.toList)))),
Stream.resource(mutex.lock) >> Stream.fromQueueNoneTerminated(messages)
)
.map(decodeMessage)

private def decodeMessage(e: MessageEvent): WSDataFrame =
e.data match {
Expand Down
14 changes: 14 additions & 0 deletions testsBrowser/src/test/scala/org/http4s/dom/WebSocketSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,18 @@ class WebSocketSuite extends CatsEffectSuite {
}
}

test("receive returns None when connection closes") {
WebSocketClient[IO]
.connectHighLevel(
WSRequest(
Uri.fromString(s"ws://localhost:${fileServicePort}/hello-goodbye").toOption.get)
)
.use { conn =>
conn.receive.assertEquals(Some(WSFrame.Text("hello"))) *>
conn.receive.assertEquals(None) *>
conn.receive.assertEquals(None) *>
conn.receiveStream.compile.toList.assertEquals(Nil)
}
}

}

0 comments on commit 5825600

Please sign in to comment.