diff --git a/test/org/apache/catalina/valves/TestFilterValve.java b/test/org/apache/catalina/valves/TestFilterValve.java new file mode 100644 index 000000000000..dd2d918c5c87 --- /dev/null +++ b/test/org/apache/catalina/valves/TestFilterValve.java @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.catalina.valves; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +import jakarta.servlet.Filter; +import jakarta.servlet.FilterChain; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.ServletResponse; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletRequestWrapper; +import jakarta.servlet.http.HttpServletResponse; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.startup.Tomcat; +import org.apache.catalina.startup.TomcatBaseTest; +import org.apache.tomcat.util.buf.ByteChunk; + +public class TestFilterValve extends TomcatBaseTest { + + + @Test + public void testFilterPassthrough() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + + Tomcat.addServlet(ctx, "hello", new HelloWorldServlet()); + ctx.addServletMappingDecoded("/", "hello"); + + FilterValve valve = new FilterValve(); + valve.setFilterClass(PassthroughFilter.class.getName()); + ctx.getPipeline().addValve(valve); + + tomcat.start(); + + ByteChunk res = new ByteChunk(); + int rc = getUrl("http://localhost:" + getPort(), res, null); + + Assert.assertEquals(HttpServletResponse.SC_OK, rc); + Assert.assertEquals(HelloWorldServlet.RESPONSE_TEXT, res.toString()); + } + + + @Test + public void testFilterBlocks() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + + Tomcat.addServlet(ctx, "hello", new HelloWorldServlet()); + ctx.addServletMappingDecoded("/", "hello"); + + FilterValve valve = new FilterValve(); + valve.setFilterClass(BlockingFilter.class.getName()); + ctx.getPipeline().addValve(valve); + + tomcat.start(); + + ByteChunk res = new ByteChunk(); + int rc = getUrl("http://localhost:" + getPort(), res, null); + + Assert.assertEquals(HttpServletResponse.SC_FORBIDDEN, rc); + } + + @Test + public void testFilterWrappingRequestThrows() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + + Tomcat.addServlet(ctx, "hello", new HelloWorldServlet()); + ctx.addServletMappingDecoded("/", "hello"); + + FilterValve valve = new FilterValve(); + valve.setFilterClass(WrappingFilter.class.getName()); + ctx.getPipeline().addValve(valve); + + tomcat.start(); + + int rc = getUrl("http://localhost:" + getPort(), new ByteChunk(), null); + + Assert.assertEquals(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, rc); + } + + + @Test(expected = LifecycleException.class) + public void testNullFilterClassThrowsOnStart() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + + FilterValve valve = new FilterValve(); + // Do NOT set filterClassName + ctx.getPipeline().addValve(valve); + + tomcat.start(); + } + + + @Test(expected = LifecycleException.class) + public void testInvalidFilterClassThrowsOnStart() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + + FilterValve valve = new FilterValve(); + valve.setFilterClass("com.nonexistent.FakeFilter"); + ctx.getPipeline().addValve(valve); + + tomcat.start(); + } + + + @Test + public void testGetFilterNameReturnsNull() { + FilterValve valve = new FilterValve(); + Assert.assertNull(valve.getFilterName()); + } + + + @Test + public void testInitParams() { + FilterValve valve = new FilterValve(); + + valve.addInitParam("key1", "value1"); + valve.addInitParam("key2", "value2"); + + Assert.assertEquals("value1", valve.getInitParameter("key1")); + Assert.assertEquals("value2", valve.getInitParameter("key2")); + Assert.assertNull(valve.getInitParameter("nonexistent")); + + List names = Collections.list(valve.getInitParameterNames()); + Assert.assertEquals(2, names.size()); + Assert.assertTrue(names.contains("key1")); + Assert.assertTrue(names.contains("key2")); + } + + + @Test + public void testInitParamsEmpty() { + FilterValve valve = new FilterValve(); + + Assert.assertNull(valve.getInitParameter("anything")); + Assert.assertFalse(valve.getInitParameterNames().hasMoreElements()); + } + + + @Test + public void testGetSetFilterClassName() { + FilterValve valve = new FilterValve(); + + Assert.assertNull(valve.getFilterClassName()); + + valve.setFilterClassName("com.example.MyFilter"); + Assert.assertEquals("com.example.MyFilter", valve.getFilterClassName()); + + valve.setFilterClass("com.example.OtherFilter"); + Assert.assertEquals("com.example.OtherFilter", valve.getFilterClassName()); + } + + @Test(expected = IllegalStateException.class) + public void testGetServletContextThrowsBeforeStart() { + FilterValve valve = new FilterValve(); + valve.getServletContext(); + } + + + /** + * A Filter that passes the request through to the next element in the chain. + */ + public static final class PassthroughFilter implements Filter { + + @Override + public void doFilter(ServletRequest request, ServletResponse response, + FilterChain chain) throws IOException, ServletException { + chain.doFilter(request, response); + } + } + + + /** + * A Filter that blocks the request by sending a 403 response without calling chain.doFilter(). + */ + public static final class BlockingFilter implements Filter { + + @Override + public void doFilter(ServletRequest request, ServletResponse response, + FilterChain chain) throws IOException, ServletException { + ((HttpServletResponse) response).sendError(HttpServletResponse.SC_FORBIDDEN); + } + } + + /** + * A Filter that wraps the request before calling chain.doFilter(), which FilterValve explicitly forbids. + */ + public static final class WrappingFilter implements Filter { + + @Override + public void doFilter(ServletRequest request, ServletResponse response, + FilterChain chain) throws IOException, ServletException { + HttpServletRequestWrapper wrapped = new HttpServletRequestWrapper((HttpServletRequest) request); + chain.doFilter(wrapped, response); + } + } + +} diff --git a/test/org/apache/catalina/valves/TestProxyErrorReportValve.java b/test/org/apache/catalina/valves/TestProxyErrorReportValve.java new file mode 100644 index 000000000000..8829fa2d6394 --- /dev/null +++ b/test/org/apache/catalina/valves/TestProxyErrorReportValve.java @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.catalina.valves; + +import java.io.IOException; +import java.io.Serial; + +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.catalina.Context; +import org.apache.catalina.core.StandardHost; +import org.apache.catalina.startup.Tomcat; +import org.apache.catalina.startup.TomcatBaseTest; +import org.apache.tomcat.util.buf.ByteChunk; + +public class TestProxyErrorReportValve extends TomcatBaseTest { + + private static final String PROXY_VALVE = + "org.apache.catalina.valves.ProxyErrorReportValve"; + + + @Test + public void testRedirectMode() throws Exception { + Tomcat tomcat = getTomcatInstance(); + StandardHost host = (StandardHost) tomcat.getHost(); + host.setErrorReportValveClass(PROXY_VALVE); + + Context ctx = getProgrammaticRootContext(); + + Tomcat.addServlet(ctx, "error", new SendErrorServlet( + HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Server broke")); + ctx.addServletMappingDecoded("/", "error"); + + // Register an error page at the Host's error report valve level + // so findErrorPage() returns a URL for the redirect + Tomcat.addServlet(ctx, "errorPage", new ErrorPageServlet()); + ctx.addServletMappingDecoded("/error-page", "errorPage"); + + tomcat.start(); + + ProxyErrorReportValve valve = (ProxyErrorReportValve) host.getPipeline().getFirst(); + valve.setProperty("errorCode." + HttpServletResponse.SC_INTERNAL_SERVER_ERROR, + "http://localhost:" + getPort() + "/error-page"); + + int rc = getUrl("http://localhost:" + getPort(), new ByteChunk(), false); + + Assert.assertEquals(HttpServletResponse.SC_FOUND, rc); + } + + @Test + public void testProxyMode() throws Exception { + Tomcat tomcat = getTomcatInstance(); + StandardHost host = (StandardHost) tomcat.getHost(); + host.setErrorReportValveClass(PROXY_VALVE); + + Context ctx = getProgrammaticRootContext(); + + Tomcat.addServlet(ctx, "error", new SendErrorServlet( + HttpServletResponse.SC_NOT_FOUND, "Not found")); + ctx.addServletMappingDecoded("/", "error"); + + Tomcat.addServlet(ctx, "errorPage", new ErrorPageServlet()); + ctx.addServletMappingDecoded("/error-page", "errorPage"); + + tomcat.start(); + + ProxyErrorReportValve valve = (ProxyErrorReportValve) host.getPipeline().getFirst(); + valve.setUseRedirect(false); + valve.setProperty("errorCode." + HttpServletResponse.SC_NOT_FOUND, + "http://localhost:" + getPort() + "/error-page"); + + ByteChunk res = new ByteChunk(); + int rc = getUrl("http://localhost:" + getPort(), res, null); + + Assert.assertEquals(HttpServletResponse.SC_NOT_FOUND, rc); + Assert.assertTrue(res.toString().contains("ERROR_PAGE_OK")); + } + + + @Test + public void testNoErrorPageFallsBackToSuper() throws Exception { + Tomcat tomcat = getTomcatInstance(); + ((StandardHost) tomcat.getHost()).setErrorReportValveClass(PROXY_VALVE); + + Context ctx = getProgrammaticRootContext(); + + Tomcat.addServlet(ctx, "error", new SendErrorServlet( + HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "No page configured")); + ctx.addServletMappingDecoded("/", "error"); + + tomcat.start(); + + ByteChunk res = new ByteChunk(); + int rc = getUrl("http://localhost:" + getPort(), res, null); + + Assert.assertEquals(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, rc); + + String body = res.toString(); + Assert.assertNotNull(body); + Assert.assertTrue("Should contain HTML error report", + body.contains("html") && + body.contains(String.valueOf(HttpServletResponse.SC_INTERNAL_SERVER_ERROR))); + } + + + @Test + public void testStatusBelow400Ignored() throws Exception { + Tomcat tomcat = getTomcatInstance(); + ((StandardHost) tomcat.getHost()).setErrorReportValveClass(PROXY_VALVE); + + Context ctx = getProgrammaticRootContext(); + + Tomcat.addServlet(ctx, "hello", new HelloWorldServlet()); + ctx.addServletMappingDecoded("/", "hello"); + + tomcat.start(); + + ByteChunk res = new ByteChunk(); + int rc = getUrl("http://localhost:" + getPort(), res, null); + + Assert.assertEquals(HttpServletResponse.SC_OK, rc); + Assert.assertEquals(HelloWorldServlet.RESPONSE_TEXT, res.toString()); + } + + + @Test + public void testStatusNotFound() throws Exception { + Tomcat tomcat = getTomcatInstance(); + ((StandardHost) tomcat.getHost()).setErrorReportValveClass(PROXY_VALVE); + + Context ctx = getProgrammaticRootContext(); + + Tomcat.addServlet(ctx, "notFound", new SendErrorServlet( + HttpServletResponse.SC_NOT_FOUND, "Resource not found")); + ctx.addServletMappingDecoded("/", "notFound"); + + tomcat.start(); + + ByteChunk res = new ByteChunk(); + int rc = getUrl("http://localhost:" + getPort(), res, null); + + Assert.assertEquals(HttpServletResponse.SC_NOT_FOUND, rc); + + String body = res.toString(); + Assert.assertNotNull(body); + Assert.assertTrue("Should contain error report", + body.contains(String.valueOf(HttpServletResponse.SC_NOT_FOUND))); + } + + + @Test + public void testGetSetProperties() { + ProxyErrorReportValve valve = new ProxyErrorReportValve(); + + Assert.assertTrue(valve.getUseRedirect()); + Assert.assertFalse(valve.getUsePropertiesFile()); + + valve.setUseRedirect(false); + Assert.assertFalse(valve.getUseRedirect()); + + valve.setUsePropertiesFile(true); + Assert.assertTrue(valve.getUsePropertiesFile()); + } + + + @Test + public void testMessageInErrorReport() throws Exception { + final String customErrorMessage = "Custom error message"; + Tomcat tomcat = getTomcatInstance(); + ((StandardHost) tomcat.getHost()).setErrorReportValveClass(PROXY_VALVE); + + Context ctx = getProgrammaticRootContext(); + + Tomcat.addServlet(ctx, "error", new SendErrorServlet( + HttpServletResponse.SC_INTERNAL_SERVER_ERROR, customErrorMessage)); + ctx.addServletMappingDecoded("/", "error"); + + tomcat.start(); + + ByteChunk res = new ByteChunk(); + int rc = getUrl("http://localhost:" + getPort(), res, null); + + Assert.assertEquals(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, rc); + + String body = res.toString(); + Assert.assertNotNull(body); + // Falls back to super.report() which includes the message + Assert.assertTrue(body.contains(customErrorMessage)); + } + + + @Test + public void testExceptionErrorReport() throws Exception { + Tomcat tomcat = getTomcatInstance(); + ((StandardHost) tomcat.getHost()).setErrorReportValveClass(PROXY_VALVE); + + Context ctx = getProgrammaticRootContext(); + + Tomcat.addServlet(ctx, "exception", new ExceptionServlet()); + ctx.addServletMappingDecoded("/", "exception"); + + tomcat.start(); + + ByteChunk res = new ByteChunk(); + int rc = getUrl("http://localhost:" + getPort(), res, null); + + Assert.assertEquals(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, rc); + + String body = res.toString(); + Assert.assertNotNull(body); + Assert.assertTrue(body.contains("RuntimeException")); + } + + + private static final class SendErrorServlet extends HttpServlet { + + @Serial + private static final long serialVersionUID = 1L; + + private final int statusCode; + private final String message; + + private SendErrorServlet(int statusCode, String message) { + this.statusCode = statusCode; + this.message = message; + } + + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) + throws IOException { + resp.sendError(statusCode, message); + } + } + + private static final class ErrorPageServlet extends HttpServlet { + + @Serial + private static final long serialVersionUID = 1L; + + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) + throws IOException { + resp.getWriter().print("ERROR_PAGE_OK"); + } + } + + + private static final class ExceptionServlet extends HttpServlet { + + @Serial + private static final long serialVersionUID = 1L; + + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) { + throw new RuntimeException("Test exception"); + } + } +} diff --git a/test/org/apache/catalina/valves/TestSemaphoreValve.java b/test/org/apache/catalina/valves/TestSemaphoreValve.java new file mode 100644 index 000000000000..cf9ae9d92a36 --- /dev/null +++ b/test/org/apache/catalina/valves/TestSemaphoreValve.java @@ -0,0 +1,440 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.catalina.valves; + +import java.io.IOException; +import java.io.Serial; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.catalina.Context; +import org.apache.catalina.connector.Request; +import org.apache.catalina.connector.Response; +import org.apache.catalina.startup.Tomcat; +import org.apache.catalina.startup.TomcatBaseTest; +import org.apache.tomcat.util.buf.ByteChunk; + +public class TestSemaphoreValve extends TomcatBaseTest { + + + @Test + public void testBasicConcurrency() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + + Tomcat.addServlet(ctx, "hello", new HelloWorldServlet()); + ctx.addServletMappingDecoded("/", "hello"); + + SemaphoreValve valve = new SemaphoreValve(); + valve.setConcurrency(10); + ctx.getPipeline().addValve(valve); + + tomcat.start(); + + ByteChunk res = new ByteChunk(); + int rc = getUrl("http://localhost:" + getPort(), res, null); + + Assert.assertEquals(HttpServletResponse.SC_OK, rc); + Assert.assertEquals(HelloWorldServlet.RESPONSE_TEXT, res.toString()); + } + + @Test + public void testInterruptedConcurrency() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + + Tomcat.addServlet(ctx, "hello", new HelloWorldServlet()); + ctx.addServletMappingDecoded("/", "hello"); + + SemaphoreValve valve = new SemaphoreValve(); + valve.setConcurrency(10); + valve.setInterruptible(true); + ctx.getPipeline().addValve(valve); + + tomcat.start(); + + ByteChunk res = new ByteChunk(); + int rc = getUrl("http://localhost:" + getPort(), res, null); + + Assert.assertEquals(HttpServletResponse.SC_OK, rc); + Assert.assertEquals(HelloWorldServlet.RESPONSE_TEXT, res.toString()); + } + + + @Test + public void testNonBlockingDenied() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + + CountDownLatch insideServlet = new CountDownLatch(1); + CountDownLatch canReturn = new CountDownLatch(1); + Tomcat.addServlet(ctx, "slow", new SlowServlet(insideServlet, canReturn)); + ctx.addServletMappingDecoded("/", "slow"); + + SemaphoreValve valve = new SemaphoreValve(); + valve.setConcurrency(1); + valve.setBlock(false); + valve.setHighConcurrencyStatus(HttpServletResponse.SC_SERVICE_UNAVAILABLE); + ctx.getPipeline().addValve(valve); + + tomcat.start(); + + // First request — should acquire the permit and block inside the servlet + AtomicInteger firstRc = new AtomicInteger(); + Thread firstThread = new Thread(() -> { + try { + firstRc.set(getUrl("http://localhost:" + getPort(), new ByteChunk(), null)); + } catch (IOException e) { + // Ignore + } + }); + firstThread.start(); + + // Wait until the first request is inside the servlet + Assert.assertTrue("First request should reach servlet", + insideServlet.await(10, TimeUnit.SECONDS)); + + // Second request — should be denied because concurrency=1 and block=false + int rc2 = getUrl("http://localhost:" + getPort(), new ByteChunk(), null); + + Assert.assertEquals(HttpServletResponse.SC_SERVICE_UNAVAILABLE, rc2); + + // Release the first request + canReturn.countDown(); + firstThread.join(10000); + Assert.assertFalse(firstThread.isAlive()); + Assert.assertEquals(HttpServletResponse.SC_OK, firstRc.get()); + } + + + @Test + public void testHighConcurrencyStatusNotSet() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + + CountDownLatch insideServlet = new CountDownLatch(1); + CountDownLatch canReturn = new CountDownLatch(1); + Tomcat.addServlet(ctx, "slow", new SlowServlet(insideServlet, canReturn)); + ctx.addServletMappingDecoded("/", "slow"); + + SemaphoreValve valve = new SemaphoreValve(); + valve.setConcurrency(1); + valve.setBlock(false); + // highConcurrencyStatus is -1 by default (no error sent) + ctx.getPipeline().addValve(valve); + + tomcat.start(); + + // First request holds the permit + Thread firstThread = new Thread(() -> { + try { + getUrl("http://localhost:" + getPort(), new ByteChunk(), null); + } catch (IOException e) { + // Ignore + } + }); + firstThread.start(); + + Assert.assertTrue("First request should reach servlet", + insideServlet.await(10, TimeUnit.SECONDS)); + + // Second request — denied but no error status is sent + int rc2 = getUrl("http://localhost:" + getPort(), new ByteChunk(), null); + + // With no highConcurrencyStatus, response is 200 without body + Assert.assertEquals(HttpServletResponse.SC_OK, rc2); + + canReturn.countDown(); + firstThread.join(10000); + } + + + @Test + public void testGetSetProperties() { + SemaphoreValve valve = new SemaphoreValve(); + + // Defaults + Assert.assertEquals(10, valve.getConcurrency()); + Assert.assertFalse(valve.getFairness()); + Assert.assertTrue(valve.getBlock()); + Assert.assertFalse(valve.getInterruptible()); + Assert.assertEquals(-1, valve.getHighConcurrencyStatus()); + + // Setters + valve.setConcurrency(5); + Assert.assertEquals(5, valve.getConcurrency()); + + valve.setFairness(true); + Assert.assertTrue(valve.getFairness()); + + valve.setBlock(false); + Assert.assertFalse(valve.getBlock()); + + valve.setInterruptible(true); + Assert.assertTrue(valve.getInterruptible()); + + valve.setHighConcurrencyStatus(HttpServletResponse.SC_TOO_MANY_REQUESTS); + Assert.assertEquals(HttpServletResponse.SC_TOO_MANY_REQUESTS, valve.getHighConcurrencyStatus()); + } + + + @Test + public void testFairSemaphore() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + + Tomcat.addServlet(ctx, "hello", new HelloWorldServlet()); + ctx.addServletMappingDecoded("/", "hello"); + + SemaphoreValve valve = new SemaphoreValve(); + valve.setConcurrency(5); + valve.setFairness(true); + ctx.getPipeline().addValve(valve); + + tomcat.start(); + + Assert.assertNotNull(valve.semaphore); + Assert.assertTrue(valve.semaphore.isFair()); + Assert.assertEquals(5, valve.semaphore.availablePermits()); + + ByteChunk res = new ByteChunk(); + int rc = getUrl("http://localhost:" + getPort(), res, null); + + Assert.assertEquals(HttpServletResponse.SC_OK, rc); + Assert.assertEquals(HelloWorldServlet.RESPONSE_TEXT, res.toString()); + } + + @Test + public void testBlockingWaitsForPermit() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + + CountDownLatch insideServlet = new CountDownLatch(1); + CountDownLatch canReturn = new CountDownLatch(1); + Tomcat.addServlet(ctx, "slow", new SlowServlet(insideServlet, canReturn)); + ctx.addServletMappingDecoded("/", "slow"); + + SemaphoreValve valve = new SemaphoreValve(); + valve.setConcurrency(1); + valve.setBlock(true); + ctx.getPipeline().addValve(valve); + + tomcat.start(); + + AtomicReference firstError = new AtomicReference<>(); + Thread firstThread = new Thread(() -> { + try { + getUrl("http://localhost:" + getPort(), new ByteChunk(), null); + } catch (IOException e) { + firstError.set(e); + } + }); + firstThread.start(); + + Assert.assertTrue("First request should reach servlet", + insideServlet.await(10, TimeUnit.SECONDS)); + + AtomicInteger secondRc = new AtomicInteger(); + AtomicReference secondError = new AtomicReference<>(); + Thread secondThread = new Thread(() -> { + try { + secondRc.set(getUrl("http://localhost:" + getPort(), new ByteChunk(), null)); + } catch (IOException e) { + secondError.set(e); + } + }); + secondThread.start(); + + // Give the second request time to arrive and block on the semaphore + Thread.sleep(500); + + Assert.assertTrue("Second request should be blocked waiting for permit", secondThread.isAlive()); + + canReturn.countDown(); + firstThread.join(10000); + Assert.assertNull(firstError.get()); + + secondThread.join(10000); + Assert.assertFalse(secondThread.isAlive()); + Assert.assertNull(secondError.get()); + Assert.assertEquals(HttpServletResponse.SC_OK, secondRc.get()); + } + + @Test + public void testControlConcurrencyBypass() throws Exception { + Tomcat tomcat = getTomcatInstance(); + + Context ctx = getProgrammaticRootContext(); + + CountDownLatch insideServlet = new CountDownLatch(1); + CountDownLatch canReturn = new CountDownLatch(1); + Tomcat.addServlet(ctx, "slow", new SlowServlet(insideServlet, canReturn)); + ctx.addServletMappingDecoded("/slow", "slow"); + + Tomcat.addServlet(ctx, "hello", new HelloWorldServlet()); + ctx.addServletMappingDecoded("/bypass", "hello"); + + SemaphoreValve valve = new SemaphoreValve() { + @Override + public boolean controlConcurrency(Request request, Response response) { + return !request.getDecodedRequestURI().equals("/bypass"); + } + }; + valve.setConcurrency(1); + valve.setBlock(false); + valve.setHighConcurrencyStatus(HttpServletResponse.SC_SERVICE_UNAVAILABLE); + ctx.getPipeline().addValve(valve); + + tomcat.start(); + + Thread firstThread = new Thread(() -> { + try { + getUrl("http://localhost:" + getPort() + "/slow", new ByteChunk(), null); + } catch (IOException e) { + // Ignored + } + }); + firstThread.start(); + + Assert.assertTrue("First request should reach servlet", + insideServlet.await(10, TimeUnit.SECONDS)); + + // Request to /bypass should succeed despite all permits being held, + // because controlConcurrency() returns false for this path + int bypassRc = getUrl("http://localhost:" + getPort() + "/bypass", new ByteChunk(), null); + Assert.assertEquals(HttpServletResponse.SC_OK, bypassRc); + + int deniedRc = getUrl("http://localhost:" + getPort() + "/slow", new ByteChunk(), null); + Assert.assertEquals(HttpServletResponse.SC_SERVICE_UNAVAILABLE, deniedRc); + + canReturn.countDown(); + firstThread.join(10000); + } + + @Test + public void testInterruptibleDenied() throws Exception { + SemaphoreValve semaphoreValve = new SemaphoreValve(); + semaphoreValve.setConcurrency(1); + semaphoreValve.setBlock(true); + semaphoreValve.setInterruptible(true); + semaphoreValve.setHighConcurrencyStatus(HttpServletResponse.SC_SERVICE_UNAVAILABLE); + + semaphoreValve.semaphore = new Semaphore(1, false); + + AtomicBoolean nextInvoked = new AtomicBoolean(false); + semaphoreValve.setNext(new ValveBase() { + @Override + public void invoke(Request request, Response response) { + nextInvoked.set(true); + } + }); + + MockResponse response = new MockResponse(); + + semaphoreValve.semaphore.acquire(); + + // On a new thread, valve will block on semaphore.acquire() because the permit is already held. + CountDownLatch invokeStarted = new CountDownLatch(1); + Thread blocked = new Thread(() -> { + invokeStarted.countDown(); + try { + semaphoreValve.invoke(null, response); + } catch (Throwable t) { + // Ignored + } + }); + blocked.start(); + + Assert.assertTrue(invokeStarted.await(10, TimeUnit.SECONDS)); + Thread.sleep(200); + + blocked.interrupt(); + blocked.join(10000); + Assert.assertFalse(blocked.isAlive()); + + Assert.assertEquals(HttpServletResponse.SC_SERVICE_UNAVAILABLE, response.getStatus()); + + Assert.assertFalse("Next valve should not be invoked when permit denied", nextInvoked.get()); + + Assert.assertEquals(0, semaphoreValve.semaphore.availablePermits()); + + semaphoreValve.semaphore.release(); + } + + private static final class SlowServlet extends HttpServlet { + + @Serial + private static final long serialVersionUID = 1L; + private final CountDownLatch insideServlet; + private final CountDownLatch canReturn; + + private SlowServlet(CountDownLatch insideServlet, CountDownLatch canReturn) { + this.insideServlet = insideServlet; + this.canReturn = canReturn; + } + + @Override + protected void doGet(HttpServletRequest req, HttpServletResponse resp) + throws IOException { + insideServlet.countDown(); + try { + Assert.assertTrue(canReturn.await(30, TimeUnit.SECONDS)); + } catch (InterruptedException e) { + // Ignore + } + resp.setContentType("text/plain"); + resp.getWriter().print("OK"); + } + } + + public static class MockResponse extends Response { + + public MockResponse() { + super(null); + } + + private int status = HttpServletResponse.SC_OK; + + @Override + public void sendError(int status) throws IOException { + this.status = status; + } + + @Override + public int getStatus() { + return status; + } + } + +} diff --git a/webapps/docs/config/valve.xml b/webapps/docs/config/valve.xml index 07947471f333..b32cea42acc1 100644 --- a/webapps/docs/config/valve.xml +++ b/webapps/docs/config/valve.xml @@ -2694,6 +2694,68 @@ +
+ + + +

The Filter Valve allows a Servlet Filter to be run as + part of the Valve pipeline. This enables reuse of existing Filter + implementations at the Valve level without duplicating their logic.

+ +

There are several caveats when using this Valve:

+
    +
  • A separate instance of the Filter class is created, distinct + from any instance that may be instantiated within a web application.
  • +
  • Calls to FilterConfig.getFilterName() will return + null.
  • +
  • FilterConfig.getServletContext() will return the proper + ServletContext for a Valve attached to a + <Context>, but will return a + ServletContext of limited use for a Valve specified on an + <Engine> or <Host>.
  • +
  • The Filter MUST NOT wrap the + ServletRequest or ServletResponse objects, or + an IllegalStateException will be thrown.
  • +
+ +
+ + + +

The Filter Valve supports the following + configuration attributes:

+ + + + +

Java class name of the implementation to use. This MUST be set to + org.apache.catalina.valves.FilterValve.

+
+ + +

The fully qualified class name of the Filter + implementation to use. The class must have a no-argument + constructor.

+
+ +
+ +

The Filter Valve also supports nested + <init-param> elements to pass initialization + parameters to the Filter:

+ + + + myParam + myValue + + ]]> + +
+ +
+