|
| 1 | +/* |
| 2 | + * Copyright 2021 http4s.org |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +package org.http4s.dom |
| 18 | + |
| 19 | +import cats.Foldable |
| 20 | +import cats.data.OptionT |
| 21 | +import cats.effect.kernel.Async |
| 22 | +import cats.effect.kernel.DeferredSource |
| 23 | +import cats.effect.kernel.Resource |
| 24 | +import cats.effect.std.Dispatcher |
| 25 | +import cats.effect.std.Queue |
| 26 | +import cats.effect.std.Semaphore |
| 27 | +import cats.effect.syntax.all._ |
| 28 | +import cats.syntax.all._ |
| 29 | +import fs2.INothing |
| 30 | +import fs2.Stream |
| 31 | +import org.http4s.Method |
| 32 | +import org.http4s.client.websocket.WSClientHighLevel |
| 33 | +import org.http4s.client.websocket.WSConnectionHighLevel |
| 34 | +import org.http4s.client.websocket.WSDataFrame |
| 35 | +import org.http4s.client.websocket.WSFrame |
| 36 | +import org.http4s.client.websocket.WSRequest |
| 37 | +import org.scalajs.dom.CloseEvent |
| 38 | +import org.scalajs.dom.MessageEvent |
| 39 | +import org.scalajs.dom.WebSocket |
| 40 | +import org.typelevel.ci._ |
| 41 | +import scodec.bits.ByteVector |
| 42 | + |
| 43 | +import scala.scalajs.js |
| 44 | +import scala.scalajs.js.JSConverters._ |
| 45 | + |
| 46 | +object WebSocketClient { |
| 47 | + |
| 48 | + def apply[F[_]](implicit F: Async[F]): WSClientHighLevel[F] = new WSClientHighLevel[F] { |
| 49 | + def connectHighLevel(request: WSRequest): Resource[F, WSConnectionHighLevel[F]] = |
| 50 | + for { |
| 51 | + dispatcher <- Dispatcher[F] |
| 52 | + messages <- Queue.unbounded[F, Option[MessageEvent]].toResource |
| 53 | + semaphore <- Semaphore[F](1).toResource |
| 54 | + error <- F.deferred[Either[Throwable, INothing]].toResource |
| 55 | + close <- F.deferred[CloseEvent].toResource |
| 56 | + ws <- Resource.makeCase { |
| 57 | + F.async_[WebSocket] { cb => |
| 58 | + if (request.method != Method.GET) |
| 59 | + cb(Left(new IllegalArgumentException("Must be GET Request"))) |
| 60 | + |
| 61 | + val protocols = request |
| 62 | + .headers |
| 63 | + .get(ci"Sec-WebSocket-Protocol") |
| 64 | + .toList |
| 65 | + .flatMap(_.toList.map(_.value)) |
| 66 | + |
| 67 | + val ws = new WebSocket(request.uri.renderString, protocols.toJSArray) |
| 68 | + ws.binaryType = "arraybuffer" // the default is blob |
| 69 | + |
| 70 | + ws.onopen = { _ => |
| 71 | + ws.onerror = // replace the error handler |
| 72 | + e => |
| 73 | + dispatcher.unsafeRunAndForget(error.complete(Left(js.JavaScriptException(e)))) |
| 74 | + cb(Right(ws)) |
| 75 | + } |
| 76 | + |
| 77 | + ws.onerror = e => cb(Left(js.JavaScriptException(e))) |
| 78 | + ws.onmessage = e => dispatcher.unsafeRunAndForget(messages.offer(Some(e))) |
| 79 | + ws.onclose = |
| 80 | + e => dispatcher.unsafeRunAndForget(messages.offer(None) *> close.complete(e)) |
| 81 | + } |
| 82 | + } { |
| 83 | + case (ws, exitCase) => |
| 84 | + val reason = exitCase match { |
| 85 | + case Resource.ExitCase.Succeeded => |
| 86 | + None |
| 87 | + case Resource.ExitCase.Errored(ex) => |
| 88 | + val reason = ex.toString |
| 89 | + // reason must be no longer than 123 bytes of UTF-8 text |
| 90 | + // UTF-8 character is max 4 bytes so we can fast-path |
| 91 | + if (reason.length <= 30 || reason.getBytes.length <= 123) |
| 92 | + Some(reason) |
| 93 | + else |
| 94 | + None |
| 95 | + case Resource.ExitCase.Canceled => |
| 96 | + Some("canceled") |
| 97 | + } |
| 98 | + |
| 99 | + val shutdown = F |
| 100 | + .async_[CloseEvent] { cb => |
| 101 | + ws.onerror = e => cb(Left(js.JavaScriptException(e))) |
| 102 | + ws.onclose = e => cb(Right(e)) |
| 103 | + reason match { // 1000 "normal closure" is only code supported in browser |
| 104 | + case Some(reason) => ws.close(1000, reason) |
| 105 | + case None => ws.close(1000) |
| 106 | + } |
| 107 | + } |
| 108 | + .flatMap(close.complete(_)) *> messages.offer(None) |
| 109 | + |
| 110 | + F.delay(ws.readyState).flatMap { |
| 111 | + case 0 | 1 => shutdown // CONNECTING | OPEN |
| 112 | + case 2 => close.get.void // CLOSING |
| 113 | + case 3 => F.unit // CLOSED |
| 114 | + case s => F.raiseError(new IllegalStateException(s"WebSocket.readyState: $s")) |
| 115 | + } |
| 116 | + } |
| 117 | + } yield new WSConnectionHighLevel[F] { |
| 118 | + |
| 119 | + def closeFrame: DeferredSource[F, WSFrame.Close] = |
| 120 | + (close: DeferredSource[F, CloseEvent]).map(e => WSFrame.Close(e.code, e.reason)) |
| 121 | + |
| 122 | + def receive: F[Option[WSDataFrame]] = semaphore |
| 123 | + .permit |
| 124 | + .surround(OptionT(messages.take).map(decodeMessage).value) |
| 125 | + .race(error.get.rethrow) |
| 126 | + .map(_.merge) |
| 127 | + |
| 128 | + override def receiveStream: Stream[F, WSDataFrame] = |
| 129 | + Stream |
| 130 | + .resource(semaphore.permit) |
| 131 | + .flatMap(_ => Stream.fromQueueNoneTerminated(messages)) |
| 132 | + .map(decodeMessage) |
| 133 | + .concurrently(Stream.exec(error.get.rethrow.widen)) |
| 134 | + |
| 135 | + private def decodeMessage(e: MessageEvent): WSDataFrame = |
| 136 | + e.data match { |
| 137 | + case s: String => WSFrame.Text(s) |
| 138 | + case b: js.typedarray.ArrayBuffer => |
| 139 | + WSFrame.Binary(ByteVector.fromJSArrayBuffer(b)) |
| 140 | + case _ => // this should never happen |
| 141 | + throw new RuntimeException |
| 142 | + } |
| 143 | + |
| 144 | + override def sendText(text: String): F[Unit] = |
| 145 | + errorOr(F.delay(ws.send(text))) |
| 146 | + |
| 147 | + override def sendBinary(bytes: ByteVector): F[Unit] = |
| 148 | + errorOr(F.delay(ws.send(bytes.toJSArrayBuffer))) |
| 149 | + |
| 150 | + def send(wsf: WSDataFrame): F[Unit] = |
| 151 | + wsf match { |
| 152 | + case WSFrame.Text(data, true) => sendText(data) |
| 153 | + case WSFrame.Binary(data, true) => sendBinary(data) |
| 154 | + case _ => |
| 155 | + F.raiseError(new IllegalArgumentException("DataFrames cannot be fragmented")) |
| 156 | + } |
| 157 | + |
| 158 | + private def errorOr(fu: F[Unit]): F[Unit] = error.tryGet.flatMap { |
| 159 | + case Some(error) => F.fromEither[Unit](error) |
| 160 | + case None => fu |
| 161 | + } |
| 162 | + |
| 163 | + def sendMany[G[_]: Foldable, A <: WSDataFrame](wsfs: G[A]): F[Unit] = |
| 164 | + wsfs.foldMapM(send(_)) |
| 165 | + |
| 166 | + def subprotocol: Option[String] = Option(ws.protocol).filter(_.nonEmpty) |
| 167 | + } |
| 168 | + } |
| 169 | + |
| 170 | +} |
0 commit comments