Skip to content

Commit

Permalink
Merge pull request #50 from http4s/queue-servlet-reader
Browse files Browse the repository at this point in the history
Use a queue-based reader for ServletInputStream
  • Loading branch information
rossabaker authored Jun 9, 2022
2 parents 6f5026e + cc7c445 commit c0bf9a8
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class AsyncHttp4sServlet[F[_]](
val ctx = servletRequest.startAsync()
ctx.setTimeout(asyncTimeoutMillis)
// Must be done on the container thread for Tomcat's sake when using async I/O.
val bodyWriter = servletIo.initWriter(servletResponse)
val bodyWriter = servletIo.bodyWriter(servletResponse, dispatcher) _
val result = F
.attempt(
toRequest(servletRequest).fold(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class BlockingHttp4sServlet[F[_]] private (
): Unit = {
val result = F
.defer {
val bodyWriter = servletIo.initWriter(servletResponse)
val bodyWriter = servletIo.bodyWriter(servletResponse, dispatcher) _

val render = toRequest(servletRequest).fold(
onParseFailure(_, servletResponse, bodyWriter),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ abstract class Http4sServlet[F[_]](
uri = uri,
httpVersion = version,
headers = toHeaders(req),
body = servletIo.reader(req),
body = servletIo.requestBody(req, dispatcher),
attributes = attributes,
)

Expand Down
160 changes: 160 additions & 0 deletions servlet/src/main/scala/org/http4s/servlet/ServletIo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,50 @@ package org.http4s
package servlet

import cats.effect._
import cats.effect.std.Dispatcher
import cats.effect.std.Queue
import cats.syntax.all._
import fs2._
import org.http4s.internal.bug
import org.log4s.getLogger

import java.util.Arrays
import java.util.concurrent.atomic.AtomicReference
import javax.servlet.ReadListener
import javax.servlet.WriteListener
import javax.servlet.http.HttpServletRequest
import javax.servlet.http.HttpServletResponse
import scala.annotation.nowarn
import scala.annotation.tailrec

/** Determines the mode of I/O used for reading request bodies and writing response bodies.
*/
sealed abstract class ServletIo[F[_]: Async] {
protected[servlet] val F: Async[F] = Async[F]

@deprecated("Prefer requestBody, which has access to a Dispatcher", "0.23.12")
protected[servlet] def reader(servletRequest: HttpServletRequest): EntityBody[F]

@nowarn("cat=deprecation")
def requestBody(
servletRequest: HttpServletRequest,
dispatcher: Dispatcher[F],
): Stream[F, Byte] = {
val _ = dispatcher // unused
reader(servletRequest)
}

/** May install a listener on the servlet response. */
@deprecated("Prefer bodyWriter, which has access to a Dispatcher", "0.23.12")
protected[servlet] def initWriter(servletResponse: HttpServletResponse): BodyWriter[F]

@nowarn("cat=deprecation")
def bodyWriter(servletResponse: HttpServletResponse, dispatcher: Dispatcher[F])(
response: Response[F]
): F[Unit] = {
val _ = dispatcher
initWriter(servletResponse)(response)
}
}

/** Use standard blocking reads and writes.
Expand Down Expand Up @@ -183,6 +206,73 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl
}
}

/* The queue implementation is influenced by ideas in jetty4s
* https://github.com/IndiscriminateCoding/jetty4s/blob/0.0.10/server/src/main/scala/jetty4s/server/HttpResourceHandler.scala
*/
override def requestBody(
servletRequest: HttpServletRequest,
dispatcher: Dispatcher[F],
): Stream[F, Byte] = {
sealed trait Read
final case class Bytes(chunk: Chunk[Byte]) extends Read
case object End extends Read
final case class Error(t: Throwable) extends Read

Stream.eval(F.delay(servletRequest.getInputStream)).flatMap { in =>
Stream.eval(Queue.bounded[F, Read](4)).flatMap { q =>
val readBody = Stream.exec(F.delay(in.setReadListener(new ReadListener {
var buf: Array[Byte] = _
unsafeReplaceBuffer()

def unsafeReplaceBuffer() =
buf = new Array[Byte](chunkSize)

def onDataAvailable(): Unit = {
def loopIfReady =
F.delay(in.isReady()).flatMap {
case true => go
case false => F.unit
}

def go: F[Unit] =
F.delay(in.read(buf)).flatMap {
case len if len == chunkSize =>
// We used the whole buffer. Replace it new before next read.
q.offer(Bytes(Chunk.array(buf))) >> F.delay(unsafeReplaceBuffer()) >> loopIfReady
case len if len >= 0 =>
// Got a partial chunk. Copy it, and reuse the current buffer.
q.offer(Bytes(Chunk.array(Arrays.copyOf(buf, len)))) >> loopIfReady
case _ =>
F.unit
}

unsafeRunAndForget(go)
}

def onAllDataRead(): Unit =
unsafeRunAndForget(q.offer(End))

def onError(t: Throwable): Unit =
unsafeRunAndForget(q.offer(Error(t)))

def unsafeRunAndForget[A](fa: F[A]): Unit =
dispatcher.unsafeRunAndForget(
fa.onError(t => F.delay(logger.error(t)("Error in servlet read listener")))
)
})))

def pullBody: Pull[F, Byte, Unit] =
Pull.eval(q.take).flatMap {
case Bytes(chunk) => Pull.output(chunk) >> pullBody
case End => Pull.done
case Error(t) => Pull.raiseError[F](t)
}

pullBody.stream.concurrently(readBody)
}
}
}

override protected[servlet] def initWriter(
servletResponse: HttpServletResponse
): BodyWriter[F] = {
Expand Down Expand Up @@ -277,4 +367,74 @@ final case class NonBlockingServletIo[F[_]: Async](chunkSize: Int) extends Servl
.drain
}
}

/* The queue implementation is influenced by ideas in jetty4s
* https://github.com/IndiscriminateCoding/jetty4s/blob/0.0.10/server/src/main/scala/jetty4s/server/HttpResourceHandler.scala
*/
override def bodyWriter(
servletResponse: HttpServletResponse,
dispatcher: Dispatcher[F],
)(response: Response[F]): F[Unit] = {
sealed trait Write
final case class Bytes(chunk: Chunk[Byte]) extends Write
case object End extends Write
case object Init extends Write

val autoFlush = response.isChunked

F.delay(servletResponse.getOutputStream).flatMap { out =>
Queue.bounded[F, Write](4).flatMap { q =>
Deferred[F, Either[Throwable, Unit]].flatMap { done =>
val writeBody = F.delay(out.setWriteListener(new WriteListener {
def onWritePossible(): Unit = {
def loopIfReady = F.delay(out.isReady()).flatMap {
case true => go
case false => F.unit
}

def flush =
if (autoFlush) {
F.delay(out.isReady()).flatMap {
case true => F.delay(out.flush()) >> loopIfReady
case false => F.unit
}
} else
loopIfReady

def go: F[Unit] =
q.take.flatMap {
case Bytes(slice: Chunk.ArraySlice[_]) =>
F.delay(
out.write(slice.values.asInstanceOf[Array[Byte]], slice.offset, slice.length)
) >> flush
case Bytes(chunk) =>
F.delay(out.write(chunk.toArray)) >> flush
case End =>
F.delay(out.flush()) >> done.complete(Either.unit).attempt.void
case Init =>
if (autoFlush) flush else go
}

unsafeRunAndForget(go)
}
def onError(t: Throwable): Unit =
unsafeRunAndForget(done.complete(Left(t)))

def unsafeRunAndForget[A](fa: F[A]): Unit =
dispatcher.unsafeRunAndForget(
fa.onError(t => F.delay(logger.error(t)("Error in servlet write listener")))
)
}))

val writes = Stream.emit(Init) ++ response.body.chunks.map(Bytes(_)) ++ Stream.emit(End)

Stream
.eval(writeBody >> done.get.rethrow)
.mergeHaltL(writes.foreach(q.offer))
.compile
.drain
}
}
}
}
}

0 comments on commit c0bf9a8

Please sign in to comment.