From c44af72435e4bb5620dc294152d37722f7d894c4 Mon Sep 17 00:00:00 2001 From: Oleg A Date: Tue, 3 Dec 2024 01:43:23 +0300 Subject: [PATCH 1/4] feat: restrict methods --- aiohttp_sse/__init__.py | 24 ++++++++++++++++++++++-- tests/test_sse.py | 28 +++++++++++++++++++++++++++- 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/aiohttp_sse/__init__.py b/aiohttp_sse/__init__.py index ee09297..cfc8357 100644 --- a/aiohttp_sse/__init__.py +++ b/aiohttp_sse/__init__.py @@ -2,12 +2,13 @@ import io import re import sys -from collections.abc import Mapping +from collections.abc import Iterable, Mapping, Set from types import TracebackType from typing import Any, Optional, TypeVar, Union, overload from aiohttp.abc import AbstractStreamWriter from aiohttp.web import BaseRequest, ContentCoding, Request, StreamResponse +from aiohttp.web_exceptions import HTTPMethodNotAllowed from .helpers import _ContextManager @@ -32,6 +33,7 @@ async def hello(request): DEFAULT_SEPARATOR = "\r\n" DEFAULT_LAST_EVENT_HEADER = "Last-Event-Id" LINE_SEP_EXPR = re.compile(r"\r\n|\r|\n") + DEFAULT_ALLOWED_METHODS: Set[str] = frozenset(("GET",)) def __init__( self, @@ -40,6 +42,7 @@ def __init__( reason: Optional[str] = None, headers: Optional[Mapping[str, str]] = None, sep: Optional[str] = None, + allowed_methods: Optional[Iterable[str]] = None, ): super().__init__(status=status, reason=reason) @@ -55,6 +58,11 @@ def __init__( self._ping_interval: float = self.DEFAULT_PING_INTERVAL self._ping_task: Optional[asyncio.Task[None]] = None self._sep = sep if sep is not None else self.DEFAULT_SEPARATOR + self._allowed_methods = ( + frozenset(allowed_methods) + if allowed_methods is not None + else self.DEFAULT_ALLOWED_METHODS + ) def is_connected(self) -> bool: """Check connection is prepared and ping task is not done.""" @@ -73,6 +81,9 @@ async def prepare(self, request: BaseRequest) -> Optional[AbstractStreamWriter]: :param request: regular aiohttp.web.Request. """ + if request.method not in self._allowed_methods: + raise HTTPMethodNotAllowed(request.method, self._allowed_methods) + if not self.prepared: writer = await super().prepare(request) self._ping_task = asyncio.create_task(self._ping()) @@ -234,6 +245,7 @@ def sse_response( reason: Optional[str] = None, headers: Optional[Mapping[str, str]] = None, sep: Optional[str] = None, + allowed_methods: Optional[Iterable[str]] = None, ) -> _ContextManager[EventSourceResponse]: ... @@ -246,6 +258,7 @@ def sse_response( headers: Optional[Mapping[str, str]] = None, sep: Optional[str] = None, response_cls: type[ESR], + allowed_methods: Optional[Iterable[str]] = None, ) -> _ContextManager[ESR]: ... @@ -257,6 +270,7 @@ def sse_response( headers: Optional[Mapping[str, str]] = None, sep: Optional[str] = None, response_cls: type[EventSourceResponse] = EventSourceResponse, + allowed_methods: Optional[Iterable[str]] = None, ) -> Any: if not issubclass(response_cls, EventSourceResponse): raise TypeError( @@ -264,5 +278,11 @@ def sse_response( "aiohttp_sse.EventSourceResponse, got {}".format(response_cls) ) - sse = response_cls(status=status, reason=reason, headers=headers, sep=sep) + sse = response_cls( + status=status, + reason=reason, + headers=headers, + sep=sep, + allowed_methods=allowed_methods, + ) return _ContextManager(sse._prepare(request)) diff --git a/tests/test_sse.py b/tests/test_sse.py index f46909b..ae9dee6 100644 --- a/tests/test_sse.py +++ b/tests/test_sse.py @@ -510,7 +510,10 @@ async def test_get_before_prepare(self) -> None: ) async def test_http_methods(aiohttp_client: AiohttpClient, http_method: str) -> None: async def handler(request: web.Request) -> EventSourceResponse: - async with sse_response(request) as sse: + async with sse_response( + request=request, + allowed_methods=("GET", "POST", "PUT", "DELETE", "PATCH"), + ) as sse: await sse.send("foo") return sse @@ -526,6 +529,29 @@ async def handler(request: web.Request) -> EventSourceResponse: assert streamed_data == "data: foo\r\n\r\n" +@pytest.mark.parametrize( + "http_method", + ("POST", "PUT", "DELETE", "PATCH"), +) +async def test_not_allowed_methods( + aiohttp_client: AiohttpClient, + http_method: str, +) -> None: + """Check that EventSourceResponse works only with GET method.""" + + async def handler(request: web.Request) -> EventSourceResponse: + async with sse_response(request) as sse: + ... + return sse # pragma: no cover + + app = web.Application() + app.router.add_route(http_method, "/", handler) + + client = await aiohttp_client(app) + async with client.request(http_method, "/") as resp: + assert resp.status == 405 + + @pytest.mark.skipif( sys.version_info < (3, 11), reason=".cancelling() missing in older versions", From f598e4bd0cd34f9d2881c005762cc2de3ae42cf1 Mon Sep 17 00:00:00 2001 From: Oleg A Date: Tue, 3 Dec 2024 09:36:37 +0300 Subject: [PATCH 2/4] feat: restrict methods (via bool) --- aiohttp_sse/__init__.py | 23 +++++++++-------------- tests/test_sse.py | 5 +---- 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/aiohttp_sse/__init__.py b/aiohttp_sse/__init__.py index cfc8357..9f1cd82 100644 --- a/aiohttp_sse/__init__.py +++ b/aiohttp_sse/__init__.py @@ -2,7 +2,7 @@ import io import re import sys -from collections.abc import Iterable, Mapping, Set +from collections.abc import Mapping from types import TracebackType from typing import Any, Optional, TypeVar, Union, overload @@ -33,7 +33,6 @@ async def hello(request): DEFAULT_SEPARATOR = "\r\n" DEFAULT_LAST_EVENT_HEADER = "Last-Event-Id" LINE_SEP_EXPR = re.compile(r"\r\n|\r|\n") - DEFAULT_ALLOWED_METHODS: Set[str] = frozenset(("GET",)) def __init__( self, @@ -42,7 +41,7 @@ def __init__( reason: Optional[str] = None, headers: Optional[Mapping[str, str]] = None, sep: Optional[str] = None, - allowed_methods: Optional[Iterable[str]] = None, + allow_all_methods: bool = False, ): super().__init__(status=status, reason=reason) @@ -58,11 +57,7 @@ def __init__( self._ping_interval: float = self.DEFAULT_PING_INTERVAL self._ping_task: Optional[asyncio.Task[None]] = None self._sep = sep if sep is not None else self.DEFAULT_SEPARATOR - self._allowed_methods = ( - frozenset(allowed_methods) - if allowed_methods is not None - else self.DEFAULT_ALLOWED_METHODS - ) + self._allow_all_methods = allow_all_methods def is_connected(self) -> bool: """Check connection is prepared and ping task is not done.""" @@ -81,8 +76,8 @@ async def prepare(self, request: BaseRequest) -> Optional[AbstractStreamWriter]: :param request: regular aiohttp.web.Request. """ - if request.method not in self._allowed_methods: - raise HTTPMethodNotAllowed(request.method, self._allowed_methods) + if not self._allow_all_methods and request.method != "GET": + raise HTTPMethodNotAllowed(request.method, ["GET"]) if not self.prepared: writer = await super().prepare(request) @@ -245,7 +240,7 @@ def sse_response( reason: Optional[str] = None, headers: Optional[Mapping[str, str]] = None, sep: Optional[str] = None, - allowed_methods: Optional[Iterable[str]] = None, + allow_all_methods: bool = False, ) -> _ContextManager[EventSourceResponse]: ... @@ -258,7 +253,7 @@ def sse_response( headers: Optional[Mapping[str, str]] = None, sep: Optional[str] = None, response_cls: type[ESR], - allowed_methods: Optional[Iterable[str]] = None, + allow_all_methods: bool = False, ) -> _ContextManager[ESR]: ... @@ -270,7 +265,7 @@ def sse_response( headers: Optional[Mapping[str, str]] = None, sep: Optional[str] = None, response_cls: type[EventSourceResponse] = EventSourceResponse, - allowed_methods: Optional[Iterable[str]] = None, + allow_all_methods: bool = False, ) -> Any: if not issubclass(response_cls, EventSourceResponse): raise TypeError( @@ -283,6 +278,6 @@ def sse_response( reason=reason, headers=headers, sep=sep, - allowed_methods=allowed_methods, + allow_all_methods=allow_all_methods, ) return _ContextManager(sse._prepare(request)) diff --git a/tests/test_sse.py b/tests/test_sse.py index ae9dee6..c2a25f3 100644 --- a/tests/test_sse.py +++ b/tests/test_sse.py @@ -510,10 +510,7 @@ async def test_get_before_prepare(self) -> None: ) async def test_http_methods(aiohttp_client: AiohttpClient, http_method: str) -> None: async def handler(request: web.Request) -> EventSourceResponse: - async with sse_response( - request=request, - allowed_methods=("GET", "POST", "PUT", "DELETE", "PATCH"), - ) as sse: + async with sse_response(request, allow_all_methods=True) as sse: await sse.send("foo") return sse From 1987ce9cfb191edf6efbd6dd2e58dfeab866b140 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Wed, 4 Dec 2024 13:23:19 +0000 Subject: [PATCH 3/4] Bump version --- aiohttp_sse/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiohttp_sse/__init__.py b/aiohttp_sse/__init__.py index 9f1cd82..7d3a204 100644 --- a/aiohttp_sse/__init__.py +++ b/aiohttp_sse/__init__.py @@ -12,7 +12,7 @@ from .helpers import _ContextManager -__version__ = "2.2.0" +__version__ = "3.0" __all__ = ["EventSourceResponse", "sse_response"] From d88dd0336effe1974b975fcc3fb2b500abde3bd4 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Wed, 4 Dec 2024 13:27:12 +0000 Subject: [PATCH 4/4] Update CHANGES.rst --- CHANGES.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/CHANGES.rst b/CHANGES.rst index 01b6c65..30279c0 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,14 @@ CHANGES .. towncrier release notes start +3.0 (2025-xx-xx) +================ + +- Disallow HTTP methods other than GET by default. + Allowing other methods can be done by passing ``allow_all_methods`` to ``EventSourceResponse``. + Note that SSE on browsers will only work with GET endpoints, this flag just serves as a + reminder. + 2.2.0 (2024-02-29) ==================