diff --git a/servlet/src/main/scala/org/http4s/servlet/ServletIo.scala b/servlet/src/main/scala/org/http4s/servlet/ServletIo.scala index e90c334d..051db1f0 100644 --- a/servlet/src/main/scala/org/http4s/servlet/ServletIo.scala +++ b/servlet/src/main/scala/org/http4s/servlet/ServletIo.scala @@ -22,45 +22,25 @@ import cats.effect.std.Dispatcher import cats.effect.std.Queue import cats.syntax.all._ import fs2._ -import org.http4s.internal.bug import org.log4s.getLogger import java.util.Arrays -import java.util.concurrent.atomic.AtomicReference import javax.servlet.ReadListener import javax.servlet.WriteListener import javax.servlet.http.HttpServletRequest import javax.servlet.http.HttpServletResponse -import scala.annotation.nowarn -import scala.annotation.tailrec /** Determines the mode of I/O used for reading request bodies and writing response bodies. */ sealed trait ServletIo[F[_]] { - - @deprecated("Prefer requestBody, which has access to a Dispatcher", "0.23.12") - protected[servlet] def reader(servletRequest: HttpServletRequest): EntityBody[F] - - @nowarn("cat=deprecation") def requestBody( servletRequest: HttpServletRequest, dispatcher: Dispatcher[F], - ): Stream[F, Byte] = { - val _ = dispatcher // unused - reader(servletRequest) - } + ): Stream[F, Byte] - /** May install a listener on the servlet response. */ - @deprecated("Prefer bodyWriter, which has access to a Dispatcher", "0.23.12") - protected[servlet] def initWriter(servletResponse: HttpServletResponse): BodyWriter[F] - - @nowarn("cat=deprecation") def bodyWriter(servletResponse: HttpServletResponse, dispatcher: Dispatcher[F])( response: Response[F] - ): F[Unit] = { - val _ = dispatcher - initWriter(servletResponse)(response) - } + ): F[Unit] } /** Use standard blocking reads and writes. @@ -69,12 +49,16 @@ sealed trait ServletIo[F[_]] { * require a larger request thread pool for the same load. */ final case class BlockingServletIo[F[_]](chunkSize: Int)(implicit F: Sync[F]) extends ServletIo[F] { - override protected[servlet] def reader(servletRequest: HttpServletRequest): EntityBody[F] = + override def requestBody( + servletRequest: HttpServletRequest, + dispatcher: Dispatcher[F], + ): EntityBody[F] = io.readInputStream[F](F.pure(servletRequest.getInputStream), chunkSize) - override protected[servlet] def initWriter( - servletResponse: HttpServletResponse - ): BodyWriter[F] = { (response: Response[F]) => + override def bodyWriter( + servletResponse: HttpServletResponse, + dispatcher: Dispatcher[F], + )(response: Response[F]): F[Unit] = { val out = servletResponse.getOutputStream val flush = response.isChunked response.body.chunks @@ -103,109 +87,6 @@ final case class NonBlockingServletIo[F[_]](chunkSize: Int)(implicit F: Async[F] extends ServletIo[F] { private[this] val logger = getLogger - private[this] def rightSome[A](a: A) = Right(Some(a)) - private[this] val rightNone = Right(None) - - override protected[servlet] def reader(servletRequest: HttpServletRequest): EntityBody[F] = - Stream.suspend { - sealed trait State - case object Init extends State - case object Ready extends State - case object Complete extends State - sealed case class Errored(t: Throwable) extends State - sealed case class Blocked(cb: Callback[Option[Chunk[Byte]]]) extends State - - val in = servletRequest.getInputStream - - val state = new AtomicReference[State](Init) - - def read(cb: Callback[Option[Chunk[Byte]]]): Unit = { - val buf = new Array[Byte](chunkSize) - val len = in.read(buf) - - if (len == chunkSize) cb(rightSome(Chunk.array(buf))) - else if (len < 0) { - state.compareAndSet(Ready, Complete) // will not overwrite an `Errored` state - cb(rightNone) - } else if (len == 0) { - logger.warn("Encountered a read of length 0") - cb(rightSome(Chunk.empty)) - } else cb(rightSome(Chunk.array(buf, 0, len))) - } - - if (in.isFinished) Stream.empty - else { - // This effect sets the callback and waits for the first bytes to read - val registerRead = - // Shift execution to a different EC - F.async_[Option[Chunk[Byte]]] { cb => - if (!state.compareAndSet(Init, Blocked(cb))) - cb(Left(bug("Shouldn't have gotten here: I should be the first to set a state"))) - else - in.setReadListener( - new ReadListener { - override def onDataAvailable(): Unit = - state.getAndSet(Ready) match { - case Blocked(cb) => read(cb) - case _ => () - } - - override def onError(t: Throwable): Unit = - state.getAndSet(Errored(t)) match { - case Blocked(cb) => cb(Left(t)) - case _ => () - } - - override def onAllDataRead(): Unit = - state.getAndSet(Complete) match { - case Blocked(cb) => cb(rightNone) - case _ => () - } - } - ) - } - - val readStream = Stream.eval(registerRead) ++ Stream - .repeatEval( // perform the initial set then transition into normal read mode - // Shift execution to a different EC - F.async_[Option[Chunk[Byte]]] { cb => - @tailrec - def go(): Unit = - state.get match { - case Ready if in.isReady => read(cb) - - case Ready => // wasn't ready so set the callback and double check that we're still not ready - val blocked = Blocked(cb) - if (state.compareAndSet(Ready, blocked)) - if (in.isReady && state.compareAndSet(blocked, Ready)) - read(cb) // data became available while we were setting up the callbacks - else { - /* NOOP: our callback is either still needed or has been handled */ - } - else go() // Our state transitioned so try again. - - case Complete => cb(rightNone) - - case Errored(t) => cb(Left(t)) - - // This should never happen so throw a huge fit if it does. - case Blocked(c1) => - val t = bug("Two callbacks found in read state") - cb(Left(t)) - c1(Left(t)) - logger.error(t)("This should never happen. Please report.") - throw t - - case Init => - cb(Left(bug("Should have left Init state by now"))) - } - go() - } - ) - readStream.unNoneTerminate.flatMap(Stream.chunk) - } - } - /* The queue implementation is influenced by ideas in jetty4s * https://github.com/IndiscriminateCoding/jetty4s/blob/0.0.10/server/src/main/scala/jetty4s/server/HttpResourceHandler.scala */ @@ -273,101 +154,6 @@ final case class NonBlockingServletIo[F[_]](chunkSize: Int)(implicit F: Async[F] } } - override protected[servlet] def initWriter( - servletResponse: HttpServletResponse - ): BodyWriter[F] = { - sealed trait State - case object Init extends State - case object Ready extends State - sealed case class Errored(t: Throwable) extends State - sealed case class Blocked(cb: Callback[Chunk[Byte] => Unit]) extends State - sealed case class AwaitingLastWrite(cb: Callback[Unit]) extends State - - val out = servletResponse.getOutputStream - /* - * If onWritePossible isn't called at least once, Tomcat begins to throw - * NullPointerExceptions from NioEndpoint$SocketProcessor.doRun under - * load. The Init state means we block callbacks until the WriteListener - * fires. - */ - val state = new AtomicReference[State](Init) - @volatile var autoFlush = false - - val writeChunk = Right { (chunk: Chunk[Byte]) => - if (!out.isReady) - logger.error(s"writeChunk called while out was not ready, bytes will be lost!") - else { - out.write(chunk.toArray) - if (autoFlush && out.isReady) - out.flush() - } - } - - val listener = new WriteListener { - override def onWritePossible(): Unit = - state.getAndSet(Ready) match { - case Blocked(cb) => cb(writeChunk) - case AwaitingLastWrite(cb) => cb(Right(())) - case old @ _ => () - } - - override def onError(t: Throwable): Unit = - state.getAndSet(Errored(t)) match { - case Blocked(cb) => cb(Left(t)) - case AwaitingLastWrite(cb) => cb(Left(t)) - case _ => () - } - } - /* - * This must be set on the container thread in Tomcat, or onWritePossible - * will not be invoked. This side effect needs to run between the acquisition - * of the servletResponse and the calculation of the http4s Response. - */ - out.setWriteListener(listener) - - val awaitLastWrite = Stream.exec { - // Shift execution to a different EC - F.async_[Unit] { cb => - state.getAndSet(AwaitingLastWrite(cb)) match { - case Ready if out.isReady => cb(Right(())) - case _ => () - } - } - } - - val chunkHandler = - F.async_[Chunk[Byte] => Unit] { cb => - val blocked = Blocked(cb) - state.getAndSet(blocked) match { - case Ready if out.isReady => - if (state.compareAndSet(blocked, Ready)) - cb(writeChunk) - case e @ Errored(t) => - if (state.compareAndSet(blocked, e)) - cb(Left(t)) - case _ => - () - } - } - - def flushPrelude = - if (autoFlush) - chunkHandler.map(_(Chunk.empty[Byte])) - else - F.unit - - { (response: Response[F]) => - if (response.isChunked) - autoFlush = true - flushPrelude *> - response.body.chunks - .evalMap(chunk => chunkHandler.map(_(chunk))) - .append(awaitLastWrite) - .compile - .drain - } - } - /* The queue implementation is influenced by ideas in jetty4s * https://github.com/IndiscriminateCoding/jetty4s/blob/0.0.10/server/src/main/scala/jetty4s/server/HttpResourceHandler.scala */ diff --git a/servlet/src/test/scala/org/http4s/servlet/ServletIoSuite.scala b/servlet/src/test/scala/org/http4s/servlet/ServletIoSuite.scala deleted file mode 100644 index 462489c1..00000000 --- a/servlet/src/test/scala/org/http4s/servlet/ServletIoSuite.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright 2013 http4s.org - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.http4s -package servlet - -import cats.effect.IO -import munit.CatsEffectSuite - -import java.io.ByteArrayInputStream -import java.nio.charset.StandardCharsets.UTF_8 -import javax.servlet._ - -class ServletIoSuite extends CatsEffectSuite { - - test( - "NonBlockingServletIo should decode request body which is smaller than chunk size correctly" - ) { - val request = - HttpServletRequestStub(inputStream = new TestServletInputStream("test".getBytes(UTF_8))) - - val io = NonBlockingServletIo[IO](10) - val body = io.reader(request) - body.compile.toList.map(bytes => new String(bytes.toArray, UTF_8)).assertEquals("test") - } - - test( - "NonBlockingServletIo should decode request body which is bigger than chunk size correctly" - ) { - val request = HttpServletRequestStub(inputStream = - new TestServletInputStream("testtesttest".getBytes(UTF_8)) - ) - - val io = NonBlockingServletIo[IO](10) - val body = io.reader(request) - body.compile.toList.map(bytes => new String(bytes.toArray, UTF_8)).assertEquals("testtesttest") - } - - class TestServletInputStream(body: Array[Byte]) extends ServletInputStream { - private var readListener: ReadListener = null - private val in = new ByteArrayInputStream(body) - - override def isReady: Boolean = true - - override def isFinished: Boolean = in.available() == 0 - - override def setReadListener(readListener: ReadListener): Unit = { - this.readListener = readListener - readListener.onDataAvailable() - } - - override def read(): Int = { - val result = in.read() - if (in.available() == 0) - readListener.onAllDataRead() - result - } - } -}