diff --git a/pyproject.toml b/pyproject.toml index 02b64e27..9c89e671 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ classifiers = [ "Operating System :: OS Independent" ] + [tool.poetry.dependencies] python = "^3.9" postgrest = "1.1.1" diff --git a/supabase/_async/client.py b/supabase/_async/client.py index 6992af06..5aab286b 100644 --- a/supabase/_async/client.py +++ b/supabase/_async/client.py @@ -323,19 +323,22 @@ def _get_auth_headers(self, authorization: Optional[str] = None) -> Dict[str, st "Authorization": authorization, } - def _listen_to_auth_events( - self, event: AuthChangeEvent, session: Optional[Session] - ): - access_token = self.supabase_key + async def _listen_to_auth_events(self, event: AuthChangeEvent, session: Optional[Session]): + original_auth = self._create_auth_header(self.supabase_key) + + if self.options.headers.get("Authorization") == original_auth: + return + + access_token = self.supabase_key # Default value to avoid unbound error + if event in ["SIGNED_IN", "TOKEN_REFRESHED", "SIGNED_OUT"]: - # reset postgrest and storage instance on event change self._postgrest = None self._storage = None self._functions = None access_token = session.access_token if session else self.supabase_key - self.options.headers["Authorization"] = self._create_auth_header(access_token) - asyncio.create_task(self.realtime.set_auth(access_token)) + self.options.headers["Authorization"] = self._create_auth_header(access_token) + await self.realtime.set_auth(access_token) async def create_client( supabase_url: str, diff --git a/supabase/_sync/client.py b/supabase/_sync/client.py index af2a841c..b6d5ac1d 100644 --- a/supabase/_sync/client.py +++ b/supabase/_sync/client.py @@ -322,12 +322,14 @@ def _get_auth_headers(self, authorization: Optional[str] = None) -> Dict[str, st "Authorization": authorization, } - def _listen_to_auth_events( - self, event: AuthChangeEvent, session: Optional[Session] - ): + def _listen_to_auth_events(self, event: AuthChangeEvent, session: Optional[Session]): + original_auth = self._create_auth_header(self.supabase_key) + + if self.options.headers.get("Authorization") == original_auth: + return + access_token = self.supabase_key if event in ["SIGNED_IN", "TOKEN_REFRESHED", "SIGNED_OUT"]: - # reset postgrest and storage instance on event change self._postgrest = None self._storage = None self._functions = None @@ -335,6 +337,7 @@ def _listen_to_auth_events( self.options.headers["Authorization"] = self._create_auth_header(access_token) + def create_client( supabase_url: str, supabase_key: str, diff --git a/tests/_async/test_auth_refresh_async.py b/tests/_async/test_auth_refresh_async.py new file mode 100644 index 00000000..b0eafa77 --- /dev/null +++ b/tests/_async/test_auth_refresh_async.py @@ -0,0 +1,42 @@ +import pytest +from supabase._async.client import create_client as create_async_client + +class DummySession: + def __init__(self, token): + self.access_token = token + +@pytest.mark.asyncio +async def test_async_auth_header_updates(monkeypatch): + client = await create_async_client("https://example.supabase.co", "svc-key") + + # Fake realtime.set_auth + called = {"token": None} + async def fake_set_auth(token): + called["token"] = token + + monkeypatch.setattr(client.realtime, "set_auth", fake_set_auth) + + client.options.headers["Authorization"] = "Bearer stale-token" + session = DummySession("refreshed-token") + + await client._listen_to_auth_events("TOKEN_REFRESHED", session) + + assert client.options.headers["Authorization"] == "Bearer refreshed-token" + assert called["token"] == "refreshed-token" + assert client._postgrest is None + assert client._storage is None + assert client._functions is None + + +@pytest.mark.asyncio +async def test_async_no_update_if_original_token_used(): + client = await create_async_client("https://example.supabase.co", "svc-role-key") + + original_auth = client._create_auth_header(client.supabase_key) + client.options.headers["Authorization"] = original_auth + + session = DummySession("some-other-token") + await client._listen_to_auth_events("TOKEN_REFRESHED", session) + + # Should skip the update logic + assert client.options.headers["Authorization"] == original_auth diff --git a/tests/_sync/test_auth_refresh_sync.py b/tests/_sync/test_auth_refresh_sync.py new file mode 100644 index 00000000..e514a842 --- /dev/null +++ b/tests/_sync/test_auth_refresh_sync.py @@ -0,0 +1,29 @@ +from supabase import create_client +import pytest + +class DummySession: + def __init__(self, access_token): + self.access_token = access_token + +def test_token_refresh_updates_header(monkeypatch): + supabase = create_client("https://test.supabase.co", "original-key") + + session = DummySession("new-token-123") + + # simulate stale token (pretend something else set a new one) + supabase.options.headers["Authorization"] = "Bearer old-token" + + supabase._listen_to_auth_events("TOKEN_REFRESHED", session) + + assert supabase.options.headers["Authorization"] == "Bearer new-token-123" + assert supabase._postgrest is None + assert supabase._storage is None + assert supabase._functions is None + +def test_no_update_when_auth_unchanged(): + supabase = create_client("https://test.supabase.co", "svc-role-key") + auth_header_before = supabase.options.headers["Authorization"] + supabase._listen_to_auth_events("SIGNED_IN", DummySession("new-token")) + auth_header_after = supabase.options.headers["Authorization"] + + assert auth_header_before == auth_header_after