From 643d0832698d5306c01927add7b4aa34da1c457d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joaqu=C3=ADn=20Ossandon?= <30879716+cacosandon@users.noreply.github.com> Date: Wed, 4 Sep 2024 19:02:43 -0400 Subject: [PATCH] Ensure text message exists before handling on `WebsocketConsumer` (#2097) * fix(channels/generic): ensure text message exists before deciding to handle * tests(channels/generic): regression test for double check of text message None * refactor(channels/generic): short condition * lint: fix flake8 errors --- channels/generic/websocket.py | 4 +-- tests/test_generic_websocket.py | 51 +++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/channels/generic/websocket.py b/channels/generic/websocket.py index 6d41c8ee..b4d99119 100644 --- a/channels/generic/websocket.py +++ b/channels/generic/websocket.py @@ -60,7 +60,7 @@ def websocket_receive(self, message): Called when a WebSocket frame is received. Decodes it and passes it to receive(). """ - if "text" in message: + if message.get("text") is not None: self.receive(text_data=message["text"]) else: self.receive(bytes_data=message["bytes"]) @@ -200,7 +200,7 @@ async def websocket_receive(self, message): Called when a WebSocket frame is received. Decodes it and passes it to receive(). """ - if "text" in message: + if message.get("text") is not None: await self.receive(text_data=message["text"]) else: await self.receive(bytes_data=message["bytes"]) diff --git a/tests/test_generic_websocket.py b/tests/test_generic_websocket.py index c553eb84..0ade1a02 100644 --- a/tests/test_generic_websocket.py +++ b/tests/test_generic_websocket.py @@ -485,3 +485,54 @@ async def connect(self): assert msg["type"] == "websocket.close" assert msg["code"] == 4007 assert msg["reason"] == "test reason" + + +@pytest.mark.django_db +@pytest.mark.asyncio +async def test_websocket_receive_with_none_text(): + """ + Tests that the receive method handles messages with None text data correctly. + """ + + class TestConsumer(WebsocketConsumer): + def receive(self, text_data=None, bytes_data=None): + if text_data: + self.send(text_data="Received text: " + text_data) + elif bytes_data: + self.send(text_data=f"Received bytes of length: {len(bytes_data)}") + + app = TestConsumer() + + # Open a connection + communicator = WebsocketCommunicator(app, "/testws/") + connected, _ = await communicator.connect() + assert connected + + # Simulate Hypercorn behavior + # (both 'text' and 'bytes' keys present, but 'text' is None) + await communicator.send_input( + { + "type": "websocket.receive", + "text": None, + "bytes": b"test data", + } + ) + response = await communicator.receive_output() + assert response["type"] == "websocket.send" + assert response["text"] == "Received bytes of length: 9" + + # Test with only 'bytes' key (simulating uvicorn/daphne behavior) + await communicator.send_input({"type": "websocket.receive", "bytes": b"more data"}) + response = await communicator.receive_output() + assert response["type"] == "websocket.send" + assert response["text"] == "Received bytes of length: 9" + + # Test with valid text data + await communicator.send_input( + {"type": "websocket.receive", "text": "Hello, world!"} + ) + response = await communicator.receive_output() + assert response["type"] == "websocket.send" + assert response["text"] == "Received text: Hello, world!" + + await communicator.disconnect()